Skip to content
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d6e15a8
break out into smaller classes
williambdean Feb 4, 2025
e9f6f5c
change the default to True
williambdean Feb 4, 2025
ad351e2
Merge branch 'main' into break-down-modelbuilder
ColtAllen Jul 27, 2025
b8c1791
Merge branch 'main' into break-down-modelbuilder
ColtAllen Jul 28, 2025
63cad51
Merge branch 'main' into break-down-modelbuilder
ColtAllen Jul 29, 2025
94347b8
WIP unit testing
ColtAllen Jul 29, 2025
3085840
requires_model methods
ColtAllen Jul 29, 2025
9b03b07
test_graphviz
ColtAllen Jul 29, 2025
83dfdc8
TODOs and fix MOdelBuilder tests
ColtAllen Jul 31, 2025
3dd6fce
remove test model file
ColtAllen Jul 31, 2025
f733057
fix customer and MMM tests
ColtAllen Jul 31, 2025
536a252
WIP fix clv tests
ColtAllen Jul 31, 2025
6416c01
Merge branch 'main' into break-down-modelbuilder
ColtAllen Jul 31, 2025
21d87a6
Merge branch 'pymc-labs:break-down-modelbuilder' into break-down-mode…
ColtAllen Jul 31, 2025
5ac567b
fix clv base tests
ColtAllen Aug 1, 2025
59fa1bd
fix remaining clv tests
ColtAllen Aug 1, 2025
c7cfa41
docstrings
ColtAllen Aug 2, 2025
b4c98ee
docstring edit
ColtAllen Aug 2, 2025
f9a62ad
remove idata_to_init_kwargs
ColtAllen Aug 2, 2025
6e1816e
Merge branch 'main' into break-down-modelbuilder
ColtAllen Aug 3, 2025
ab0ae5f
CLV models inherit from BaseModelBuilder
ColtAllen Aug 3, 2025
d3a325e
Merge branch 'pymc-labs:break-down-modelbuilder' into break-down-mode…
ColtAllen Aug 3, 2025
810ac86
WIP test cleanup
ColtAllen Aug 3, 2025
1a9701a
fix base model tests
ColtAllen Aug 4, 2025
58ab9e3
clean up tests
ColtAllen Aug 4, 2025
77aa802
testing coverage
ColtAllen Aug 4, 2025
a06c6ba
add abstract methods to ModelBuilder
ColtAllen Aug 4, 2025
b523bc3
docstrings
ColtAllen Aug 4, 2025
0d5a27c
Merge branch 'main' into break-down-modelbuilder
ColtAllen Aug 9, 2025
760a848
Merge branch 'main' into break-down-modelbuilder
ColtAllen Aug 11, 2025
9fb252e
docstrings return model_config
ColtAllen Aug 11, 2025
040100c
Merge branch 'main' into break-down-modelbuilder
ColtAllen Aug 12, 2025
163d772
remove duplicated self.data attr
ColtAllen Aug 12, 2025
eb26441
Merge branch 'main' into break-down-modelbuilder
ColtAllen Aug 18, 2025
f7295af
Merge branch 'main' into break-down-modelbuilder
juanitorduz Aug 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 25 additions & 50 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
# limitations under the License.
"""CLV Model base class."""

import json
import warnings
from collections.abc import Sequence
from pathlib import Path
from typing import Literal, cast

import arviz as az
Expand All @@ -27,9 +25,8 @@
from pymc.backends.base import MultiTrace
from pymc.model.core import Model

from pymc_marketing.model_builder import ModelBuilder
from pymc_marketing.model_builder import DifferentModelError, ModelBuilder
from pymc_marketing.model_config import ModelConfig, parse_model_config
from pymc_marketing.utils import from_netcdf


class CLVModel(ModelBuilder):
Expand All @@ -46,6 +43,7 @@ def __init__(
sampler_config: dict | None = None,
non_distributions: list[str] | None = None,
):
self.data = data
model_config = model_config or {}

