Skip to content

Commit 92f1154

Browse files
author
Nabil Fayak
committed
added in checkmates features needed to remove evalml dependencies in tempo
1 parent 514b3cd commit 92f1154

16 files changed

+2325
-5
lines changed

checkmates/exceptions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@
55
ValidationErrorCode,
66
ObjectiveCreationError,
77
ObjectiveNotFoundError,
8+
MethodPropertyNotFoundError,
9+
ComponentNotYetFittedError,
810
)

checkmates/exceptions/exceptions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ class ObjectiveNotFoundError(Exception):
1313

1414
pass
1515

16+
class MethodPropertyNotFoundError(Exception):
17+
"""Exception to raise when a class is does not have an expected method or property."""
18+
19+
pass
20+
21+
class ComponentNotYetFittedError(Exception):
22+
"""An exception to be raised when predict/predict_proba/transform is called on a component without fitting first."""
23+
24+
pass
1625

1726
class ObjectiveCreationError(Exception):
1827
"""Exception when get_objective tries to instantiate an objective and required args are not provided."""

checkmates/objectives/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
from checkmates.objectives.objective_base import ObjectiveBase
44
from checkmates.objectives.regression_objective import RegressionObjective
55

6-
from checkmates.objectives.utils import get_objective
7-
from checkmates.objectives.utils import get_default_primary_search_objective
8-
from checkmates.objectives.utils import get_non_core_objectives
9-
from checkmates.objectives.utils import get_core_objectives
6+
from checkmates.objectives.utils import (
7+
get_objective,
8+
get_default_primary_search_objective,
9+
get_non_core_objectives,
10+
get_core_objectives,
11+
get_problem_type,
12+
)
1013

1114

1215
from checkmates.objectives.standard_metrics import RootMeanSquaredLogError

checkmates/objectives/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
"""Utility methods for CheckMates objectives."""
2+
import pandas as pd
3+
from typing import Optional
4+
25
from checkmates import objectives
36
from checkmates.exceptions import ObjectiveCreationError, ObjectiveNotFoundError
47
from checkmates.objectives.objective_base import ObjectiveBase
58
from checkmates.problem_types import handle_problem_types
69
from checkmates.utils.gen_utils import _get_subclasses
10+
from checkmates.problem_types import ProblemTypes
11+
12+
from checkmates.utils.logger import get_logger
13+
14+
logger = get_logger(__file__)
715

816

917
def get_non_core_objectives():
@@ -89,6 +97,34 @@ def get_objective(objective, return_instance=False, **kwargs):
8997

9098
return objective_class
9199

100+
def get_problem_type(
101+
input_problem_type: Optional[str],
102+
target_data: pd.Series,
103+
) -> ProblemTypes:
104+
"""helper function to determine if classification problem is binary or multiclass dependent on target variable values."""
105+
if not input_problem_type:
106+
raise ValueError("problem type is required")
107+
if input_problem_type.lower() == "classification":
108+
values: pd.Series = target_data.value_counts()
109+
if values.size == 2:
110+
return ProblemTypes.BINARY
111+
elif values.size > 2:
112+
return ProblemTypes.MULTICLASS
113+
else:
114+
message: str = "The target field contains less than two unique values. It cannot be used for modeling."
115+
logger.error(message, exc_info=True)
116+
raise ValueError(message)
117+
118+
if input_problem_type.lower() == "regression":
119+
return ProblemTypes.REGRESSION
120+
121+
if input_problem_type.lower() == "time series regression":
122+
return ProblemTypes.TIME_SERIES_REGRESSION
123+
124+
message = f"Unexpected problem type provided in configuration: {input_problem_type}"
125+
logger.error(message, exc_info=True)
126+
raise ValueError(message)
127+
92128

