-
Notifications
You must be signed in to change notification settings - Fork 309
Open
Labels
ModelBuilderRelated to the ModelBuilder class and its childrenRelated to the ModelBuilder class and its childrenmaintenancemodel componentsRelated to the various model componentsRelated to the various model components
Description
#1870 should be merged before opening a PR for this issue.
MMMs are the only models using ModelBuilder.fit
directly. Other models have to override with their own fitters & samplers due to different data-loading requirements, which has led to code redundancies for idata extensions and prior/predictive checks. The CLVModels also provide fitters other than NUTS and a wrapper for arviz.fit_summary
that could easily be moved into ModelBuilder
for shared functionality.
Let's consolidate all shared boilerplate into ModelBuilder.fit
, and pare down all the various model.fit(**specific_data_kwargs)
methods into the bare minimum needed to call super().fit(**kwargs)
. We can start with this code created by @williambdean in #1467:
Model fit starter code
def create_fit_data(self) -> xr.Dataset:
"""Create the fit_data group based on the input data."""
if isinstance(self.data, pd.DataFrame):
return self.data.to_xarray()
return self.data
def fit(
self,
progressbar: bool | None = None,
random_seed: RandomState | None = None,
**kwargs: Any,
) -> az.InferenceData:
"""Fit a model using the data provided at initialization.
Sets attrs to inference data of the model.
Parameters
----------
progressbar : bool, optional
Specifies whether the fit progress bar should be displayed. Defaults to True.
random_seed : Optional[RandomState]
Provides sampler with initial random seed for obtaining reproducible samples.
**kwargs : Any
Custom sampler settings can be provided in form of keyword arguments.
Returns
-------
self : az.InferenceData
Returns inference data of the fitted model.
Examples
--------
>>> model = MyModel()
>>> idata = model.fit()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
"""
if not hasattr(self, "model"):
self.build_model()
sampler_kwargs = create_sample_kwargs(
self.sampler_config,
progressbar,
random_seed,
**kwargs,
)
with self.model:
idata = pm.sample(**sampler_kwargs)
if self.idata:
self.idata = self.idata.copy()
self.idata.extend(idata, join="right")
else:
self.idata = idata
if "fit_data" in self.idata:
del self.idata.fit_data
fit_data = self.create_fit_data()
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="The group fit_data is not defined in the InferenceData scheme",
)
self.idata.add_groups(fit_data=fit_data)
self.set_idata_attrs(self.idata)
return self.idata # type: ignore
Metadata
Metadata
Assignees
Labels
ModelBuilderRelated to the ModelBuilder class and its childrenRelated to the ModelBuilder class and its childrenmaintenancemodel componentsRelated to the various model componentsRelated to the various model components