1
+ """Base class for all components."""
2
+ import copy
3
+ from abc import ABC , abstractmethod
4
+
5
+ import cloudpickle
6
+
7
+ from checkmates .exceptions import MethodPropertyNotFoundError
8
+ from checkmates .pipelines .component_base_meta import ComponentBaseMeta
9
+ from checkmates .utils import (
10
+ _downcast_nullable_X ,
11
+ _downcast_nullable_y ,
12
+ classproperty ,
13
+ infer_feature_types ,
14
+ log_subtitle ,
15
+ safe_repr ,
16
+ )
17
+ from checkmates .utils .logger import get_logger
18
+
19
+
20
+ class ComponentBase (ABC , metaclass = ComponentBaseMeta ):
21
+ """Base class for all components.
22
+
23
+ Args:
24
+ parameters (dict): Dictionary of parameters for the component. Defaults to None.
25
+ component_obj (obj): Third-party objects useful in component implementation. Defaults to None.
26
+ random_seed (int): Seed for the random number generator. Defaults to 0.
27
+ """
28
+
29
+ _default_parameters = None
30
+ _can_be_used_for_fast_partial_dependence = True
31
+ # Referring to the pandas nullable dtypes; not just woodwork logical types
32
+ _integer_nullable_incompatibilities = []
33
+ _boolean_nullable_incompatibilities = []
34
+ is_multiseries = False
35
+
36
+ def __init__ (self , parameters = None , component_obj = None , random_seed = 0 , ** kwargs ):
37
+ """Base class for all components.
38
+
39
+ Args:
40
+ parameters (dict): Dictionary of parameters for the component. Defaults to None.
41
+ component_obj (obj): Third-party objects useful in component implementation. Defaults to None.
42
+ random_seed (int): Seed for the random number generator. Defaults to 0.
43
+ kwargs (Any): Any keyword arguments to pass into the component.
44
+ """
45
+ self .random_seed = random_seed
46
+ self ._component_obj = component_obj
47
+ self ._parameters = parameters or {}
48
+ self ._is_fitted = False
49
+
50
+ @property
51
+ @classmethod
52
+ @abstractmethod
53
+ def name (cls ):
54
+ """Returns string name of this component."""
55
+
56
+ @property
57
+ @classmethod
58
+ @abstractmethod
59
+ def modifies_features (cls ):
60
+ """Returns whether this component modifies (subsets or transforms) the features variable during transform.
61
+
62
+ For Estimator objects, this attribute determines if the return
63
+ value from `predict` or `predict_proba` should be used as
64
+ features or targets.
65
+ """
66
+
67
+ @property
68
+ @classmethod
69
+ @abstractmethod
70
+ def modifies_target (cls ):
71
+ """Returns whether this component modifies (subsets or transforms) the target variable during transform.
72
+
73
+ For Estimator objects, this attribute determines if the return
74
+ value from `predict` or `predict_proba` should be used as
75
+ features or targets.
76
+ """
77
+
78
+ @property
79
+ @classmethod
80
+ @abstractmethod
81
+ def training_only (cls ):
82
+ """Returns whether or not this component should be evaluated during training-time only, or during both training and prediction time."""
83
+
84
+ @classproperty
85
+ def needs_fitting (self ):
86
+ """Returns boolean determining if component needs fitting before calling predict, predict_proba, transform, or feature_importances.
87
+
88
+ This can be overridden to False for components that do not need to be fit or whose fit methods do nothing.
89
+
90
+ Returns:
91
+ True.
92
+ """
93
+ return True
94
+
95
+ @property
96
+ def parameters (self ):
97
+ """Returns the parameters which were used to initialize the component."""
98
+ return copy .copy (self ._parameters )
99
+
100
+ @classproperty
101
+ def default_parameters (cls ):
102
+ """Returns the default parameters for this component.
103
+
104
+ Our convention is that Component.default_parameters == Component().parameters.
105
+
106
+ Returns:
107
+ dict: Default parameters for this component.
108
+ """
109
+ if cls ._default_parameters is None :
110
+ cls ._default_parameters = cls ().parameters
111
+
112
+ return cls ._default_parameters
113
+
114
+ @classproperty
115
+ def _supported_by_list_API (cls ):
116
+ return not cls .modifies_target
117
+
118
+ def _handle_partial_dependence_fast_mode (
119
+ self ,
120
+ pipeline_parameters ,
121
+ X = None ,
122
+ target = None ,
123
+ ):
124
+ """Determines whether or not a component can be used with partial dependence's fast mode.
125
+
126
+ Args:
127
+ pipeline_parameters (dict): Pipeline parameters that will be used to create the pipelines
128
+ used in partial dependence fast mode.
129
+ X (pd.DataFrame, optional): Holdout data being used for partial dependence calculations.
130
+ target (str, optional): The target whose values we are trying to predict.
131
+ """
132
+ if self ._can_be_used_for_fast_partial_dependence :
133
+ return pipeline_parameters
134
+
135
+ raise TypeError (
136
+ f"Component { self .name } cannot run partial dependence fast mode." ,
137
+ )
138
+
139
+ def clone (self ):
140
+ """Constructs a new component with the same parameters and random state.
141
+
142
+ Returns:
143
+ A new instance of this component with identical parameters and random state.
144
+ """
145
+ return self .__class__ (** self .parameters , random_seed = self .random_seed )
146
+
147
+ def fit (self , X , y = None ):
148
+ """Fits component to data.
149
+
150
+ Args:
151
+ X (pd.DataFrame): The input training data of shape [n_samples, n_features]
152
+ y (pd.Series, optional): The target training data of length [n_samples]
153
+
154
+ Returns:
155
+ self
156
+
157
+ Raises:
158
+ MethodPropertyNotFoundError: If component does not have a fit method or a component_obj that implements fit.
159
+ """
160
+ X = infer_feature_types (X )
161
+ if y is not None :
162
+ y = infer_feature_types (y )
163
+ try :
164
+ self ._component_obj .fit (X , y )
165
+ return self
166
+ except AttributeError :
167
+ raise MethodPropertyNotFoundError (
168
+ "Component requires a fit method or a component_obj that implements fit" ,
169
+ )
170
+
171
+ def describe (self , print_name = False , return_dict = False ):
172
+ """Describe a component and its parameters.
173
+
174
+ Args:
175
+ print_name(bool, optional): whether to print name of component
176
+ return_dict(bool, optional): whether to return description as dictionary in the format {"name": name, "parameters": parameters}
177
+
178
+ Returns:
179
+ None or dict: Returns dictionary if return_dict is True, else None.
180
+ """
181
+ logger = get_logger (f"{ __name__ } .describe" )
182
+ if print_name :
183
+ title = self .name
184
+ log_subtitle (logger , title )
185
+ for parameter in self .parameters :
186
+ parameter_str = ("\t * {} : {}" ).format (
187
+ parameter ,
188
+ self .parameters [parameter ],
189
+ )
190
+ logger .info (parameter_str )
191
+ if return_dict :
192
+ component_dict = {"name" : self .name }
193
+ component_dict .update ({"parameters" : self .parameters })
194
+ return component_dict
195
+
196
+ def save (self , file_path , pickle_protocol = cloudpickle .DEFAULT_PROTOCOL ):
197
+ """Saves component at file path.
198
+
199
+ Args:
200
+ file_path (str): Location to save file.
201
+ pickle_protocol (int): The pickle data stream format.
202
+ """
203
+ with open (file_path , "wb" ) as f :
204
+ cloudpickle .dump (self , f , protocol = pickle_protocol )
205
+
206
+ @staticmethod
207
+ def load (file_path ):
208
+ """Loads component at file path.
209
+
210
+ Args:
211
+ file_path (str): Location to load file.
212
+
213
+ Returns:
214
+ ComponentBase object
215
+ """
216
+ with open (file_path , "rb" ) as f :
217
+ return cloudpickle .load (f )
218
+
219
+ def __eq__ (self , other ):
220
+ """Check for equality."""
221
+ if not isinstance (other , self .__class__ ):
222
+ return False
223
+ random_seed_eq = self .random_seed == other .random_seed
224
+ if not random_seed_eq :
225
+ return False
226
+ attributes_to_check = ["_parameters" , "_is_fitted" ]
227
+ for attribute in attributes_to_check :
228
+ if getattr (self , attribute ) != getattr (other , attribute ):
229
+ return False
230
+ return True
231
+
232
+ def __str__ (self ):
233
+ """String representation of a component."""
234
+ return self .name
235
+
236
+ def __repr__ (self ):
237
+ """String representation of a component."""
238
+ parameters_repr = ", " .join (
239
+ [f"{ key } ={ safe_repr (value )} " for key , value in self .parameters .items ()],
240
+ )
241
+ return f"{ (type (self ).__name__ )} ({ parameters_repr } )"
242
+
243
+ def update_parameters (self , update_dict , reset_fit = True ):
244
+ """Updates the parameter dictionary of the component.
245
+
246
+ Args:
247
+ update_dict (dict): A dict of parameters to update.
248
+ reset_fit (bool, optional): If True, will set `_is_fitted` to False.
249
+ """
250
+ self ._parameters .update (update_dict )
251
+ if reset_fit :
252
+ self ._is_fitted = False
253
+
254
+ def _handle_nullable_types (self , X = None , y = None ):
255
+ """Transforms X and y to remove any incompatible nullable types according to a component's needs.
256
+
257
+ Args:
258
+ X (pd.DataFrame, optional): Input data to a component of shape [n_samples, n_features].
259
+ May contain nullable types.
260
+ y (pd.Series, optional): The target of length [n_samples]. May contain nullable types.
261
+
262
+ Returns:
263
+ X, y with any incompatible nullable types downcasted to compatible equivalents.
264
+ """
265
+ X_bool_incompatible = "X" in self ._boolean_nullable_incompatibilities
266
+ X_int_incompatible = "X" in self ._integer_nullable_incompatibilities
267
+ if X is not None and (X_bool_incompatible or X_int_incompatible ):
268
+ X = _downcast_nullable_X (
269
+ X ,
270
+ handle_boolean_nullable = X_bool_incompatible ,
271
+ handle_integer_nullable = X_int_incompatible ,
272
+ )
273
+
274
+ y_bool_incompatible = "y" in self ._boolean_nullable_incompatibilities
275
+ y_int_incompatible = "y" in self ._integer_nullable_incompatibilities
276
+ if y is not None and (y_bool_incompatible or y_int_incompatible ):
277
+ y = _downcast_nullable_y (
278
+ y ,
279
+ handle_boolean_nullable = y_bool_incompatible ,
280
+ handle_integer_nullable = y_int_incompatible ,
281
+ )
282
+
283
+ return X , y
0 commit comments