diff --git a/pymc_marketing/clv/models/basic.py b/pymc_marketing/clv/models/basic.py index 97c4ecd13..a670450e9 100644 --- a/pymc_marketing/clv/models/basic.py +++ b/pymc_marketing/clv/models/basic.py @@ -13,10 +13,8 @@ # limitations under the License. """CLV Model base class.""" -import json import warnings from collections.abc import Sequence -from pathlib import Path from typing import Literal, cast import arviz as az @@ -27,9 +25,8 @@ from pymc.backends.base import MultiTrace from pymc.model.core import Model -from pymc_marketing.model_builder import ModelBuilder +from pymc_marketing.model_builder import DifferentModelError, ModelBuilder from pymc_marketing.model_config import ModelConfig, parse_model_config -from pymc_marketing.utils import from_netcdf class CLVModel(ModelBuilder): @@ -46,6 +43,7 @@ def __init__( sampler_config: dict | None = None, non_distributions: list[str] | None = None, ): + self.data = data model_config = model_config or {} deprecated_keys = [key for key in model_config if key.endswith("_prior")] @@ -60,14 +58,14 @@ def __init__( model_config[new_key] = model_config.pop(key) - model_config = parse_model_config( - model_config, + super().__init__(model_config, sampler_config) + + # Parse model config after merging with defaults + self.model_config = parse_model_config( + self.model_config, non_distributions=non_distributions, ) - super().__init__(model_config, sampler_config) - self.data = data - @staticmethod def _validate_cols( data: pd.DataFrame, @@ -260,59 +258,39 @@ def _fit_approx( ) @classmethod - def load(cls, fname: str): - """Create a ModelBuilder instance from a file. - - Loads inference data for the model. - - Parameters - ---------- - fname : string - This denotes the name with path from where idata should be loaded from. + def idata_to_init_kwargs(cls, idata: az.InferenceData) -> dict: + """Create the initialization kwargs from an InferenceData object.""" + kwargs = cls.attrs_to_init_kwargs(idata.attrs) + kwargs["data"] = idata.fit_data.to_dataframe() - Returns - ------- - Returns an instance of ModelBuilder. - - Raises - ------ - ValueError - If the inference data that is loaded doesn't match with the model. - - Examples - -------- - >>> class MyModel(ModelBuilder): - >>> ... - >>> name = "./mymodel.nc" - >>> imported_model = MyModel.load(name) - - """ - filepath = Path(str(fname)) - idata = from_netcdf(filepath) - return cls._build_with_idata(idata) + return kwargs @classmethod - def _build_with_idata(cls, idata: az.InferenceData): - dataset = idata.fit_data.to_dataframe() + def build_from_idata(cls, idata: az.InferenceData) -> None: + """Build the model from the InferenceData object.""" + kwargs = cls.idata_to_init_kwargs(idata) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=DeprecationWarning, ) - model = cls( - dataset, - model_config=json.loads(idata.attrs["model_config"]), # type: ignore - sampler_config=json.loads(idata.attrs["sampler_config"]), - ) + model = cls(**kwargs) model.idata = idata model._rename_posterior_variables() model.build_model() # type: ignore if model.id != idata.attrs["id"]: - raise ValueError(f"Inference data not compatible with {cls._model_type}") + msg = ( + "The model id in the InferenceData does not match the model id. " + "There was no error loading the inference data, but the model may " + "be different. " + "Investigate if the model structure or configuration has changed." + ) + raise DifferentModelError(msg) return model + # TODO: Remove in 2026Q1? def _rename_posterior_variables(self): """Rename variables in the posterior group to remove the _prior suffix. @@ -355,7 +333,7 @@ def thin_fit_result(self, keep_every: int): self.fit_result # noqa: B018 (Raise Error if fit didn't happen yet) assert self.idata is not None # noqa: S101 new_idata = self.idata.isel(draw=slice(None, None, keep_every)).copy() - return type(self)._build_with_idata(new_idata) + return self.build_from_idata(new_idata) @property def default_sampler_config(self) -> dict: @@ -378,8 +356,3 @@ def fit_summary(self, **kwargs): return res["mean"].rename("value") else: return az.summary(self.fit_result, **kwargs) - - @property - def output_var(self): - """Output variable of the model.""" - pass diff --git a/pymc_marketing/customer_choice/mnl_logit.py b/pymc_marketing/customer_choice/mnl_logit.py index 53a5acec5..39b577fd5 100644 --- a/pymc_marketing/customer_choice/mnl_logit.py +++ b/pymc_marketing/customer_choice/mnl_logit.py @@ -25,13 +25,13 @@ import pytensor.tensor as pt from pymc_extras.prior import Prior -from pymc_marketing.model_builder import ModelBuilder +from pymc_marketing.model_builder import RegressionModelBuilder from pymc_marketing.model_config import parse_model_config HDI_ALPHA = 0.5 -class MNLogit(ModelBuilder): +class MNLogit(RegressionModelBuilder): """ Multinomial Logit class. diff --git a/pymc_marketing/customer_choice/mv_its.py b/pymc_marketing/customer_choice/mv_its.py index 1e98a2d3b..02e49f32a 100644 --- a/pymc_marketing/customer_choice/mv_its.py +++ b/pymc_marketing/customer_choice/mv_its.py @@ -27,13 +27,13 @@ from xarray import DataArray from pymc_marketing.mmm.additive_effect import MuEffect -from pymc_marketing.model_builder import ModelBuilder, create_idata_accessor +from pymc_marketing.model_builder import RegressionModelBuilder, create_idata_accessor from pymc_marketing.model_config import parse_model_config HDI_ALPHA = 0.5 -class MVITS(ModelBuilder): +class MVITS(RegressionModelBuilder): """Multivariate Interrupted Time Series class. Class to perform a multivariate interrupted time series analysis with the @@ -251,7 +251,7 @@ def _generate_and_preprocess_model_data( ], } - def build_model( + def build_model( # type: ignore[override] self, X: pd.DataFrame, y: pd.Series | np.ndarray, diff --git a/pymc_marketing/customer_choice/nested_logit.py b/pymc_marketing/customer_choice/nested_logit.py index 14927c25f..38a24da56 100644 --- a/pymc_marketing/customer_choice/nested_logit.py +++ b/pymc_marketing/customer_choice/nested_logit.py @@ -26,13 +26,13 @@ from pymc_extras.prior import Prior from pytensor.tensor.variable import TensorVariable -from pymc_marketing.model_builder import ModelBuilder +from pymc_marketing.model_builder import RegressionModelBuilder from pymc_marketing.model_config import parse_model_config HDI_ALPHA = 0.5 -class NestedLogit(ModelBuilder): +class NestedLogit(RegressionModelBuilder): """ Nested Logit class. diff --git a/pymc_marketing/mmm/base.py b/pymc_marketing/mmm/base.py index 1893152d6..71ad82181 100644 --- a/pymc_marketing/mmm/base.py +++ b/pymc_marketing/mmm/base.py @@ -45,14 +45,14 @@ ValidateDateColumn, ValidateTargetColumn, ) -from pymc_marketing.model_builder import ModelBuilder +from pymc_marketing.model_builder import RegressionModelBuilder __all__ = ["BaseValidateMMM", "MMMModelBuilder"] from pydantic import Field, validate_call -class MMMModelBuilder(ModelBuilder): +class MMMModelBuilder(RegressionModelBuilder): """Base class for Marketing Mix Models (MMM).""" model: pm.Model diff --git a/pymc_marketing/mmm/mmm.py b/pymc_marketing/mmm/mmm.py index 20e8b73d3..9afcc787e 100644 --- a/pymc_marketing/mmm/mmm.py +++ b/pymc_marketing/mmm/mmm.py @@ -503,7 +503,7 @@ def forward_pass(self, x: pt.TensorVariable | npt.NDArray) -> pt.TensorVariable: return second.apply(x=first.apply(x=x, dims="channel"), dims="channel") - def build_model( + def build_model( # type: ignore[override] self, X: pd.DataFrame, y: pd.Series | np.ndarray, diff --git a/pymc_marketing/mmm/multidimensional.py b/pymc_marketing/mmm/multidimensional.py index 573b79803..e390912a4 100644 --- a/pymc_marketing/mmm/multidimensional.py +++ b/pymc_marketing/mmm/multidimensional.py @@ -62,7 +62,10 @@ add_noise_to_channel_allocation, create_zero_dataset, ) -from pymc_marketing.model_builder import ModelBuilder, _handle_deprecate_pred_argument +from pymc_marketing.model_builder import ( + RegressionModelBuilder, + _handle_deprecate_pred_argument, +) from pymc_marketing.model_config import parse_model_config from pymc_marketing.model_graph import deterministics_to_flat @@ -75,7 +78,7 @@ warnings.warn(warning_msg, FutureWarning, stacklevel=1) -class MMM(ModelBuilder): +class MMM(RegressionModelBuilder): """Marketing Mix Model class for estimating the impact of marketing channels on a target variable. This class implements the core functionality of a Marketing Mix Model (MMM), allowing for the diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index b212eb563..667965739 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Base class responsible of the high level API for model building, fitting saving and loading.""" +"""Base classes containing primitives and high-level API for model building, fitting, saving, and loading.""" import hashlib import json @@ -164,199 +164,51 @@ class DifferentModelError(Exception): """Error raised when a model loaded is different than one saved.""" -class ModelBuilder(ABC): - """Base class for building models with PyMC-Marketing. +class ModelIO: + """Mixin to handle saving and loading of models.""" - It provides an easy-to-use API (similar to scikit-learn) for models - and help with deployment. - """ - - _model_type = "BaseClass" - version = "None" - - X: pd.DataFrame | None = None - y: pd.Series | np.ndarray | None = None - - def __init__( - self, - model_config: dict | None = None, - sampler_config: dict | None = None, - ): - """Initialize model configuration and sampler configuration for the model. - - Parameters - ---------- - model_config : Dictionary, optional - dictionary of parameters that initialise model configuration. - Class-default defined by the user default_model_config method. - sampler_config : Dictionary, optional - dictionary of parameters that initialise sampler configuration. - Class-default defined by the user default_sampler_config method. - - Examples - -------- - >>> class MyModel(ModelBuilder): - >>> ... - >>> model = MyModel(model_config, sampler_config) - - """ - if sampler_config is None: - sampler_config = {} - if model_config is None: - model_config = {} - self.sampler_config = ( - self.default_sampler_config | sampler_config - ) # Parameters for fit sampling - self.model_config = ( - self.default_model_config | model_config - ) # parameters for priors etc. - self.model: pm.Model - self.idata: az.InferenceData | None = None # idata is generated during fitting - self.is_fitted_ = False - - def _validate_data(self, X, y=None): - if y is not None: - return check_X_y( - X, y, accept_sparse=False, y_numeric=True, multi_output=False - ) - else: - return check_array(X, accept_sparse=False) - - def _data_setter( - self, - X: np.ndarray | pd.DataFrame | xr.Dataset | xr.DataArray, - y: np.ndarray | pd.Series | xr.DataArray | None = None, - ) -> None: - """Set new data in the model. - - Parameters - ---------- - X : array, shape (n_obs, n_features) - The training input samples. - y : array, shape (n_obs,) - The target values (real numbers). - - Examples - -------- - Example logic of data_setter method - - .. code-block:: python - - def _data_setter(self, X, y=None): - data = {"X": X} - if y is None: - y = np.zeros(len(X)) - data["y"] = y - - with self.model: - pm.set_data(data) - - """ - msg = "This model doesn't support setting new data, posterior_predictive, or out of sample methods." - raise NotImplementedError(msg) - - @property - @abstractmethod - def output_var(self) -> str: - """Returns the name of the output variable of the model. - - Returns - ------- - output_var : str - Name of the output variable of the model. - - """ + _model_type: str + version: str + idata: az.InferenceData | None + sampler_config: dict + model_config: dict @property - @abstractmethod - def default_model_config(self) -> dict: - """Return a class default configuration dictionary. - - For model builder if no model_config is provided on class initialization - Useful for understanding structure of required model_config to allow its customization by users + def id(self) -> str: + """Generate a unique hash value for the model. - Examples - -------- - >>> @classmethod - >>> def default_model_config(self): - >>> Return { - >>> 'a' : { - >>> 'loc': 7, - >>> 'scale' : 3 - >>> }, - >>> 'b' : { - >>> 'loc': 3, - >>> 'scale': 5 - >>> } - >>> 'obs_error': 2 - >>> } + The hash value is created using the last 16 characters of the SHA256 hash encoding, + based on the model configuration, version, and model type. Returns ------- - model_config : dict - A set of default parameters for predictor distributions that allow to save and recreate the model. - - """ - - @property - @abstractmethod - def default_sampler_config(self) -> dict: - """Return a class default sampler configuration dictionary. - - For model builder if no sampler_config is provided on class initialization - Useful for understanding structure of required sampler_config to allow its customization by users + str + A string of length 16 characters containing a unique hash of the model. Examples -------- - >>> @classmethod - >>> def default_sampler_config(self): - >>> Return { - >>> 'draws': 1_000, - >>> 'tune': 1_000, - >>> 'chains': 1, - >>> 'target_accept': 0.95, - >>> } - - Returns - ------- - sampler_config : dict - A set of default settings for used by model in fit process. + >>> model = MyModel() + >>> model.id + '0123456789abcdef' """ + hasher = hashlib.sha256() + hasher.update(str(self.model_config.values()).encode()) + hasher.update(self.version.encode()) + hasher.update(self._model_type.encode()) + return hasher.hexdigest()[:16] + @property @abstractmethod - def build_model( - self, - X: pd.DataFrame | xr.Dataset | xr.DataArray, - y: pd.Series | np.ndarray | xr.DataArray, - **kwargs, - ) -> None: - """Create an instance of `pm.Model` based on provided data and model_config. - - It attaches the model to self.model. - - Parameters - ---------- - X : pd.DataFrame - The input data that is going to be used in the model. This should be a DataFrame - containing the features (predictors) for the model. For efficiency reasons, it should - only contain the necessary data columns, not the entire available dataset, as this - will be encoded into the data used to recreate the model. - - y : Union[pd.Series, np.ndarray] - The target data for the model. This should be a Series representing the output - or dependent variable for the model. - - kwargs : dict - Additional keyword arguments that may be used for model configuration. + def _serializable_model_config(self) -> dict[str, int | float | dict]: + """Converts non-serializable values from model_config to their serializable reversable equivalent. - See Also - -------- - default_model_config : returns default model config + Data types like pandas DataFrame, Series or datetime aren't JSON serializable, + so in order to save the model they need to be formatted. Returns ------- - None + model_config : dict """ @@ -443,7 +295,7 @@ def set_idata_attrs( raise ValueError(msg) init_parameters: set[str] = set(signature(self.__init__).parameters.keys()) # type: ignore - # Remove since this will be stored in the fit_data group of InferenceData + # Remove data attr since it will be stored in the fit_data group of InferenceData init_parameters -= {"data"} if missing_keys := init_parameters - attrs_keys: @@ -533,7 +385,10 @@ def _model_config_formatting(cls, model_config: dict) -> dict: @classmethod def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]: - """Convert the model configuration and sampler configuration from the attributes to keyword arguments.""" + """Convert the model configuration and sampler configuration from the attributes to keyword arguments. + + This method must be overridden in child classes if additional keyword arguments are needed. + """ return { "model_config": cls._model_config_formatting( json.loads(attrs["model_config"]) @@ -541,133 +396,494 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]: "sampler_config": json.loads(attrs["sampler_config"]), } + @classmethod + def idata_to_init_kwargs(cls, idata: az.InferenceData) -> dict[str, Any]: + """Create the model configuration and sampler configuration from the InferenceData to keyword arguments. + + This method must be overridden in child classes if data is needed as a keyword argument. + """ + return cls.attrs_to_init_kwargs(idata.attrs) + + @abstractmethod def build_from_idata(self, idata: az.InferenceData) -> None: - """Build model from the InferenceData object. + """Build the model from the InferenceData object.""" - This is part of the :func:`load` method. See :func:`load` for more larger context. + @classmethod + def load(cls, fname: str, check: bool = True): + """Create a ModelBuilder instance from a file. - Usually a wrapper around the :func:`build_model` method unless the model - has some additional steps to be built. + Loads inference data for the model. + + This class method has a few steps: + + - Load the InferenceData from the file. + - Construct a new instance of the model using the InferenceData attrs + - Build the model from the InferenceData + - Check if the model id matches the id in the InferenceData loaded. Parameters ---------- - idata : az.InferenceData - The InferenceData object to build the model from. + fname : string + This denotes the name with path from where idata should be loaded from. + check : bool, optional + Whether to check if the model id matches the id in the InferenceData loaded. + Defaults to True. + + Returns + ------- + Returns an instance of ModelBuilder. + + Raises + ------ + DifferentModelError + If the inference data that is loaded doesn't match with the model. + + Examples + -------- + Load a model from a file + + .. code-block:: python + + file_name: str = "./mymodel.nc" + model = MyModel.load(file_name) """ - dataset = idata.fit_data.to_dataframe() # type: ignore - X = dataset.drop(columns=[self.output_var]) - y = dataset[self.output_var] + filepath = Path(str(fname)) + idata = from_netcdf(filepath) - self.build_model(X, y) + try: + return cls.load_from_idata(idata, check=check) + except DifferentModelError as e: + error_msg = ( + f"The file '{fname}' does not contain " + "an InferenceData of the same model " + f"or configuration as '{cls._model_type}'" + ) + raise DifferentModelError(error_msg) from e @classmethod - def load_from_idata(cls, idata: az.InferenceData) -> "ModelBuilder": + def load_from_idata(cls, idata: az.InferenceData, check: bool = True) -> "ModelIO": """Create a ModelBuilder instance from an InferenceData object. This class method has a few steps: - - Construct a new instance of the model using the InferenceData attrs - - Build the model from the InferenceData - - Check if the model id matches the id in the InferenceData loaded. + - Construct a new instance of the model using the InferenceData attrs + - Build the model from the InferenceData + - Check if the model id matches the id in the InferenceData loaded. + + Parameters + ---------- + idata : az.InferenceData + The InferenceData object to load the model from. + check : bool, optional + Whether to check if the model id matches the id in the InferenceData loaded. + Defaults to True. + + Returns + ------- + ModelBuilder + An instance of the ModelBuilder class. + + Raises + ------ + DifferentModelError + If the model id in the InferenceData does not match the model id built. + + """ + init_kwargs = cls.idata_to_init_kwargs(idata) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + model = cls(**init_kwargs) + + model.idata = idata + model.build_from_idata(idata) + + if not check: + return model + + if (model_version := model.version) != ( + loaded_version := idata.attrs["version"] + ): + msg = ( + f"The model version ({loaded_version}) in the InferenceData does not " + f"match the model version ({model_version}). " + "There was no error loading the inference data, but the model structure " + "is different. " + ) + raise DifferentModelError(msg) + + if model.id != idata.attrs["id"]: + msg = ( + "The model id in the InferenceData does not match the model id. " + "There was no error loading the inference data, but the model may " + "be different. " + "Investigate if the model structure or configuration has changed." + ) + raise DifferentModelError(msg) + + return model + + +class ModelBuilder(ABC, ModelIO): + """Base class for building PyMC-Marketing models. + + Child classes must implement the following methods: + - default_model_config: Returns a dictionary for default model configuration. + - default_sampler_config: Returns a dictionary for default sampler configuration. + - build_model: Builds the model based on the provided data and model configuration. + - build_from_idata: Builds the model from an InferenceData object. Needed for loading models. + - fit: Fits the model based on the provided data and sampler configurations. + - attrs_to_init_kwargs: Override to add additional init keyword arguments. + - _serializable_model_config: Needed for saving and loading the model. + + """ + + _model_type = "BaseClass" + version = "None" + + def __init__( + self, + model_config: dict | None = None, + sampler_config: dict | None = None, + ): + """Initialize model configuration and sampler configuration for the model. + + Parameters + ---------- + model_config : Dictionary, optional + dictionary of parameters that initialise model configuration. + Class-default defined by the user default_model_config method. + sampler_config : Dictionary, optional + dictionary of parameters that initialise sampler configuration. + Class-default defined by the user default_sampler_config method. + + Examples + -------- + >>> class MyModel(ModelBuilder): + >>> ... + >>> model = MyModel(model_config, sampler_config) + + """ + if sampler_config is None: + sampler_config = {} + if model_config is None: + model_config = {} + + self.sampler_config = ( + self.default_sampler_config | sampler_config + ) # Parameters for fit sampling + self.model_config = ( + self.default_model_config | model_config + ) # parameters for priors etc. + + self.model: pm.Model + self.idata: az.InferenceData | None = None # idata is generated during fitting + self.is_fitted_ = False + + @property + @abstractmethod + def default_model_config(self) -> dict: + """Return a class default configuration dictionary. + + For model builder if no model_config is provided on class initialization + Useful for understanding structure of required model_config to allow its customization by users + + Examples + -------- + >>> @classmethod + >>> def default_model_config(self): + >>> Return { + >>> 'a' : { + >>> 'loc': 7, + >>> 'scale' : 3 + >>> }, + >>> 'b' : { + >>> 'loc': 3, + >>> 'scale': 5 + >>> } + >>> 'obs_error': 2 + >>> } + + Returns + ------- + model_config : dict + A set of default parameters for predictor distributions that allow to save and recreate the model. + + """ + + @property + @abstractmethod + def default_sampler_config(self) -> dict: + """Return a class default sampler configuration dictionary. + + For model builder if no sampler_config is provided on class initialization + Useful for understanding structure of required sampler_config to allow its customization by users + + Examples + -------- + >>> @classmethod + >>> def default_sampler_config(self): + >>> Return { + >>> 'draws': 1_000, + >>> 'tune': 1_000, + >>> 'chains': 1, + >>> 'target_accept': 0.95, + >>> } + + Returns + ------- + sampler_config : dict + A set of default settings for used by model in fit process. + + """ + + @abstractmethod + def build_model( + self, + **kwargs, + ) -> None: + """Create an instance of `pm.Model` based on provided data and model_config. + + It attaches the model to self.model. + + Parameters + ---------- + kwargs : dict + data arguments for model configuration. + + See Also + -------- + default_model_config : returns default model config + + Returns + ------- + None + + """ + + # TODO: Convert from abstract method into a base fitter for all models. + @abstractmethod + def fit( + self, + **kwargs, + ) -> az.InferenceData: + """Fit a model using the data passed as a parameter. + + Sets attrs to inference data of the model. + + Returns + ------- + self : az.InferenceData + Returns inference data of the fitted model. + + """ + + @requires_model + def graphviz(self, **kwargs): + """Get the graphviz representation of the model. + + Parameters + ---------- + **kwargs + Keyword arguments for the `pm.model_to_graphviz` function + + Returns + ------- + graphviz.Digraph + + """ + return pm.model_to_graphviz(self.model, **kwargs) + + @requires_model + def table(self, **model_table_kwargs) -> Table: + """Get the summary table of the model. + + Parameters + ---------- + **model_table_kwargs + Keyword arguments for the `model_table` function + + Returns + ------- + rich.table.Table + A rich table containing the summary of the model. + + """ + return model_table(self.model, **model_table_kwargs) + + @property + def fit_result(self) -> xr.Dataset: + """Get the posterior fit_result. + + Returns + ------- + InferenceData object. + + """ + return create_idata_accessor( + "posterior", "The model hasn't been fit yet, call .fit() first" + ).__get__(self) + + @fit_result.setter + def fit_result(self, res: az.InferenceData) -> None: + """Create a setter method to overwrite the pre-existing fit_result. + + Parameters + ---------- + res : az.InferenceData + The inferencedata object to be set + + Returns + ------- + property + The property setter for the InferenceData object. + + """ + if self.idata is None: + self.idata = res + elif "posterior" in self.idata: + warnings.warn("Overriding pre-existing fit_result", stacklevel=1) + self.idata.posterior = res + else: + self.idata.posterior = res + + prior = create_idata_accessor( + "prior", + "The model hasn't been sampled yet, call .sample_prior_predictive() first", + ) + prior_predictive = create_idata_accessor( + "prior_predictive", + "The model hasn't been sampled yet, call .sample_prior_predictive() first", + ) + posterior = create_idata_accessor( + "posterior", "The model hasn't been fit yet, call .fit() first" + ) + + posterior_predictive = create_idata_accessor( + "posterior_predictive", + "The model hasn't been fit yet, call .sample_posterior_predictive() first", + ) + predictions = create_idata_accessor( + "predictions", + "Call the 'sample_posterior_predictive' method with predictions=True first.", + ) + + +class RegressionModelBuilder(ModelBuilder): + """ModelBuilder class providing an easy-to-use API similar to scikit-learn for regression models. + + Training data is provided in the fit method and must follow the following convention: + - X: Matrix containing predictor variables + - y: Target variable array + """ + + def _validate_data(self, X, y=None): + if y is not None: + return check_X_y( + X, y, accept_sparse=False, y_numeric=True, multi_output=False + ) + else: + return check_array(X, accept_sparse=False) + + @abstractmethod + def _data_setter( + self, + X: np.ndarray | pd.DataFrame | xr.Dataset | xr.DataArray, + y: np.ndarray | pd.Series | xr.DataArray | None = None, + ) -> None: + """Set new data in the model. Parameters ---------- - idata : az.InferenceData - The InferenceData object to load the model from. + X : array, shape (n_obs, n_features) + The training input samples. + y : array, shape (n_obs,) + The target values (real numbers). Returns ------- - ModelBuilder - An instance of the ModelBuilder class. + None - Raises - ------ - DifferentModelError - If the model id in the InferenceData does not match the model id built. + Examples + -------- + >>> def _data_setter(self, data : pd.DataFrame): + >>> with self.model: + >>> pm.set_data({'x': X['x'].values}) + >>> try: # if y values in new data + >>> pm.set_data({'y_data': y.values}) + >>> except: # dummies otherwise + >>> pm.set_data({'y_data': np.zeros(len(data))}) """ - # needs to be converted, because json.loads was changing tuple to list - init_kwargs = cls.attrs_to_init_kwargs(idata.attrs) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=DeprecationWarning) - model = cls(**init_kwargs) + @property + @abstractmethod + def output_var(self) -> str: + """Returns the name of the output variable of the model. - model.idata = idata - model.build_from_idata(idata) - model.post_sample_model_transformation() + Returns + ------- + output_var : str + Name of the output variable of the model. - if (model_version := model.version) != ( - loaded_version := idata.attrs["version"] - ): - msg = ( - f"The model version ({loaded_version}) in the InferenceData does not " - f"match the model version ({model_version}). " - "There was no error loading the inference data, but the model structure " - "is different. " - ) - raise DifferentModelError(msg) - elif model.id != idata.attrs["id"]: - msg = ( - "The model id in the InferenceData does not match the model id. " - "There was no error loading the inference data, but the model may " - "be different. " - "Investigate if the model structure or configuration has changed." - ) - raise DifferentModelError(msg) + """ - return model + @abstractmethod + def build_model( # type: ignore[override] + self, + X: pd.DataFrame | xr.Dataset | xr.DataArray, + y: pd.Series | np.ndarray | xr.DataArray, + **kwargs, + ) -> None: + """Create an instance of `pm.Model` based on provided data and model_config. - @classmethod - def load(cls, fname: str): - """Create a ModelBuilder instance from a file. + It attaches the model to self.model. - Loads inference data for the model. + Parameters + ---------- + X : pd.DataFrame | xr.Dataset | xr.DataArray + The input data that is going to be used in the model. This should be a DataFrame + containing the features (predictors) for the model. For efficiency reasons, it should + only contain the necessary data columns, not the entire available dataset, as this + will be encoded into the data used to recreate the model. - This class method has a few steps: + y : pd.Series | np.ndarray | xr.DataArray + The target data for the model. This should be a Series representing the output + or dependent variable for the model. - - Load the InferenceData from the file. - - Construct a new instance of the model using the InferenceData attrs - - Build the model from the InferenceData - - Check if the model id matches the id in the InferenceData loaded. + kwargs : dict + Additional keyword arguments that may be used for model configuration. - Parameters - ---------- - fname : string - This denotes the name with path from where idata should be loaded from. + See Also + -------- + default_model_config : returns default model config Returns ------- - Returns an instance of ModelBuilder. + None - Raises - ------ - DifferentModelError - If the inference data that is loaded doesn't match with the model. + """ - Examples - -------- - Load a model from a file + def build_from_idata(self, idata: az.InferenceData) -> None: + """Build model from the InferenceData object. - .. code-block:: python + This is part of the :func:`load` method. See :func:`load` for more larger context. - file_name: str = "./mymodel.nc" - model = MyModel.load(file_name) + Usually a wrapper around the :func:`build_model` method unless the model + has some additional steps to be built. + + Parameters + ---------- + idata : az.InferenceData + The InferenceData object to build the model from. """ - filepath = Path(str(fname)) - idata = from_netcdf(filepath) + dataset = idata.fit_data.to_dataframe() # type: ignore + X = dataset.drop(columns=[self.output_var]) + y = dataset[self.output_var] - try: - return cls.load_from_idata(idata) - except DifferentModelError as e: - error_msg = ( - f"The file '{fname}' does not contain " - "an InferenceData of the same model " - f"or configuration as '{cls._model_type}'" - ) - raise DifferentModelError(error_msg) from e + self.build_model(X, y) # type: ignore def create_fit_data( self, @@ -690,9 +906,9 @@ def create_fit_data( def post_sample_model_transformation(self) -> None: """Perform transformation on the model after sampling.""" - return + pass - def fit( + def fit( # type: ignore[override] self, X: pd.DataFrame | xr.Dataset | xr.DataArray, y: pd.Series | xr.DataArray | np.ndarray | None = None, @@ -791,42 +1007,6 @@ def fit( self.set_idata_attrs(self.idata) return self.idata # type: ignore - @property - def fit_result(self) -> xr.Dataset: - """Get the posterior fit_result. - - Returns - ------- - InferenceData object. - - """ - return create_idata_accessor( - "posterior", "The model hasn't been fit yet, call .fit() first" - ).__get__(self) - - @fit_result.setter - def fit_result(self, res: az.InferenceData) -> None: - """Create a setter method to overwrite the pre-existing fit_result. - - Parameters - ---------- - res : az.InferenceData - The inferencedata object to be set - - Returns - ------- - property - The property setter for the InferenceData object. - - """ - if self.idata is None: - self.idata = res - elif "posterior" in self.idata: - warnings.warn("Overriding pre-existing fit_result", stacklevel=1) - self.idata.posterior = res - else: - self.idata.posterior = res - def predict( self, X: np.ndarray | pd.DataFrame | pd.Series | None = None, @@ -988,20 +1168,6 @@ def sample_posterior_predictive( return az.extract(post_pred, variable_name, combined=combined) - @property - @abstractmethod - def _serializable_model_config(self) -> dict[str, int | float | dict]: - """Converts non-serializable values from model_config to their serializable reversable equivalent. - - Data types like pandas DataFrame, Series or datetime aren't JSON serializable, - so in order to save the model they need to be formatted. - - Returns - ------- - model_config: dict - - """ - def predict_proba( self, X: np.ndarray | pd.DataFrame | pd.Series | None = None, @@ -1052,81 +1218,3 @@ def predict_posterior( ) return posterior_predictive_samples[self.output_var] - - @property - def id(self) -> str: - """Generate a unique hash value for the model. - - The hash value is created using the last 16 characters of the SHA256 hash encoding, - based on the model configuration, version, and model type. - - Returns - ------- - str - A string of length 16 characters containing a unique hash of the model. - - Examples - -------- - >>> model = MyModel() - >>> model.id - '0123456789abcdef' - - """ - hasher = hashlib.sha256() - hasher.update(str(self.model_config.values()).encode()) - hasher.update(self.version.encode()) - hasher.update(self._model_type.encode()) - return hasher.hexdigest()[:16] - - def graphviz(self, **kwargs): - """Get the graphviz representation of the model. - - Parameters - ---------- - **kwargs - Keyword arguments for the `pm.model_to_graphviz` function - - Returns - ------- - graphviz.Digraph - - """ - return pm.model_to_graphviz(self.model, **kwargs) - - @requires_model - def table(self, **model_table_kwargs) -> Table: - """Get the summary table of the model. - - Parameters - ---------- - **model_table_kwargs - Keyword arguments for the `model_table` function - - Returns - ------- - rich.table.Table - A rich table containing the summary of the model. - - """ - return model_table(self.model, **model_table_kwargs) - - prior = create_idata_accessor( - "prior", - "The model hasn't been sampled yet, call .sample_prior_predictive() first", - ) - prior_predictive = create_idata_accessor( - "prior_predictive", - "The model hasn't been sampled yet, call .sample_prior_predictive() first", - ) - posterior = create_idata_accessor( - "posterior", "The model hasn't been fit yet, call .fit() first" - ) - - posterior_predictive = create_idata_accessor( - "posterior_predictive", - "The model hasn't been fit yet, call .sample_posterior_predictive() first", - ) - predictions = create_idata_accessor( - "predictions", - "Call the 'sample_posterior_predictive' method with predictions=True first.", - ) diff --git a/tests/clv/models/test_basic.py b/tests/clv/models/test_basic.py index 86e704dd6..400ba0f11 100644 --- a/tests/clv/models/test_basic.py +++ b/tests/clv/models/test_basic.py @@ -21,6 +21,7 @@ from pymc_extras.prior import Prior from pymc_marketing.clv.models.basic import CLVModel +from pymc_marketing.model_builder import DifferentModelError from tests.conftest import mock_fit_MAP, mock_sample, set_model_fit @@ -178,7 +179,11 @@ def test_load(self, mocker): model.fit(tune=0, chains=2, draws=5) model.save("test_model") model2 = model.load("test_model") + assert model2.fit_result is not None + + # TODO: Add this to the model_builder.py load method? + model2.build_model() assert model2.model is not None os.remove("test_model") @@ -214,8 +219,8 @@ def mock_property(self): # Apply the monkeypatch for the property monkeypatch.setattr(CLVModelTest, "id", property(mock_property)) with pytest.raises( - ValueError, - match="Inference data not compatible with CLVModelTest", + DifferentModelError, + match="The file 'test_model'", ): CLVModelTest.load("test_model") os.remove("test_model") diff --git a/tests/clv/models/test_beta_geo.py b/tests/clv/models/test_beta_geo.py index 79a049d58..a436d283c 100644 --- a/tests/clv/models/test_beta_geo.py +++ b/tests/clv/models/test_beta_geo.py @@ -241,7 +241,7 @@ def test_posterior_distributions(self, fit_type) -> None: map_idata.posterior = map_idata.posterior.isel( chain=slice(None, 1), draw=slice(None, 1) ) - model = self.model._build_with_idata(map_idata) + model = self.model.build_from_idata(map_idata) # We expect 1000 draws to be sampled with MAP expected_shape = (1, 1000) expected_pop_dims = (1, 1000, dim_T, 2) diff --git a/tests/clv/models/test_pareto_nbd.py b/tests/clv/models/test_pareto_nbd.py index 7cdd20696..a22dd084d 100644 --- a/tests/clv/models/test_pareto_nbd.py +++ b/tests/clv/models/test_pareto_nbd.py @@ -325,7 +325,7 @@ def test_posterior_distributions(self, fit_type) -> None: map_idata.posterior = map_idata.posterior.isel( chain=slice(None, 1), draw=slice(None, 1) ) - model = self.model._build_with_idata(map_idata) + model = self.model.build_from_idata(map_idata) # We expect 1000 draws to be sampled with MAP expected_shape = (1, 1000) expected_pop_dims = (1, 1000, dim_T, 2) diff --git a/tests/clv/test_utils.py b/tests/clv/test_utils.py index f02047138..e8e06a141 100644 --- a/tests/clv/test_utils.py +++ b/tests/clv/test_utils.py @@ -213,12 +213,12 @@ def test_map_posterior_mix_fit_types( # Copy model with thinned chain/draw as would be obtained from MAP if transaction_model_map: - transaction_model = transaction_model._build_with_idata( + transaction_model = transaction_model.build_from_idata( transaction_model.idata.sel(chain=slice(0, 1), draw=slice(0, 1)) ) if gg_map: - fitted_gg = fitted_gg._build_with_idata( + fitted_gg = fitted_gg.build_from_idata( fitted_gg.idata.sel(chain=slice(0, 1), draw=slice(0, 1)) ) # create future_spend column from fitted gg diff --git a/tests/customer_choice/test_nestedlogit.py b/tests/customer_choice/test_nestedlogit.py index 8c892a0cf..eee152bf4 100644 --- a/tests/customer_choice/test_nestedlogit.py +++ b/tests/customer_choice/test_nestedlogit.py @@ -137,7 +137,7 @@ def test_preprocess_model_data_sets_attributes(nstL, sample_df, utility_eqs): X, F, y = nstL.preprocess_model_data(sample_df, utility_eqs) # Check main attributes exist - assert hasattr(nstL, "X") + assert hasattr(nstL, "X_data") assert hasattr(nstL, "F") assert hasattr(nstL, "y") assert hasattr(nstL, "alternatives") diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 22a4d9e2e..7b7e1ec4a 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -27,9 +27,13 @@ import xarray as xr from rich.table import Table +from pymc_marketing.hsgp_kwargs import HSGPKwargs from pymc_marketing.model_builder import ( DifferentModelError, ModelBuilder, + ModelIO, + RegressionModelBuilder, + _handle_deprecate_pred_argument, create_sample_kwargs, ) @@ -50,7 +54,7 @@ def toy_y(toy_X): @pytest.fixture(scope="module") -def fitted_model_instance(toy_X, toy_y, mock_pymc_sample): +def fitted_regression_model_instance(toy_X, toy_y, mock_pymc_sample): sampler_config = { "draws": 100, "tune": 100, @@ -62,7 +66,7 @@ def fitted_model_instance(toy_X, toy_y, mock_pymc_sample): "b": {"loc": 0, "scale": 10}, "obs_error": 2, } - model = ModelBuilderTest( + model = RegressionModelBuilderTest( model_config=model_config, sampler_config=sampler_config, test_parameter="test_parameter", @@ -78,21 +82,57 @@ def fitted_model_instance(toy_X, toy_y, mock_pymc_sample): @pytest.fixture(scope="module") -def not_fitted_model_instance(): +def not_fitted_regression_model_instance(): sampler_config = {"draws": 100, "tune": 100, "chains": 2, "target_accept": 0.95} model_config = { "a": {"loc": 0, "scale": 10, "dims": ("numbers",)}, "b": {"loc": 0, "scale": 10}, "obs_error": 2, } - return ModelBuilderTest( + return RegressionModelBuilderTest( model_config=model_config, sampler_config=sampler_config, test_parameter="test_paramter", ) -class ModelBuilderTest(ModelBuilder): +@pytest.fixture(scope="module") +def toy_data(toy_X, toy_y): + """Create a combined dataset for DataRegressionModelBuilderTest.""" + data = toy_X.copy() + data["output"] = toy_y + return data + + +@pytest.fixture(scope="module") +def fitted_base_model_instance(toy_data, mock_pymc_sample): + sampler_config = { + "draws": 100, + "tune": 100, + "chains": 2, + "target_accept": 0.95, + } + model_config = { + "mu_loc": 0, + "mu_scale": 1, + "sigma_scale": 1, + } + model = ModelBuilderTest( + model_config=model_config, + sampler_config=sampler_config, + test_parameter="test_parameter", + ) + model.fit( + chains=1, + draws=100, + tune=100, + ) + return model + + +class RegressionModelBuilderTest(RegressionModelBuilder): + """Test class for RegressionModelBuilder with X and y data arguments.""" + def __init__(self, model_config=None, sampler_config=None, test_parameter=None): self.test_parameter = test_parameter super().__init__(model_config=model_config, sampler_config=sampler_config) @@ -103,8 +143,6 @@ def __init__(self, model_config=None, sampler_config=None, test_parameter=None): def build_model(self, X: pd.DataFrame, y: pd.Series): coords = {"numbers": np.arange(len(X))} - y = y if isinstance(y, np.ndarray) else y.values - with pm.Model(coords=coords) as self.model: x = pm.Data("x", X["input"].values) y_data = pm.Data("y_data", y) @@ -168,78 +206,409 @@ def default_sampler_config(self) -> dict: } -def test_model_and_sampler_config(): - default = ModelBuilderTest() +class ModelBuilderTest(ModelBuilder): + """Test class for ModelBuilder base class.""" + + def __init__(self, model_config=None, sampler_config=None, test_parameter=None): + self.test_parameter = test_parameter + super().__init__(model_config=model_config, sampler_config=sampler_config) + + _model_type = "base_test_model" + version = "0.1" + + def build_model(self, **kwargs): + # This is a simple model for testing the ModelBuilder base class + with pm.Model() as self.model: + # Very simple model to avoid compilation issues + pm.Normal("test", 0, 1) + + def build_from_idata(self, idata: az.InferenceData) -> None: + self.build_model() + + def create_idata_attrs(self): + attrs = super().create_idata_attrs() + attrs["test_parameter"] = json.dumps(self.test_parameter) + return attrs + + @property + def _serializable_model_config(self): + return self.model_config + + @property + def default_model_config(self) -> dict: + return {"mu_loc": 0, "mu_scale": 1, "sigma_scale": 1} + + @property + def default_sampler_config(self) -> dict: + return { + "draws": 1_000, + "tune": 1_000, + "chains": 3, + "target_accept": 0.95, + } + + def fit(self, **kwargs): + """Override fit method for ModelBuilderTest.""" + if not hasattr(self, "model"): + self.build_model() + + sampler_kwargs = create_sample_kwargs( + self.sampler_config, + kwargs.get("progressbar"), + kwargs.get("random_seed"), + **kwargs, + ) + with self.model: + idata = pm.sample(**sampler_kwargs) + + if self.idata: + self.idata = self.idata.copy() + self.idata.extend(idata, join="right") + else: + self.idata = idata + + self.set_idata_attrs(self.idata) + return self.idata + + +@pytest.mark.parametrize( + "model_class,expected_type,test_config", + [ + (RegressionModelBuilderTest, "test_model", {"obs_error": 3}), + (ModelBuilderTest, "base_test_model", {"obs_error": 3}), + (RegressionModelBuilderTest, "test_model", {"mu_loc": 5}), + ], +) +def test_model_configuration(model_class, expected_type, test_config): + """Test model and sampler configuration for all model types.""" + default = model_class() assert default.model_config == default.default_model_config assert default.sampler_config == default.default_sampler_config + assert default._model_type == expected_type - nondefault = ModelBuilderTest( - model_config={"obs_error": 3}, sampler_config={"draws": 42} - ) + nondefault = model_class(model_config=test_config, sampler_config={"draws": 42}) assert nondefault.model_config != nondefault.default_model_config assert nondefault.sampler_config != nondefault.default_sampler_config - assert nondefault.model_config == default.model_config | {"obs_error": 3} + assert nondefault.model_config == default.model_config | test_config assert nondefault.sampler_config == default.sampler_config | {"draws": 42} -def test_save_input_params(fitted_model_instance): - assert fitted_model_instance.idata.attrs["test_parameter"] == '"test_parameter"' +@pytest.mark.parametrize( + "test_case,model_class,method,expected_error,args", + [ + ( + "save_without_fit", + RegressionModelBuilderTest, + "save", + "The model hasn't been fit yet", + ["test"], + ), + ( + "fit_result_error", + RegressionModelBuilderTest, + "fit_result", + "The model hasn't been fit yet", + [], + ), + ( + "graphviz_before_build", + RegressionModelBuilderTest, + "graphviz", + "The model hasn't been built yet", + [], + ), + ( + "table_before_build", + RegressionModelBuilderTest, + "table", + "The model hasn't been built yet", + [], + ), + ], +) +def test_error_handling(test_case, model_class, method, expected_error, args): + """Test various error conditions.""" + model = model_class() + with pytest.raises(RuntimeError, match=expected_error): + getattr(model, method)(*args) + + +def test_model_io_comprehensive(): + """Comprehensive test of ModelIO mixin functionality.""" + # Test with different model types + regression_model = RegressionModelBuilderTest(test_parameter="test_parameter") + base_model = ModelBuilderTest(test_parameter="test_parameter") + + # Test that all have unique IDs + ids = [regression_model.id, base_model.id] + assert len(set(ids)) == 2 + + # Test that all have proper model types and versions + assert regression_model._model_type == "test_model" + assert base_model._model_type == "base_test_model" + assert regression_model.version == "0.1" + assert base_model.version == "0.1" + + # Test attrs creation + attrs = regression_model.create_idata_attrs() + required_keys = {"id", "model_type", "version", "sampler_config", "model_config"} + assert all(key in attrs for key in required_keys) + assert attrs["model_type"] == "test_model" + assert attrs["version"] == "0.1" + assert attrs["test_parameter"] == '"test_parameter"' + + # Test set_idata_attrs + with pm.Model() as simple_model: + pm.Normal("test", 0, 1) + + fake_idata = pm.sample_prior_predictive( + draws=10, model=simple_model, random_seed=1234 + ) + fake_idata.add_groups(dict(posterior=fake_idata.prior)) + + result_idata = regression_model.set_idata_attrs(fake_idata) + assert result_idata.attrs["id"] == regression_model.id + assert result_idata.attrs["model_type"] == regression_model._model_type + assert result_idata.attrs["version"] == regression_model.version + + # Test error when no idata provided + with pytest.raises(RuntimeError, match="No idata provided to set attrs on"): + regression_model.set_idata_attrs(None) -def test_has_pymc_marketing_version(fitted_model_instance): - assert "pymc_marketing_version" in fitted_model_instance.posterior.attrs +@pytest.mark.parametrize( + "method_name,deprecated_arg,additional_kwargs", + [ + ("sample_posterior_predictive", "X_pred", {}), + ("predict", "X_pred", {}), + ("sample_prior_predictive", "X_pred", {}), + ( + "sample_prior_predictive", + "y_pred", + {"X": pd.DataFrame({"input": [1, 2, 3]})}, + ), + ], +) +def test_deprecation_warnings( + fitted_regression_model_instance, + toy_X, + toy_y, + method_name, + deprecated_arg, + additional_kwargs, +): + """Test deprecation warnings for various methods.""" + # Clear any existing data that might interfere + if "posterior_predictive" in fitted_regression_model_instance.idata: + del fitted_regression_model_instance.idata.posterior_predictive + if "prior" in fitted_regression_model_instance.idata: + del fitted_regression_model_instance.idata.prior + if "prior_predictive" in fitted_regression_model_instance.idata: + del fitted_regression_model_instance.idata.prior_predictive + + with pytest.warns(DeprecationWarning, match=f"{deprecated_arg} is deprecated"): + method = getattr(fitted_regression_model_instance, method_name) + if deprecated_arg == "y_pred": + method(**additional_kwargs, **{deprecated_arg: toy_y}) + else: + method(**additional_kwargs, **{deprecated_arg: toy_X}) + + +def test_data_validation_comprehensive(): + """Comprehensive test of data validation in RegressionModelBuilder.""" + model = RegressionModelBuilderTest() + + # Test _validate_data method + X = np.array([[1, 2], [3, 4]]) + y = np.array([1, 2]) + + # Test with X and y + X_valid, y_valid = model._validate_data(X, y) + assert isinstance(X_valid, np.ndarray) + assert isinstance(y_valid, np.ndarray) + + # Test with only X + X_valid_only = model._validate_data(X) + assert isinstance(X_valid_only, np.ndarray) + + # Test with pandas DataFrame and Series + X_df = pd.DataFrame(X, columns=["a", "b"]) + y_series = pd.Series(y) + X_valid_df, y_valid_series = model._validate_data(X_df, y_series) + assert isinstance(X_valid_df, np.ndarray) + assert isinstance(y_valid_series, np.ndarray) + + # Test output variable conflict + X_with_output = pd.DataFrame({"input": [1, 2, 3]}) + X_with_output["output"] = pd.Series([1, 2, 3]) + + with pytest.raises(ValueError, match="X includes a column named 'output'"): + model.fit(X_with_output, pd.Series([1, 2, 3])) + + +def test_graphviz_and_requires_model(): + """Test graphviz functionality and requires_model decorator.""" + model = RegressionModelBuilderTest() + + # Test that graphviz and table fail before model is built + with pytest.raises(RuntimeError, match="The model hasn't been built yet"): + model.graphviz() + with pytest.raises(RuntimeError, match="The model hasn't been built yet"): + model.table() -def test_save_load(fitted_model_instance): - rng = np.random.default_rng(42) + # Test that they work after model is built + model.build_model(pd.DataFrame({"input": [1, 2, 3]}), pd.Series([1, 2, 3])) + assert isinstance(model.graphviz(), graphviz.graphs.Digraph) + assert isinstance(model.table(), Table) + + +def test_model_config_formatting_comprehensive(): + """Comprehensive test of model config formatting.""" + model = RegressionModelBuilderTest() + + # Test with empty config + empty_config = {} + formatted = model._model_config_formatting(empty_config) + assert formatted == {} + + # Test with nested dicts but no lists + simple_config = {"a": {"b": "c"}} + formatted = model._model_config_formatting(simple_config) + assert formatted == simple_config + + # Test with mixed types (original test) + model_config = { + "a": { + "loc": [0, 0], + "scale": 10, + "dims": [ + "x", + ], + }, + } + converted_model_config = model._model_config_formatting(model_config) + np.testing.assert_equal(converted_model_config["a"]["dims"], ("x",)) + np.testing.assert_equal(converted_model_config["a"]["loc"], np.array([0, 0])) + + # Test with mixed types (edge cases) + mixed_config = {"a": {"dims": ["x", "y"], "loc": [1, 2], "scale": 10}} + formatted = model._model_config_formatting(mixed_config) + assert formatted["a"]["dims"] == ("x", "y") + assert isinstance(formatted["a"]["loc"], np.ndarray) + assert formatted["a"]["scale"] == 10 + + +def test_idata_accessors_comprehensive(): + """Comprehensive test of idata accessor properties.""" + model = RegressionModelBuilderTest() + + # Test that accessors fail when no idata is available + with pytest.raises(RuntimeError, match="The model hasn't been fit yet"): + model.posterior + + with pytest.raises(RuntimeError, match="The model hasn't been sampled yet"): + model.prior + + with pytest.raises(RuntimeError, match="The model hasn't been sampled yet"): + model.prior_predictive + + with pytest.raises(RuntimeError, match="The model hasn't been fit yet"): + model.posterior_predictive + + with pytest.raises( + RuntimeError, match="Call the 'sample_posterior_predictive' method" + ): + model.predictions + + # Test fit_result accessor + with pytest.raises(RuntimeError, match="The model hasn't been fit yet"): + model.fit_result + + +def test_handle_deprecate_pred_argument(): + """Test the _handle_deprecate_pred_argument utility function.""" + kwargs = {} + + # Test normal case + result = _handle_deprecate_pred_argument("test_value", "test", kwargs) + assert result == "test_value" + + # Test deprecated argument + kwargs = {"test_pred": "deprecated_value"} + with pytest.warns(DeprecationWarning, match="test_pred is deprecated"): + result = _handle_deprecate_pred_argument(None, "test", kwargs) + assert result == "deprecated_value" + assert "test_pred" not in kwargs # Should be removed + + # Test both arguments provided + kwargs = {"test_pred": "deprecated_value"} + with pytest.raises(ValueError, match="Both test and test_pred cannot be provided"): + _handle_deprecate_pred_argument("test_value", "test", kwargs) + + # Test none allowed (without deprecated argument) + kwargs = {} + result = _handle_deprecate_pred_argument(None, "test", kwargs, none_allowed=True) + assert result is None + + # Test none not allowed + with pytest.raises(ValueError, match="Please provide test"): + _handle_deprecate_pred_argument(None, "test", kwargs, none_allowed=False) + + +def test_save_input_params(fitted_regression_model_instance): + assert ( + fitted_regression_model_instance.idata.attrs["test_parameter"] + == '"test_parameter"' + ) + + +def test_has_pymc_marketing_version(fitted_regression_model_instance): + assert "pymc_marketing_version" in fitted_regression_model_instance.posterior.attrs + + +def test_base_model_save_load(fitted_base_model_instance): + """Test save/load functionality for BaseRegressionModelBuilderTest.""" temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) - fitted_model_instance.save(temp.name) - test_builder2 = ModelBuilderTest.load(temp.name) + fitted_base_model_instance.save(temp.name) - assert fitted_model_instance.idata.groups() == test_builder2.idata.groups() - assert fitted_model_instance.id == test_builder2.id - assert fitted_model_instance.model_config == test_builder2.model_config - assert fitted_model_instance.sampler_config == test_builder2.sampler_config + test_builder2 = ModelBuilderTest.load(temp.name) - x_pred = rng.uniform(low=0, high=1, size=100) - prediction_data = pd.DataFrame({"input": x_pred}) - pred1 = fitted_model_instance.predict(prediction_data) - pred2 = test_builder2.predict(prediction_data) - assert pred1.shape == pred2.shape + assert fitted_base_model_instance.idata.groups() == test_builder2.idata.groups() + assert fitted_base_model_instance.id == test_builder2.id + assert fitted_base_model_instance.model_config == test_builder2.model_config + assert fitted_base_model_instance.sampler_config == test_builder2.sampler_config temp.close() -def test_initial_build_and_fit(fitted_model_instance, check_idata=True) -> ModelBuilder: +def test_initial_build_and_fit( + fitted_regression_model_instance, check_idata=True +) -> RegressionModelBuilder: if check_idata: - assert fitted_model_instance.idata is not None - assert "posterior" in fitted_model_instance.idata.groups() - + assert fitted_regression_model_instance.idata is not None + assert "posterior" in fitted_regression_model_instance.idata.groups() -def test_save_without_fit_raises_runtime_error(): - model_builder = ModelBuilderTest() - match = "The model hasn't been fit yet" - with pytest.raises(RuntimeError, match=match): - model_builder.save("saved_model") - -def test_save_with_kwargs(fitted_model_instance): +def test_save_with_kwargs(fitted_regression_model_instance): """Test that kwargs are properly passed to to_netcdf""" import unittest.mock as mock - with mock.patch.object(fitted_model_instance.idata, "to_netcdf") as mock_to_netcdf: + with mock.patch.object( + fitted_regression_model_instance.idata, "to_netcdf" + ) as mock_to_netcdf: temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) # Test with kwargs supported by InferenceData.to_netcdf() kwargs = {"engine": "netcdf4", "groups": ["posterior", "log_likelihood"]} - fitted_model_instance.save(temp.name, **kwargs) + fitted_regression_model_instance.save(temp.name, **kwargs) # Verify to_netcdf was called with the correct arguments mock_to_netcdf.assert_called_once_with(temp.name, **kwargs) temp.close() -def test_save_with_kwargs_integration(fitted_model_instance): +def test_save_with_kwargs_integration(fitted_regression_model_instance): """Test save function with actual kwargs (integration test)""" temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) @@ -248,7 +617,7 @@ def test_save_with_kwargs_integration(fitted_model_instance): try: # Test with specific groups - this tests that kwargs are passed through - fitted_model_instance.save(temp_path, groups=["posterior"]) + fitted_regression_model_instance.save(temp_path, groups=["posterior"]) # Verify file was created successfully assert os.path.exists(temp_path) @@ -267,7 +636,7 @@ def test_save_with_kwargs_integration(fitted_model_instance): os.unlink(temp_path) -def test_save_kwargs_backward_compatibility(fitted_model_instance): +def test_save_kwargs_backward_compatibility(fitted_regression_model_instance): """Test that save function still works without kwargs (backward compatibility)""" temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) temp_path = temp.name @@ -275,11 +644,11 @@ def test_save_kwargs_backward_compatibility(fitted_model_instance): try: # Test without any kwargs (original behavior) - fitted_model_instance.save(temp_path) + fitted_regression_model_instance.save(temp_path) # Verify file was created and can be loaded assert os.path.exists(temp_path) - loaded_model = ModelBuilderTest.load(temp_path) + loaded_model = RegressionModelBuilderTest.load(temp_path) assert loaded_model.idata is not None assert "posterior" in loaded_model.idata.groups() @@ -291,7 +660,7 @@ def test_save_kwargs_backward_compatibility(fitted_model_instance): def test_empty_sampler_config_fit(toy_X, toy_y, mock_pymc_sample): sampler_config = {} - model_builder = ModelBuilderTest(sampler_config=sampler_config) + model_builder = RegressionModelBuilderTest(sampler_config=sampler_config) model_builder.idata = model_builder.fit( X=toy_X, y=toy_y, chains=1, draws=100, tune=100 ) @@ -299,50 +668,33 @@ def test_empty_sampler_config_fit(toy_X, toy_y, mock_pymc_sample): assert "posterior" in model_builder.idata.groups() -def test_fit(fitted_model_instance): +def test_fit(fitted_regression_model_instance): rng = np.random.default_rng(42) - assert fitted_model_instance.idata is not None - assert "posterior" in fitted_model_instance.idata.groups() - assert fitted_model_instance.idata.posterior.sizes["draw"] == 100 + assert fitted_regression_model_instance.idata is not None + assert "posterior" in fitted_regression_model_instance.idata.groups() + assert fitted_regression_model_instance.idata.posterior.sizes["draw"] == 100 prediction_data = pd.DataFrame({"input": rng.uniform(low=0, high=1, size=100)}) - fitted_model_instance.predict(prediction_data) - post_pred = fitted_model_instance.sample_posterior_predictive( + fitted_regression_model_instance.predict(prediction_data) + post_pred = fitted_regression_model_instance.sample_posterior_predictive( prediction_data, extend_idata=True, combined=True ) assert ( - post_pred[fitted_model_instance.output_var].shape[0] + post_pred[fitted_regression_model_instance.output_var].shape[0] == prediction_data.input.shape[0] ) def test_fit_no_t(toy_X, mock_pymc_sample): - model_builder = ModelBuilderTest() + model_builder = RegressionModelBuilderTest() model_builder.idata = model_builder.fit(X=toy_X, chains=1, draws=100, tune=100) assert model_builder.model is not None assert model_builder.idata is not None assert "posterior" in model_builder.idata.groups() -def test_fit_dup_Y(toy_X, toy_y): - toy_X = pd.concat((toy_X, toy_y), axis=1) - model_builder = ModelBuilderTest() - - with pytest.raises( - ValueError, - match="X includes a column named 'output', which conflicts with the target variable.", - ): - model_builder.fit(X=toy_X, chains=1, draws=100, tune=100) - - -def test_fit_result_error(): - model = ModelBuilderTest() - with pytest.raises(RuntimeError, match="The model hasn't been fit yet"): - model.fit_result - - def test_set_fit_result(toy_X, toy_y): - model = ModelBuilderTest() + model = RegressionModelBuilderTest() model.build_model(X=toy_X, y=toy_y) model.idata = None fake_fit = pm.sample_prior_predictive(draws=50, model=model.model, random_seed=1234) @@ -358,50 +710,36 @@ def test_set_fit_result(toy_X, toy_y): sys.platform == "win32", reason="Permissions for temp files not granted on windows CI.", ) -def test_predict(fitted_model_instance): +def test_predict(fitted_regression_model_instance): rng = np.random.default_rng(42) x_pred = rng.uniform(low=0, high=1, size=100) prediction_data = pd.DataFrame({"input": x_pred}) - pred = fitted_model_instance.predict(prediction_data) + pred = fitted_regression_model_instance.predict(prediction_data) # Perform elementwise comparison using numpy assert isinstance(pred, np.ndarray) assert len(pred) > 0 @pytest.mark.parametrize("combined", [True, False]) -def test_sample_posterior_predictive(fitted_model_instance, combined): +def test_sample_posterior_predictive(fitted_regression_model_instance, combined): rng = np.random.default_rng(42) n_pred = 100 x_pred = rng.uniform(low=0, high=1, size=n_pred) prediction_data = pd.DataFrame({"input": x_pred}) - pred = fitted_model_instance.sample_posterior_predictive( + pred = fitted_regression_model_instance.sample_posterior_predictive( prediction_data, combined=combined, extend_idata=True ) - chains = fitted_model_instance.idata.posterior.sizes["chain"] - draws = fitted_model_instance.idata.posterior.sizes["draw"] + chains = fitted_regression_model_instance.idata.posterior.sizes["chain"] + draws = fitted_regression_model_instance.idata.posterior.sizes["draw"] expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred) - assert pred[fitted_model_instance.output_var].shape == expected_shape - assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating) - - -def test_model_config_formatting(): - model_config = { - "a": { - "loc": [0, 0], - "scale": 10, - "dims": [ - "x", - ], - }, - } - model_builder = ModelBuilderTest() - converted_model_config = model_builder._model_config_formatting(model_config) - np.testing.assert_equal(converted_model_config["a"]["dims"], ("x",)) - np.testing.assert_equal(converted_model_config["a"]["loc"], np.array([0, 0])) + assert pred[fitted_regression_model_instance.output_var].shape == expected_shape + assert np.issubdtype( + pred[fitted_regression_model_instance.output_var].dtype, np.floating + ) def test_id(): - model_builder = ModelBuilderTest() + model_builder = RegressionModelBuilderTest() expected_id = hashlib.sha256( str(model_builder.model_config.values()).encode() + model_builder.version.encode() @@ -413,11 +751,11 @@ def test_id(): @pytest.mark.parametrize("name", ["prior_predictive", "posterior_predictive"]) def test_sample_xxx_predictive_keeps_second( - fitted_model_instance, toy_X, name: str + fitted_regression_model_instance, toy_X, name: str ) -> None: rng = np.random.default_rng(42) method_name = f"sample_{name}" - method = getattr(fitted_model_instance, method_name) + method = getattr(fitted_regression_model_instance, method_name) X_pred = toy_X @@ -433,25 +771,25 @@ def test_sample_xxx_predictive_keeps_second( with pytest.raises(AssertionError): xr.testing.assert_allclose(first_sample, second_sample) - sample = getattr(fitted_model_instance.idata, name) + sample = getattr(fitted_regression_model_instance.idata, name) xr.testing.assert_allclose(sample, second_sample) -def test_prediction_kwarg(fitted_model_instance, toy_X): - result = fitted_model_instance.sample_posterior_predictive( +def test_prediction_kwarg(fitted_regression_model_instance, toy_X): + result = fitted_regression_model_instance.sample_posterior_predictive( toy_X, extend_idata=True, predictions=True, ) - assert "predictions" in fitted_model_instance.idata - assert "predictions_constant_data" in fitted_model_instance.idata + assert "predictions" in fitted_regression_model_instance.idata + assert "predictions_constant_data" in fitted_regression_model_instance.idata assert isinstance(result, xr.Dataset) @pytest.fixture(scope="module") -def model_with_prior_predictive(toy_X) -> ModelBuilderTest: - model = ModelBuilderTest() +def model_with_prior_predictive(toy_X) -> RegressionModelBuilderTest: + model = RegressionModelBuilderTest() model.sample_prior_predictive(toy_X) return model @@ -482,7 +820,7 @@ def test_fit_after_prior_keeps_prior( def test_second_fit(toy_X, toy_y, mock_pymc_sample): - model = ModelBuilderTest() + model = RegressionModelBuilderTest() model.fit(X=toy_X, y=toy_y, chains=1, draws=100, tune=100) assert "posterior" in model.idata @@ -495,7 +833,7 @@ def test_second_fit(toy_X, toy_y, mock_pymc_sample): assert id_before != id_after -class InsufficientModel(ModelBuilder): +class InsufficientModel(RegressionModelBuilder): def __init__( self, model_config=None, sampler_config=None, new_parameter=None ) -> None: @@ -531,6 +869,13 @@ def _generate_and_preprocess_model_data( def _serializable_model_config(self) -> dict[str, int | float | dict]: return {} + def _data_setter( + self, + X, + y, + ) -> None: + pass + def test_insufficient_attrs() -> None: model = InsufficientModel() @@ -542,6 +887,17 @@ def test_insufficient_attrs() -> None: model.sample_prior_predictive(X=X_pred) +def test_abstract_methods(): + """Test that abstract methods are properly enforced.""" + # Test that we can't instantiate ModelBuilder directly + with pytest.raises(TypeError): + ModelBuilder(data=None) + + # Test that we can't instantiate RegressionModelBuilder directly + with pytest.raises(TypeError): + RegressionModelBuilder() + + def test_incorrect_set_idata_attrs_override() -> None: class IncorrectSetAttrs(InsufficientModel): def create_idata_attrs(self) -> dict: @@ -634,7 +990,7 @@ def test_fit_random_seed_reproducibility(toy_X, toy_y, create_random_seed) -> No "draws": 10, "tune": 5, } - model = ModelBuilderTest(sampler_config=sampler_config) + model = RegressionModelBuilderTest(sampler_config=sampler_config) idata = model.fit(toy_X, toy_y, random_seed=create_random_seed()) idata2 = model.fit(toy_X, toy_y, random_seed=create_random_seed()) @@ -653,7 +1009,7 @@ def test_fit_sampler_config_seed_reproducibility(toy_X, toy_y) -> None: "tune": 5, "random_seed": 42, } - model = ModelBuilderTest(sampler_config=sampler_config) + model = RegressionModelBuilderTest(sampler_config=sampler_config) idata = model.fit(toy_X, toy_y) idata2 = model.fit(toy_X, toy_y) @@ -668,7 +1024,7 @@ def test_fit_sampler_config_with_rng_fails(toy_X, toy_y, mock_pymc_sample) -> No "tune": 5, "random_seed": np.random.default_rng(42), } - model = ModelBuilderTest(sampler_config=sampler_config) + model = RegressionModelBuilderTest(sampler_config=sampler_config) match = "Object of type Generator is not JSON serializable" with pytest.raises(TypeError, match=match): @@ -676,7 +1032,7 @@ def test_fit_sampler_config_with_rng_fails(toy_X, toy_y, mock_pymc_sample) -> No def test_unmatched_index(toy_X, toy_y) -> None: - model = ModelBuilderTest() + model = RegressionModelBuilderTest() toy_X = toy_X.copy() toy_X.index = toy_X.index + 1 match = "Index of X and y must match" @@ -684,58 +1040,45 @@ def test_unmatched_index(toy_X, toy_y) -> None: model.fit(toy_X, toy_y) -def test_graphviz(toy_X, toy_y): - """Test pymc.graphviz utility on model before and after being built""" - model = ModelBuilderTest() +@pytest.fixture(scope="module") +def stale_idata(fitted_regression_model_instance) -> az.InferenceData: + idata = fitted_regression_model_instance.idata.copy() + idata.attrs["version"] = "0.0.1" - with pytest.raises( - AttributeError, match="'ModelBuilderTest' object has no attribute 'model'" - ): - model.graphviz() + return idata - model.build_model(X=toy_X, y=toy_y) - assert isinstance(model.graphviz(), graphviz.graphs.Digraph) + +@pytest.fixture(scope="module") +def different_configuration_idata(fitted_regression_model_instance) -> az.InferenceData: + idata = fitted_regression_model_instance.idata.copy() + + model_config = json.loads(idata.attrs["model_config"]) + model_config["a"] = {"loc": 1, "scale": 15, "dims": ("numbers",)} + idata.attrs["model_config"] = json.dumps(model_config) + + return idata @pytest.mark.parametrize( - "method_name", + "fixture_name, match", [ - "sample_posterior_predictive", - "predict", + pytest.param( + "stale_idata", + re.escape("The model version (0.0.1)"), + id="different version", + ), + pytest.param( + "different_configuration_idata", "The model id", id="different id" + ), ], ) -def test_X_pred_posterior_deprecation( - method_name, - fitted_model_instance, - toy_X, -) -> None: - if "posterior_predictive" in fitted_model_instance.idata: - del fitted_model_instance.idata.posterior_predictive - - with pytest.warns(DeprecationWarning, match="X_pred is deprecated"): - method = getattr(fitted_model_instance, method_name) - method(X_pred=toy_X) - - assert isinstance(fitted_model_instance.posterior_predictive, xr.Dataset) - - -def test_X_pred_prior_deprecation(fitted_model_instance, toy_X, toy_y) -> None: - if "prior" in fitted_model_instance.idata: - del fitted_model_instance.idata.prior - if "prior_predictive" in fitted_model_instance.idata: - del fitted_model_instance.idata.prior_predictive - - with pytest.warns(DeprecationWarning, match="X_pred is deprecated"): - fitted_model_instance.sample_prior_predictive(X_pred=toy_X) - - with pytest.warns(DeprecationWarning, match="y_pred is deprecated"): - fitted_model_instance.sample_prior_predictive(toy_X, y_pred=toy_y) - - assert isinstance(fitted_model_instance.prior, xr.Dataset) - assert isinstance(fitted_model_instance.prior_predictive, xr.Dataset) +def test_load_from_idata_errors(request, fixture_name, match) -> None: + idata = request.getfixturevalue(fixture_name) + with pytest.raises(DifferentModelError, match=match): + RegressionModelBuilderTest.load_from_idata(idata, check=True) -class XarrayModel(ModelBuilder): +class XarrayModel(RegressionModelBuilder): """Multivariate Regression model.""" def build_model(self, X, y, **kwargs): @@ -828,48 +1171,109 @@ def test_xarray_model_builder(X_is_array, xarray_X, xarray_y, mock_pymc_sample) ) -@pytest.fixture(scope="module") -def stale_idata(fitted_model_instance) -> az.InferenceData: - idata = fitted_model_instance.idata.copy() - idata.attrs["version"] = "0.0.1" +def test_check_X_y_and_check_array_fallback(monkeypatch): + """Test fallback functions for check_X_y and check_array.""" + import importlib + import sys - return idata + # Remove sklearn from sys.modules to force fallback + sys.modules["sklearn"] = None + sys.modules["sklearn.utils"] = None + sys.modules["sklearn.utils.validation"] = None + import pymc_marketing.model_builder as mb + importlib.reload(mb) -@pytest.fixture(scope="module") -def different_configuration_idata(fitted_model_instance) -> az.InferenceData: - idata = fitted_model_instance.idata.copy() - model_config = json.loads(idata.attrs["model_config"]) - model_config["a"] = {"loc": 1, "scale": 15, "dims": ("numbers",)} - idata.attrs["model_config"] = json.dumps(model_config) + X = np.array([[1, 2], [3, 4]]) + y = np.array([1, 2]) + X2, y2 = mb.check_X_y(X, y) + assert np.array_equal(X, X2) + assert np.array_equal(y, y2) + X3 = mb.check_array(X) + assert np.array_equal(X, X3) - return idata +def test_create_idata_attrs_default_to_dict_and_hsgp_kwargs(): + """Test default function in create_idata_attrs.""" -@pytest.mark.parametrize( - "fixture_name, match", - [ - pytest.param( - "stale_idata", - re.escape("The model version (0.0.1)"), - id="different version", - ), - pytest.param( - "different_configuration_idata", "The model id", id="different id" - ), - ], -) -def test_load_from_idata_errors(request, fixture_name, match) -> None: - idata = request.getfixturevalue(fixture_name) - with pytest.raises(DifferentModelError, match=match): - ModelBuilderTest.load_from_idata(idata) + class DummyModel(ModelIO): + _model_type = "dummy" + version = "0.1" + sampler_config = {} + model_config = {} + @property + def _serializable_model_config(self): + return self.model_config -def test_table() -> None: - model = ModelBuilderTest() - match = "The model hasn't been built yet" - with pytest.raises(RuntimeError, match=match): - model.table() + def build_from_idata(self, idata): + pass + + class ObjWithToDict: + def to_dict(self): + return {"foo": "bar"} + + m = DummyModel() + m.model_config = {"obj": ObjWithToDict()} + attrs = m.create_idata_attrs() + import json + assert json.loads(attrs["model_config"])["obj"] == {"foo": "bar"} + + m.model_config = {"hsgp": HSGPKwargs(input_dim=1, L=1.0, m=1)} + attrs = m.create_idata_attrs() + assert "hsgp" in json.loads(attrs["model_config"]) + + +def test_load_from_idata_check_false(fitted_regression_model_instance): + """Covers line 503: if not check: return model.""" + idata = fitted_regression_model_instance.idata + model = RegressionModelBuilderTest.load_from_idata(idata, check=False) + assert isinstance(model, RegressionModelBuilderTest) + + +def test_fit_result_setter_else_branch(): + """Covers line 707: else branch in fit_result setter.""" + model = RegressionModelBuilderTest() + # Create idata with no 'posterior' + import arviz as az + + idata = az.from_dict(prior={"a": np.ones((1, 1, 1))}) + model.idata = idata + model.fit_result = idata + assert hasattr(model.idata, "posterior") + + +def test_predict_keyerror_output_var_missing(): + """Covers line 1009: KeyError in predict if output_var missing.""" + model = RegressionModelBuilderTest() model.build_model(pd.DataFrame({"input": [1, 2, 3]}), pd.Series([1, 2, 3])) - assert isinstance(model.table(), Table) + # Patch sample_posterior_predictive to return missing output_var + model.sample_posterior_predictive = lambda *a, **k: {"not_output": np.ones(3)} + with pytest.raises(KeyError): + model.predict(pd.DataFrame({"input": [1, 2, 3]})) + + +def test_predict_proba_calls_predict_posterior(monkeypatch): + """Covers line 1137: predict_proba calls predict_posterior.""" + model = RegressionModelBuilderTest() + called = {} + + def fake_predict_posterior(*a, **k): + called["yes"] = True + return "ok" + + model.predict_posterior = fake_predict_posterior + result = model.predict_proba(np.array([[1, 2, 3]])) + assert called["yes"] + assert result == "ok" + + +def test_predict_posterior_keyerror_output_var_missing(): + """Test KeyError in predict_posterior if output_var missing.""" + model = RegressionModelBuilderTest() + model.build_model(pd.DataFrame({"input": [1, 2, 3]}), pd.Series([1, 2, 3])) + # Patch sample_posterior_predictive to return missing output_var + model.sample_posterior_predictive = lambda *a, **k: {"not_output": np.ones(3)} + with pytest.raises(KeyError): + model.predict_posterior(pd.DataFrame({"input": [1, 2, 3]}))