@@ -50,7 +50,6 @@ class ModelBuilder:
50
50
51
51
def __init__ (
52
52
self ,
53
- data : Union [np .ndarray , pd .DataFrame , pd .Series ] = None ,
54
53
model_config : Dict = None ,
55
54
sampler_config : Dict = None ,
56
55
):
@@ -77,10 +76,8 @@ def __init__(
77
76
78
77
self .model_config = model_config # parameters for priors etc.
79
78
self .model = None # Set by build_model
80
- self .output_var = "" # Set by build_model
81
79
self .idata : Optional [az .InferenceData ] = None # idata is generated during fitting
82
80
self .is_fitted_ = False
83
- self .data = data
84
81
85
82
def _validate_data (self , X , y = None ):
86
83
if y is not None :
@@ -122,6 +119,19 @@ def _data_setter(
122
119
123
120
raise NotImplementedError
124
121
122
+ @property
123
+ @abstractmethod
124
+ def output_var (self ):
125
+ """
126
+ Returns the name of the output variable of the model.
127
+
128
+ Returns
129
+ -------
130
+ output_var : str
131
+ Name of the output variable of the model.
132
+ """
133
+ raise NotImplementedError
134
+
125
135
@property
126
136
@abstractmethod
127
137
def default_model_config (self ) -> Dict :
@@ -176,39 +186,41 @@ def default_sampler_config(self) -> Dict:
176
186
raise NotImplementedError
177
187
178
188
@abstractmethod
179
- def generate_model_data (
180
- self , data : Union [np . ndarray , pd .DataFrame , pd .Series ] = None
181
- ) -> pd . DataFrame :
189
+ def generate_and_preprocess_model_data (
190
+ self , X : Union [pd .DataFrame , pd .Series ], y : pd . Series
191
+ ) -> None :
182
192
"""
183
- Returns a default dataset for a class, can be used as a hint to data formatting required for the class
184
- If data is not None, dataset will be created from it's content.
193
+ Applies preprocessing to the data before fitting the model.
194
+ if validate is True, it will check if the data is valid for the model.
195
+ sets self.model_coords based on provided dataset
185
196
186
197
Parameters:
187
- data : Union[np.ndarray, pd.DataFrame, pd.Series], optional
188
- dataset that will replace the default sample data
189
-
198
+ X : array, shape (n_obs, n_features)
199
+ y : array, shape (n_obs,)
190
200
191
201
Examples
192
202
--------
193
203
>>> @classmethod
194
- >>> def generate_model_data (self):
204
+ >>> def generate_and_preprocess_model_data (self, X, y ):
195
205
>>> x = np.linspace(start=1, stop=50, num=100)
196
206
>>> y = 5 * x + 3 + np.random.normal(0, 1, len(x)) * np.random.rand(100)*10 + np.random.rand(100)*6.4
197
- >>> data = pd.DataFrame({'input': x, 'output': y})
207
+ >>> X = pd.DataFrame(x, columns=['x'])
208
+ >>> y = pd.Series(y, name='y')
209
+ >>> self.X = X
210
+ >>> self.y = y
198
211
199
212
Returns
200
213
-------
201
- data : pd.DataFrame
202
- The data we want to train the model on.
214
+ None
203
215
204
216
"""
205
217
raise NotImplementedError
206
218
207
219
@abstractmethod
208
220
def build_model (
209
221
self ,
210
- data : Union [ np . ndarray , pd .DataFrame , pd . Series ] = None ,
211
- model_config : Dict = None ,
222
+ X : pd .DataFrame ,
223
+ y : pd . Series ,
212
224
** kwargs ,
213
225
) -> None :
214
226
"""
@@ -217,22 +229,31 @@ def build_model(
217
229
218
230
Parameters
219
231
----------
220
- data : dict
221
- Preformated data that is going to be used in the model. For efficiency reasons it should contain only the necesary data columns,
222
- not entire available dataset since it's going to be encoded into data used to recreate the model.
223
- If not provided uses data from self.data
224
- model_config : dict
225
- Dictionary where keys are strings representing names of parameters of the model, values are dictionaries of parameters
226
- needed for creating model parameters. If not provided uses data from self.model_config
232
+ X : pd.DataFrame
233
+ The input data that is going to be used in the model. This should be a DataFrame
234
+ containing the features (predictors) for the model. For efficiency reasons, it should
235
+ only contain the necessary data columns, not the entire available dataset, as this
236
+ will be encoded into the data used to recreate the model.
237
+
238
+ y : pd.Series
239
+ The target data for the model. This should be a Series representing the output
240
+ or dependent variable for the model.
241
+
242
+ kwargs : dict
243
+ Additional keyword arguments that may be used for model configuration.
227
244
228
245
See Also
229
246
--------
230
247
default_model_config : returns default model config
231
248
232
- Returns:
233
- ----------
249
+ Returns
250
+ -------
234
251
None
235
252
253
+ Raises
254
+ ------
255
+ NotImplementedError
256
+ This is an abstract method and must be implemented in a subclass.
236
257
"""
237
258
raise NotImplementedError
238
259
@@ -248,7 +269,7 @@ def sample_model(self, **kwargs):
248
269
Returns
249
270
-------
250
271
xarray.Dataset
251
- The PyMC3 samples dataset.
272
+ The PyMC samples dataset.
252
273
253
274
Raises
254
275
------
@@ -383,12 +404,14 @@ def load(cls, fname: str):
383
404
filepath = Path (str (fname ))
384
405
idata = az .from_netcdf (filepath )
385
406
model = cls (
386
- data = idata .fit_data .to_dataframe (),
387
407
model_config = json .loads (idata .attrs ["model_config" ]),
388
408
sampler_config = json .loads (idata .attrs ["sampler_config" ]),
389
409
)
390
410
model .idata = idata
391
- model .build_model ()
411
+ dataset = idata .fit_data .to_dataframe ()
412
+ X = dataset .drop (columns = [model .output_var ])
413
+ y = dataset [model .output_var ]
414
+ model .build_model (X , y )
392
415
# All previously used data is in idata.
393
416
394
417
if model .id != idata .attrs ["id" ]:
@@ -400,8 +423,8 @@ def load(cls, fname: str):
400
423
401
424
def fit (
402
425
self ,
403
- X : Union [ np . ndarray , pd .DataFrame , pd . Series ] ,
404
- y : Union [ np . ndarray , pd .Series ] ,
426
+ X : pd .DataFrame ,
427
+ y : pd .Series ,
405
428
progressbar : bool = True ,
406
429
predictor_names : List [str ] = None ,
407
430
random_seed : RandomState = None ,
@@ -442,25 +465,19 @@ def fit(
442
465
if predictor_names is None :
443
466
predictor_names = []
444
467
445
- X , y = X , y
446
-
447
- self .build_model (data = self .data )
448
- self ._data_setter (X , y )
468
+ y = pd .DataFrame ({self .output_var : y })
469
+ self .generate_and_preprocess_model_data (X , y .values .flatten ())
470
+ self .build_model (self .X , self .y )
449
471
450
472
sampler_config = self .sampler_config .copy ()
451
473
sampler_config ["progressbar" ] = progressbar
452
474
sampler_config ["random_seed" ] = random_seed
453
475
sampler_config .update (** kwargs )
454
-
455
476
self .idata = self .sample_model (** sampler_config )
456
- if type (X ) is np .ndarray :
457
- if len (predictor_names ) > 0 :
458
- X = pd .DataFrame (X , columns = predictor_names )
459
- else :
460
- X = pd .DataFrame (X , columns = [f"predictor{ x } " for x in range (1 , X .shape [1 ] + 1 )])
461
- if type (y ) is np .ndarray :
462
- y = pd .Series (y , name = "target" )
463
- combined_data = pd .concat ([X , y ], axis = 1 )
477
+
478
+ X_df = pd .DataFrame (X , columns = X .columns )
479
+ combined_data = pd .concat ([X_df , y ], axis = 1 )
480
+ assert all (combined_data .columns ), "All columns must have non-empty names"
464
481
self .idata .add_groups (fit_data = combined_data .to_xarray ()) # type: ignore
465
482
return self .idata # type: ignore
466
483
@@ -513,6 +530,7 @@ def predict(
513
530
def sample_prior_predictive (
514
531
self ,
515
532
X_pred ,
533
+ y_pred = None ,
516
534
samples : Optional [int ] = None ,
517
535
extend_idata : bool = False ,
518
536
combined : bool = True ,
@@ -539,13 +557,15 @@ def sample_prior_predictive(
539
557
prior_predictive_samples : DataArray, shape (n_pred, samples)
540
558
Prior predictive samples for each input X_pred
541
559
"""
560
+ if y_pred is None :
561
+ y_pred = np .zeros (len (X_pred ))
542
562
if samples is None :
543
563
samples = self .sampler_config .get ("draws" , 500 )
544
564
545
565
if self .model is None :
546
- self .build_model ()
566
+ self .build_model (X_pred , y_pred )
547
567
548
- self ._data_setter (X_pred )
568
+ self ._data_setter (X_pred , y_pred )
549
569
if self .model is not None :
550
570
with self .model : # sample with new input data
551
571
prior_pred : az .InferenceData = pm .sample_prior_predictive (samples , ** kwargs )
0 commit comments