93129
def get_default_primary_search_objective(problem_type):
94130
"""Get the default primary search objective for a problem type.

checkmates/pipelines/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from checkmates.pipelines.component_base_meta import ComponentBaseMeta
2+
from checkmates.pipelines.component_base import ComponentBase
3+
from checkmates.pipelines.transformers import Transformer
4+
from checkmates.pipelines.components import ( # noqa: F401
5+
DropColumns,
6+
DropRowsTransformer,
7+
PerColumnImputer,
8+
TargetImputer,
9+
TimeSeriesImputer,
10+
TimeSeriesRegularizer,
11+
)
12+
from checkmates.pipelines.utils import _make_component_list_from_actions, split_data, drop_infinity
13+
from checkmates.pipelines.training_validation_split import TrainingValidationSplit
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
"""Base class for all components."""
2+
import copy
3+
from abc import ABC, abstractmethod
4+
5+
import cloudpickle
6+
7+
from checkmates.exceptions import MethodPropertyNotFoundError
8+
from checkmates.pipelines.component_base_meta import ComponentBaseMeta
9+
from checkmates.utils import (
10+
_downcast_nullable_X,
11+
_downcast_nullable_y,
12+
classproperty,
13+
infer_feature_types,
14+
log_subtitle,
15+
safe_repr,
16+
)
17+
from checkmates.utils.logger import get_logger
18+
19+
20+
class ComponentBase(ABC, metaclass=ComponentBaseMeta):
21+
"""Base class for all components.
22+
23+
Args:
24+
parameters (dict): Dictionary of parameters for the component. Defaults to None.
25+
component_obj (obj): Third-party objects useful in component implementation. Defaults to None.
26+
random_seed (int): Seed for the random number generator. Defaults to 0.
27+
"""
28+
29+
_default_parameters = None
30+
_can_be_used_for_fast_partial_dependence = True
31+
# Referring to the pandas nullable dtypes; not just woodwork logical types
32+
_integer_nullable_incompatibilities = []
33+
_boolean_nullable_incompatibilities = []
34+
is_multiseries = False
35+
36+
def __init__(self, parameters=None, component_obj=None, random_seed=0, **kwargs):
37+
"""Base class for all components.
38+
39+
Args:
40+
parameters (dict): Dictionary of parameters for the component. Defaults to None.
41+
component_obj (obj): Third-party objects useful in component implementation. Defaults to None.
42+
random_seed (int): Seed for the random number generator. Defaults to 0.
43+
kwargs (Any): Any keyword arguments to pass into the component.
44+
"""
45+
self.random_seed = random_seed
46+
self._component_obj = component_obj
47+
self._parameters = parameters or {}
48+
self._is_fitted = False
49+
50+
@property
51+
@classmethod
52+
@abstractmethod
53+
def name(cls):
54+
"""Returns string name of this component."""
55+
56+
@property
57+
@classmethod
58+
@abstractmethod
59+
def modifies_features(cls):
60+
"""Returns whether this component modifies (subsets or transforms) the features variable during transform.
61+
62+
For Estimator objects, this attribute determines if the return
63+
value from `predict` or `predict_proba` should be used as
64+
features or targets.
65+
"""
66+
67+
@property
68+
@classmethod
69+
@abstractmethod
70+
def modifies_target(cls):
71+
"""Returns whether this component modifies (subsets or transforms) the target variable during transform.
72+
73+
For Estimator objects, this attribute determines if the return
74+
value from `predict` or `predict_proba` should be used as
75+
features or targets.
76+
"""
77+
78+
@property
79+
@classmethod
80+
@abstractmethod
81+
def training_only(cls):
82+
"""Returns whether or not this component should be evaluated during training-time only, or during both training and prediction time."""
83+
84+
@classproperty
85+
def needs_fitting(self):
86+
"""Returns boolean determining if component needs fitting before calling predict, predict_proba, transform, or feature_importances.
87+
88+
This can be overridden to False for components that do not need to be fit or whose fit methods do nothing.
89+
90+
Returns:
91+
True.
92+
"""
93+
return True
94+
95+
@property
96+
def parameters(self):
97+
"""Returns the parameters which were used to initialize the component."""
98+
return copy.copy(self._parameters)
99+
100+
@classproperty
101+
def default_parameters(cls):
102+
"""Returns the default parameters for this component.
103+
104+
Our convention is that Component.default_parameters == Component().parameters.
105+
106+
Returns:
107+
dict: Default parameters for this component.
108+
"""
109+
if cls._default_parameters is None:
110+
cls._default_parameters = cls().parameters
111+
112+
return cls._default_parameters
113+
114+
@classproperty
115+
def _supported_by_list_API(cls):
116+
return not cls.modifies_target
117+
118+
def _handle_partial_dependence_fast_mode(
119+
self,
120+
pipeline_parameters,
121+
X=None,
122+
target=None,
123+
):
124+
"""Determines whether or not a component can be used with partial dependence's fast mode.
125+
126+
Args:
127+
pipeline_parameters (dict): Pipeline parameters that will be used to create the pipelines
128+
used in partial dependence fast mode.
129+
X (pd.DataFrame, optional): Holdout data being used for partial dependence calculations.
130+
target (str, optional): The target whose values we are trying to predict.
131+
"""
132+
if self._can_be_used_for_fast_partial_dependence:
133+
return pipeline_parameters
134+
135+
raise TypeError(
136+
f"Component {self.name} cannot run partial dependence fast mode.",
137+
)
138+
139+
def clone(self):
140+
"""Constructs a new component with the same parameters and random state.
141+
142+
Returns:
143+
A new instance of this component with identical parameters and random state.
144+
"""
145+
return self.__class__(**self.parameters, random_seed=self.random_seed)
146+
147+
def fit(self, X, y=None):
148+
"""Fits component to data.
149+
150+
Args:
151+
X (pd.DataFrame): The input training data of shape [n_samples, n_features]
152+
y (pd.Series, optional): The target training data of length [n_samples]
153+
154+
Returns:
155+
self
156+
157+
Raises:
158+
MethodPropertyNotFoundError: If component does not have a fit method or a component_obj that implements fit.
159+
"""
160+
X = infer_feature_types(X)
161+
if y is not None:
162+
y = infer_feature_types(y)
163+
try:
164+
self._component_obj.fit(X, y)
165+
return self
166+
except AttributeError:
167+
raise MethodPropertyNotFoundError(
168+
"Component requires a fit method or a component_obj that implements fit",
169+
)
170+
171+
def describe(self, print_name=False, return_dict=False):
172+
"""Describe a component and its parameters.
173+
174+
Args:
175+
print_name(bool, optional): whether to print name of component
176+
return_dict(bool, optional): whether to return description as dictionary in the format {"name": name, "parameters": parameters}
177+
178+
Returns:
179+
None or dict: Returns dictionary if return_dict is True, else None.
180+
"""
181+
logger = get_logger(f"{__name__}.describe")
182+
if print_name:
183+
title = self.name
184+
log_subtitle(logger, title)
185+
for parameter in self.parameters:
186+
parameter_str = ("\t * {} : {}").format(
187+
parameter,
188+
self.parameters[parameter],
189+
)
190+
logger.info(parameter_str)
191+
if return_dict:
192+
component_dict = {"name": self.name}
193+
component_dict.update({"parameters": self.parameters})
194+
return component_dict
195+
196+
def save(self, file_path, pickle_protocol=cloudpickle.DEFAULT_PROTOCOL):
197+
"""Saves component at file path.
198+
199+
Args:
200+
file_path (str): Location to save file.
201+
pickle_protocol (int): The pickle data stream format.
202+
"""
203+
with open(file_path, "wb") as f:
204+
cloudpickle.dump(self, f, protocol=pickle_protocol)
205+
206+
@staticmethod
207+
def load(file_path):
208+
"""Loads component at file path.
209+
210+
Args:
211+
file_path (str): Location to load file.
212+
213+
Returns:
214+
ComponentBase object
215+
"""
216+
with open(file_path, "rb") as f:
217+
return cloudpickle.load(f)
218+
219+
def __eq__(self, other):
220+
"""Check for equality."""
221+
if not isinstance(other, self.__class__):
222+
return False
223+
random_seed_eq = self.random_seed == other.random_seed
224+
if not random_seed_eq:
225+
return False
226+
attributes_to_check = ["_parameters", "_is_fitted"]
227+
for attribute in attributes_to_check:
228+
if getattr(self, attribute) != getattr(other, attribute):
229+
return False
230+
return True
231+
232+
def __str__(self):
233+
"""String representation of a component."""
234+
return self.name
235+
236+
def __repr__(self):
237+
"""String representation of a component."""
238+
parameters_repr = ", ".join(
239+
[f"{key}={safe_repr(value)}" for key, value in self.parameters.items()],
240+
)
241+
return f"{(type(self).__name__)}({parameters_repr})"
242+
243+
def update_parameters(self, update_dict, reset_fit=True):
244+
"""Updates the parameter dictionary of the component.
245+
246+
Args:
247+
update_dict (dict): A dict of parameters to update.
248+
reset_fit (bool, optional): If True, will set `_is_fitted` to False.
249+
"""
250+
self._parameters.update(update_dict)
251+
if reset_fit:
252+
self._is_fitted = False
253+
254+
def _handle_nullable_types(self, X=None, y=None):
255+
"""Transforms X and y to remove any incompatible nullable types according to a component's needs.
256+
257+
Args:
258+
X (pd.DataFrame, optional): Input data to a component of shape [n_samples, n_features].
259+
May contain nullable types.
260+
y (pd.Series, optional): The target of length [n_samples]. May contain nullable types.
261+
262+
Returns:
263+
X, y with any incompatible nullable types downcasted to compatible equivalents.
264+
"""
265+
X_bool_incompatible = "X" in self._boolean_nullable_incompatibilities
266+
X_int_incompatible = "X" in self._integer_nullable_incompatibilities
267+
if X is not None and (X_bool_incompatible or X_int_incompatible):
268+
X = _downcast_nullable_X(
269+
X,
270+
handle_boolean_nullable=X_bool_incompatible,
271+
handle_integer_nullable=X_int_incompatible,
272+
)
273+
274+
y_bool_incompatible = "y" in self._boolean_nullable_incompatibilities
275+
y_int_incompatible = "y" in self._integer_nullable_incompatibilities
276+
if y is not None and (y_bool_incompatible or y_int_incompatible):
277+
y = _downcast_nullable_y(
278+
y,
279+
handle_boolean_nullable=y_bool_incompatible,
280+
handle_integer_nullable=y_int_incompatible,
281+
)
282+
283+
return X, y

0 commit comments

Comments
 (0)