deprecated_keys = [key for key in model_config if key.endswith("_prior")]
Expand All @@ -60,12 +58,14 @@ def __init__(

model_config[new_key] = model_config.pop(key)

model_config = parse_model_config(
model_config,
super().__init__(model_config, sampler_config)

# Parse model config after merging with defaults
self.model_config = parse_model_config(
self.model_config,
non_distributions=non_distributions,
)

super().__init__(model_config, sampler_config)
self.data = data

@staticmethod
Expand Down Expand Up @@ -260,59 +260,39 @@ def _fit_approx(
)

@classmethod
def load(cls, fname: str):
"""Create a ModelBuilder instance from a file.

Loads inference data for the model.

Parameters
----------
fname : string
This denotes the name with path from where idata should be loaded from.

Returns
-------
Returns an instance of ModelBuilder.
def idata_to_init_kwargs(cls, idata: az.InferenceData) -> dict:
"""Create the initialization kwargs from an InferenceData object."""
kwargs = cls.attrs_to_init_kwargs(idata.attrs)
kwargs["data"] = idata.fit_data.to_dataframe()

Raises
------
ValueError
If the inference data that is loaded doesn't match with the model.

Examples
--------
>>> class MyModel(ModelBuilder):
>>> ...
>>> name = "./mymodel.nc"
>>> imported_model = MyModel.load(name)

"""
filepath = Path(str(fname))
idata = from_netcdf(filepath)
return cls._build_with_idata(idata)
return kwargs

@classmethod
def _build_with_idata(cls, idata: az.InferenceData):
dataset = idata.fit_data.to_dataframe()
def build_from_idata(cls, idata: az.InferenceData) -> None:
"""Build the model from the InferenceData object."""
kwargs = cls.idata_to_init_kwargs(idata)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
)
model = cls(
dataset,
model_config=json.loads(idata.attrs["model_config"]), # type: ignore
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
model = cls(**kwargs)

model.idata = idata
model._rename_posterior_variables()

model.build_model() # type: ignore
if model.id != idata.attrs["id"]:
raise ValueError(f"Inference data not compatible with {cls._model_type}")
msg = (
"The model id in the InferenceData does not match the model id. "
"There was no error loading the inference data, but the model may "
"be different. "
"Investigate if the model structure or configuration has changed."
)
raise DifferentModelError(msg)
return model

# TODO: Remove in 2026Q1?
def _rename_posterior_variables(self):
"""Rename variables in the posterior group to remove the _prior suffix.

Expand Down Expand Up @@ -355,7 +335,7 @@ def thin_fit_result(self, keep_every: int):
self.fit_result # noqa: B018 (Raise Error if fit didn't happen yet)
assert self.idata is not None # noqa: S101
new_idata = self.idata.isel(draw=slice(None, None, keep_every)).copy()
return type(self)._build_with_idata(new_idata)
return self.build_from_idata(new_idata)

@property
def default_sampler_config(self) -> dict:
Expand All @@ -378,8 +358,3 @@ def fit_summary(self, **kwargs):
return res["mean"].rename("value")
else:
return az.summary(self.fit_result, **kwargs)

@property
def output_var(self):
"""Output variable of the model."""
pass
4 changes: 2 additions & 2 deletions pymc_marketing/customer_choice/mnl_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import pytensor.tensor as pt
from pymc_extras.prior import Prior

from pymc_marketing.model_builder import ModelBuilder
from pymc_marketing.model_builder import RegressionModelBuilder
from pymc_marketing.model_config import parse_model_config

HDI_ALPHA = 0.5


class MNLogit(ModelBuilder):
class MNLogit(RegressionModelBuilder):
"""
Multinomial Logit class.

Expand Down
6 changes: 3 additions & 3 deletions pymc_marketing/customer_choice/mv_its.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
from xarray import DataArray

from pymc_marketing.mmm.additive_effect import MuEffect
from pymc_marketing.model_builder import ModelBuilder, create_idata_accessor
from pymc_marketing.model_builder import RegressionModelBuilder, create_idata_accessor
from pymc_marketing.model_config import parse_model_config

HDI_ALPHA = 0.5


class MVITS(ModelBuilder):
class MVITS(RegressionModelBuilder):
"""Multivariate Interrupted Time Series class.

Class to perform a multivariate interrupted time series analysis with the
Expand Down Expand Up @@ -251,7 +251,7 @@ def _generate_and_preprocess_model_data(
],
}

def build_model(
def build_model( # type: ignore[override]
self,
X: pd.DataFrame,
y: pd.Series | np.ndarray,
Expand Down
4 changes: 2 additions & 2 deletions pymc_marketing/customer_choice/nested_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
from pymc_extras.prior import Prior
from pytensor.tensor.variable import TensorVariable

from pymc_marketing.model_builder import ModelBuilder
from pymc_marketing.model_builder import RegressionModelBuilder
from pymc_marketing.model_config import parse_model_config

HDI_ALPHA = 0.5


class NestedLogit(ModelBuilder):
class NestedLogit(RegressionModelBuilder):
"""
Nested Logit class.

Expand Down
4 changes: 2 additions & 2 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@
ValidateDateColumn,
ValidateTargetColumn,
)
from pymc_marketing.model_builder import ModelBuilder
from pymc_marketing.model_builder import RegressionModelBuilder

__all__ = ["BaseValidateMMM", "MMMModelBuilder"]

from pydantic import Field, validate_call


class MMMModelBuilder(ModelBuilder):
class MMMModelBuilder(RegressionModelBuilder):
"""Base class for Marketing Mix Models (MMM)."""

model: pm.Model
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def forward_pass(self, x: pt.TensorVariable | npt.NDArray) -> pt.TensorVariable:

return second.apply(x=first.apply(x=x, dims="channel"), dims="channel")

def build_model(
def build_model( # type: ignore[override]
self,
X: pd.DataFrame,
y: pd.Series | np.ndarray,
Expand Down
7 changes: 5 additions & 2 deletions pymc_marketing/mmm/multidimensional.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@
add_noise_to_channel_allocation,
create_zero_dataset,
)
from pymc_marketing.model_builder import ModelBuilder, _handle_deprecate_pred_argument
from pymc_marketing.model_builder import (
RegressionModelBuilder,
_handle_deprecate_pred_argument,
)
from pymc_marketing.model_config import parse_model_config
from pymc_marketing.model_graph import deterministics_to_flat

Expand All @@ -74,7 +77,7 @@
warnings.warn(warning_msg, FutureWarning, stacklevel=1)


class MMM(ModelBuilder):
class MMM(RegressionModelBuilder):
"""Marketing Mix Model class for estimating the impact of marketing channels on a target variable.

This class implements the core functionality of a Marketing Mix Model (MMM), allowing for the
Expand Down
Loading