Skip to content

Consolidate model fitters into ModelBuider #1871

@ColtAllen

Description

@ColtAllen

#1870 should be merged before opening a PR for this issue.

MMMs are the only models using ModelBuilder.fit directly. Other models have to override with their own fitters & samplers due to different data-loading requirements, which has led to code redundancies for idata extensions and prior/predictive checks. The CLVModels also provide fitters other than NUTS and a wrapper for arviz.fit_summary that could easily be moved into ModelBuilder for shared functionality.

Let's consolidate all shared boilerplate into ModelBuilder.fit, and pare down all the various model.fit(**specific_data_kwargs) methods into the bare minimum needed to call super().fit(**kwargs). We can start with this code created by @williambdean in #1467:

Model fit starter code
def create_fit_data(self) -> xr.Dataset:
	"""Create the fit_data group based on the input data."""
	if isinstance(self.data, pd.DataFrame):
		return self.data.to_xarray()
	return self.data

def fit(
	self,
	progressbar: bool | None = None,
	random_seed: RandomState | None = None,
	**kwargs: Any,
) -> az.InferenceData:
	"""Fit a model using the data provided at initialization.
	
	Sets attrs to inference data of the model.

	Parameters
	----------
	progressbar : bool, optional
		Specifies whether the fit progress bar should be displayed. Defaults to True.
	random_seed : Optional[RandomState]
		Provides sampler with initial random seed for obtaining reproducible samples.
	**kwargs : Any
		Custom sampler settings can be provided in form of keyword arguments.
	
	Returns
	-------
	self : az.InferenceData
		Returns inference data of the fitted model.
	
	Examples
	--------
>>> 	model = MyModel()
>>> 	idata = model.fit()
	Auto-assigning NUTS sampler...
	Initializing NUTS using jitter+adapt_diag...
	"""
	if not hasattr(self, "model"):
		self.build_model()
		
	sampler_kwargs = create_sample_kwargs(
		self.sampler_config,
		progressbar,
		random_seed,
		**kwargs,
	)
	with self.model:
		idata = pm.sample(**sampler_kwargs)
	
	if self.idata:
		self.idata = self.idata.copy()
		self.idata.extend(idata, join="right")
	else:
		self.idata = idata
	
	if "fit_data" in self.idata:
		del self.idata.fit_data
	
	fit_data = self.create_fit_data()

	with warnings.catch_warnings():
		warnings.filterwarnings(
			"ignore",
			category=UserWarning,
			message="The group fit_data is not defined in the InferenceData scheme",
		)
		self.idata.add_groups(fit_data=fit_data)
	self.set_idata_attrs(self.idata)
	return self.idata # type: ignore

Metadata

Metadata

Assignees

No one assigned

    Labels

    ModelBuilderRelated to the ModelBuilder class and its childrenmaintenancemodel componentsRelated to the various model components

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions