Skip to content

Commit ad84e90

Browse files
ColtAllenwilliambdeanjuanitorduz
authored
Refactor ModelBuilder into smaller classes (#1870)
* break out into smaller classes * change the default to True * WIP unit testing * requires_model methods * test_graphviz * TODOs and fix MOdelBuilder tests * remove test model file * fix customer and MMM tests * WIP fix clv tests * fix clv base tests * fix remaining clv tests * docstrings * docstring edit * remove idata_to_init_kwargs * CLV models inherit from BaseModelBuilder * WIP test cleanup * fix base model tests * clean up tests * testing coverage * add abstract methods to ModelBuilder * docstrings * docstrings return model_config * remove duplicated self.data attr --------- Co-authored-by: Will Dean <[email protected]> Co-authored-by: Juan Orduz <[email protected]>
1 parent 7b56b6e commit ad84e90

File tree

14 files changed

+1128
-655
lines changed

14 files changed

+1128
-655
lines changed

pymc_marketing/clv/models/basic.py

Lines changed: 25 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
# limitations under the License.
1414
"""CLV Model base class."""
1515

16-
import json
1716
import warnings
1817
from collections.abc import Sequence
19-
from pathlib import Path
2018
from typing import Literal, cast
2119

2220
import arviz as az
@@ -27,9 +25,8 @@
2725
from pymc.backends.base import MultiTrace
2826
from pymc.model.core import Model
2927

30-
from pymc_marketing.model_builder import ModelBuilder
28+
from pymc_marketing.model_builder import DifferentModelError, ModelBuilder
3129
from pymc_marketing.model_config import ModelConfig, parse_model_config
32-
from pymc_marketing.utils import from_netcdf
3330

3431

3532
class CLVModel(ModelBuilder):
@@ -46,6 +43,7 @@ def __init__(
4643
sampler_config: dict | None = None,
4744
non_distributions: list[str] | None = None,
4845
):
46+
self.data = data
4947
model_config = model_config or {}
5048

5149
deprecated_keys = [key for key in model_config if key.endswith("_prior")]
@@ -60,14 +58,14 @@ def __init__(
6058

6159
model_config[new_key] = model_config.pop(key)
6260

63-
model_config = parse_model_config(
64-
model_config,
61+
super().__init__(model_config, sampler_config)
62+
63+
# Parse model config after merging with defaults
64+
self.model_config = parse_model_config(
65+
self.model_config,
6566
non_distributions=non_distributions,
6667
)
6768

68-
super().__init__(model_config, sampler_config)
69-
self.data = data
70-
7169
@staticmethod
7270
def _validate_cols(
7371
data: pd.DataFrame,
@@ -260,59 +258,39 @@ def _fit_approx(
260258
)
261259

262260
@classmethod
263-
def load(cls, fname: str):
264-
"""Create a ModelBuilder instance from a file.
265-
266-
Loads inference data for the model.
267-
268-
Parameters
269-
----------
270-
fname : string
271-
This denotes the name with path from where idata should be loaded from.
261+
def idata_to_init_kwargs(cls, idata: az.InferenceData) -> dict:
262+
"""Create the initialization kwargs from an InferenceData object."""
263+
kwargs = cls.attrs_to_init_kwargs(idata.attrs)
264+
kwargs["data"] = idata.fit_data.to_dataframe()
272265

273-
Returns
274-
-------
275-
Returns an instance of ModelBuilder.
276-
277-
Raises
278-
------
279-
ValueError
280-
If the inference data that is loaded doesn't match with the model.
281-
282-
Examples
283-
--------
284-
>>> class MyModel(ModelBuilder):
285-
>>> ...
286-
>>> name = "./mymodel.nc"
287-
>>> imported_model = MyModel.load(name)
288-
289-
"""
290-
filepath = Path(str(fname))
291-
idata = from_netcdf(filepath)
292-
return cls._build_with_idata(idata)
266+
return kwargs
293267

294268
@classmethod
295-
def _build_with_idata(cls, idata: az.InferenceData):
296-
dataset = idata.fit_data.to_dataframe()
269+
def build_from_idata(cls, idata: az.InferenceData) -> None:
270+
"""Build the model from the InferenceData object."""
271+
kwargs = cls.idata_to_init_kwargs(idata)
297272
with warnings.catch_warnings():
298273
warnings.filterwarnings(
299274
"ignore",
300275
category=DeprecationWarning,
301276
)
302-
model = cls(
303-
dataset,
304-
model_config=json.loads(idata.attrs["model_config"]), # type: ignore
305-
sampler_config=json.loads(idata.attrs["sampler_config"]),
306-
)
277+
model = cls(**kwargs)
307278

308279
model.idata = idata
309280
model._rename_posterior_variables()
310281

311282
model.build_model() # type: ignore
312283
if model.id != idata.attrs["id"]:
313-
raise ValueError(f"Inference data not compatible with {cls._model_type}")
284+
msg = (
285+
"The model id in the InferenceData does not match the model id. "
286+
"There was no error loading the inference data, but the model may "
287+
"be different. "
288+
"Investigate if the model structure or configuration has changed."
289+
)
290+
raise DifferentModelError(msg)
314291
return model
315292

293+
# TODO: Remove in 2026Q1?
316294
def _rename_posterior_variables(self):
317295
"""Rename variables in the posterior group to remove the _prior suffix.
318296
@@ -355,7 +333,7 @@ def thin_fit_result(self, keep_every: int):
355333
self.fit_result # noqa: B018 (Raise Error if fit didn't happen yet)
356334
assert self.idata is not None # noqa: S101
357335
new_idata = self.idata.isel(draw=slice(None, None, keep_every)).copy()
358-
return type(self)._build_with_idata(new_idata)
336+
return self.build_from_idata(new_idata)
359337

360338
@property
361339
def default_sampler_config(self) -> dict:
@@ -378,8 +356,3 @@ def fit_summary(self, **kwargs):
378356
return res["mean"].rename("value")
379357
else:
380358
return az.summary(self.fit_result, **kwargs)
381-
382-
@property
383-
def output_var(self):
384-
"""Output variable of the model."""
385-
pass

pymc_marketing/customer_choice/mnl_logit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
import pytensor.tensor as pt
2626
from pymc_extras.prior import Prior
2727

28-
from pymc_marketing.model_builder import ModelBuilder
28+
from pymc_marketing.model_builder import RegressionModelBuilder
2929
from pymc_marketing.model_config import parse_model_config
3030

3131
HDI_ALPHA = 0.5
3232

3333

34-
class MNLogit(ModelBuilder):
34+
class MNLogit(RegressionModelBuilder):
3535
"""
3636
Multinomial Logit class.
3737

pymc_marketing/customer_choice/mv_its.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
from xarray import DataArray
2828

2929
from pymc_marketing.mmm.additive_effect import MuEffect
30-
from pymc_marketing.model_builder import ModelBuilder, create_idata_accessor
30+
from pymc_marketing.model_builder import RegressionModelBuilder, create_idata_accessor
3131
from pymc_marketing.model_config import parse_model_config
3232

3333
HDI_ALPHA = 0.5
3434

3535

36-
class MVITS(ModelBuilder):
36+
class MVITS(RegressionModelBuilder):
3737
"""Multivariate Interrupted Time Series class.
3838
3939
Class to perform a multivariate interrupted time series analysis with the
@@ -251,7 +251,7 @@ def _generate_and_preprocess_model_data(
251251
],
252252
}
253253

254-
def build_model(
254+
def build_model( # type: ignore[override]
255255
self,
256256
X: pd.DataFrame,
257257
y: pd.Series | np.ndarray,

pymc_marketing/customer_choice/nested_logit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
from pymc_extras.prior import Prior
2727
from pytensor.tensor.variable import TensorVariable
2828

29-
from pymc_marketing.model_builder import ModelBuilder
29+
from pymc_marketing.model_builder import RegressionModelBuilder
3030
from pymc_marketing.model_config import parse_model_config
3131

3232
HDI_ALPHA = 0.5
3333

3434

35-
class NestedLogit(ModelBuilder):
35+
class NestedLogit(RegressionModelBuilder):
3636
"""
3737
Nested Logit class.
3838

pymc_marketing/mmm/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@
4545
ValidateDateColumn,
4646
ValidateTargetColumn,
4747
)
48-
from pymc_marketing.model_builder import ModelBuilder
48+
from pymc_marketing.model_builder import RegressionModelBuilder
4949

5050
__all__ = ["BaseValidateMMM", "MMMModelBuilder"]
5151

5252
from pydantic import Field, validate_call
5353

5454

55-
class MMMModelBuilder(ModelBuilder):
55+
class MMMModelBuilder(RegressionModelBuilder):
5656
"""Base class for Marketing Mix Models (MMM)."""
5757

5858
model: pm.Model

pymc_marketing/mmm/mmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def forward_pass(self, x: pt.TensorVariable | npt.NDArray) -> pt.TensorVariable:
503503

504504
return second.apply(x=first.apply(x=x, dims="channel"), dims="channel")
505505

506-
def build_model(
506+
def build_model( # type: ignore[override]
507507
self,
508508
X: pd.DataFrame,
509509
y: pd.Series | np.ndarray,

pymc_marketing/mmm/multidimensional.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@
6262
add_noise_to_channel_allocation,
6363
create_zero_dataset,
6464
)
65-
from pymc_marketing.model_builder import ModelBuilder, _handle_deprecate_pred_argument
65+
from pymc_marketing.model_builder import (
66+
RegressionModelBuilder,
67+
_handle_deprecate_pred_argument,
68+
)
6669
from pymc_marketing.model_config import parse_model_config
6770
from pymc_marketing.model_graph import deterministics_to_flat
6871

@@ -75,7 +78,7 @@
7578
warnings.warn(warning_msg, FutureWarning, stacklevel=1)
7679

7780

78-
class MMM(ModelBuilder):
81+
class MMM(RegressionModelBuilder):
7982
"""Marketing Mix Model class for estimating the impact of marketing channels on a target variable.
8083
8184
This class implements the core functionality of a Marketing Mix Model (MMM), allowing for the

0 commit comments

Comments
 (0)