diff --git a/pymc_marketing/clv/models/basic.py b/pymc_marketing/clv/models/basic.py index 97c4ecd1..0e72c70f 100644 --- a/pymc_marketing/clv/models/basic.py +++ b/pymc_marketing/clv/models/basic.py @@ -29,7 +29,6 @@ from pymc_marketing.model_builder import ModelBuilder from pymc_marketing.model_config import ModelConfig, parse_model_config -from pymc_marketing.utils import from_netcdf class CLVModel(ModelBuilder): @@ -260,58 +259,13 @@ def _fit_approx( ) @classmethod - def load(cls, fname: str): - """Create a ModelBuilder instance from a file. - - Loads inference data for the model. - - Parameters - ---------- - fname : string - This denotes the name with path from where idata should be loaded from. - - Returns - ------- - Returns an instance of ModelBuilder. - - Raises - ------ - ValueError - If the inference data that is loaded doesn't match with the model. - - Examples - -------- - >>> class MyModel(ModelBuilder): - >>> ... - >>> name = "./mymodel.nc" - >>> imported_model = MyModel.load(name) - - """ - filepath = Path(str(fname)) - idata = from_netcdf(filepath) - return cls._build_with_idata(idata) - - @classmethod - def _build_with_idata(cls, idata: az.InferenceData): - dataset = idata.fit_data.to_dataframe() - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - ) - model = cls( - dataset, - model_config=json.loads(idata.attrs["model_config"]), # type: ignore - sampler_config=json.loads(idata.attrs["sampler_config"]), - ) - - model.idata = idata - model._rename_posterior_variables() - - model.build_model() # type: ignore - if model.id != idata.attrs["id"]: - raise ValueError(f"Inference data not compatible with {cls._model_type}") - return model + def idata_to_init_kwargs(cls, idata: az.InferenceData) -> dict: + """Create the initialization kwargs from an InferenceData object.""" + return { + "data": idata.fit_data.to_dataframe(), + "model_config": json.loads(idata.attrs["model_config"]), + "sampler_config": json.loads(idata.attrs["sampler_config"]), + } def _rename_posterior_variables(self): """Rename variables in the posterior group to remove the _prior suffix. @@ -355,7 +309,7 @@ def thin_fit_result(self, keep_every: int): self.fit_result # noqa: B018 (Raise Error if fit didn't happen yet) assert self.idata is not None # noqa: S101 new_idata = self.idata.isel(draw=slice(None, None, keep_every)).copy() - return type(self)._build_with_idata(new_idata) + return self.build_from_idata(new_idata) @property def default_sampler_config(self) -> dict: @@ -378,8 +332,3 @@ def fit_summary(self, **kwargs): return res["mean"].rename("value") else: return az.summary(self.fit_result, **kwargs) - - @property - def output_var(self): - """Output variable of the model.""" - pass diff --git a/pymc_marketing/customer_choice/mv_its.py b/pymc_marketing/customer_choice/mv_its.py index 1e98a2d3..e928fa1f 100644 --- a/pymc_marketing/customer_choice/mv_its.py +++ b/pymc_marketing/customer_choice/mv_its.py @@ -27,13 +27,13 @@ from xarray import DataArray from pymc_marketing.mmm.additive_effect import MuEffect -from pymc_marketing.model_builder import ModelBuilder, create_idata_accessor +from pymc_marketing.model_builder import RegressionModelBuilder, create_idata_accessor from pymc_marketing.model_config import parse_model_config HDI_ALPHA = 0.5 -class MVITS(ModelBuilder): +class MVITS(RegressionModelBuilder): """Multivariate Interrupted Time Series class. Class to perform a multivariate interrupted time series analysis with the diff --git a/pymc_marketing/mmm/base.py b/pymc_marketing/mmm/base.py index 1893152d..71ad8218 100644 --- a/pymc_marketing/mmm/base.py +++ b/pymc_marketing/mmm/base.py @@ -45,14 +45,14 @@ ValidateDateColumn, ValidateTargetColumn, ) -from pymc_marketing.model_builder import ModelBuilder +from pymc_marketing.model_builder import RegressionModelBuilder __all__ = ["BaseValidateMMM", "MMMModelBuilder"] from pydantic import Field, validate_call -class MMMModelBuilder(ModelBuilder): +class MMMModelBuilder(RegressionModelBuilder): """Base class for Marketing Mix Models (MMM).""" model: pm.Model diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index b212eb56..4254236a 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -164,199 +164,51 @@ class DifferentModelError(Exception): """Error raised when a model loaded is different than one saved.""" -class ModelBuilder(ABC): - """Base class for building models with PyMC-Marketing. +class ModelIO: + """Mixin to handle saving and loading of models.""" - It provides an easy-to-use API (similar to scikit-learn) for models - and help with deployment. - """ - - _model_type = "BaseClass" - version = "None" - - X: pd.DataFrame | None = None - y: pd.Series | np.ndarray | None = None - - def __init__( - self, - model_config: dict | None = None, - sampler_config: dict | None = None, - ): - """Initialize model configuration and sampler configuration for the model. - - Parameters - ---------- - model_config : Dictionary, optional - dictionary of parameters that initialise model configuration. - Class-default defined by the user default_model_config method. - sampler_config : Dictionary, optional - dictionary of parameters that initialise sampler configuration. - Class-default defined by the user default_sampler_config method. - - Examples - -------- - >>> class MyModel(ModelBuilder): - >>> ... - >>> model = MyModel(model_config, sampler_config) - - """ - if sampler_config is None: - sampler_config = {} - if model_config is None: - model_config = {} - self.sampler_config = ( - self.default_sampler_config | sampler_config - ) # Parameters for fit sampling - self.model_config = ( - self.default_model_config | model_config - ) # parameters for priors etc. - self.model: pm.Model - self.idata: az.InferenceData | None = None # idata is generated during fitting - self.is_fitted_ = False - - def _validate_data(self, X, y=None): - if y is not None: - return check_X_y( - X, y, accept_sparse=False, y_numeric=True, multi_output=False - ) - else: - return check_array(X, accept_sparse=False) - - def _data_setter( - self, - X: np.ndarray | pd.DataFrame | xr.Dataset | xr.DataArray, - y: np.ndarray | pd.Series | xr.DataArray | None = None, - ) -> None: - """Set new data in the model. - - Parameters - ---------- - X : array, shape (n_obs, n_features) - The training input samples. - y : array, shape (n_obs,) - The target values (real numbers). - - Examples - -------- - Example logic of data_setter method - - .. code-block:: python - - def _data_setter(self, X, y=None): - data = {"X": X} - if y is None: - y = np.zeros(len(X)) - data["y"] = y - - with self.model: - pm.set_data(data) - - """ - msg = "This model doesn't support setting new data, posterior_predictive, or out of sample methods." - raise NotImplementedError(msg) - - @property - @abstractmethod - def output_var(self) -> str: - """Returns the name of the output variable of the model. - - Returns - ------- - output_var : str - Name of the output variable of the model. - - """ + _model_type: str + version: str + idata: az.InferenceData | None + sampler_config: dict + model_config: dict @property - @abstractmethod - def default_model_config(self) -> dict: - """Return a class default configuration dictionary. - - For model builder if no model_config is provided on class initialization - Useful for understanding structure of required model_config to allow its customization by users + def id(self) -> str: + """Generate a unique hash value for the model. - Examples - -------- - >>> @classmethod - >>> def default_model_config(self): - >>> Return { - >>> 'a' : { - >>> 'loc': 7, - >>> 'scale' : 3 - >>> }, - >>> 'b' : { - >>> 'loc': 3, - >>> 'scale': 5 - >>> } - >>> 'obs_error': 2 - >>> } + The hash value is created using the last 16 characters of the SHA256 hash encoding, + based on the model configuration, version, and model type. Returns ------- - model_config : dict - A set of default parameters for predictor distributions that allow to save and recreate the model. - - """ - - @property - @abstractmethod - def default_sampler_config(self) -> dict: - """Return a class default sampler configuration dictionary. - - For model builder if no sampler_config is provided on class initialization - Useful for understanding structure of required sampler_config to allow its customization by users + str + A string of length 16 characters containing a unique hash of the model. Examples -------- - >>> @classmethod - >>> def default_sampler_config(self): - >>> Return { - >>> 'draws': 1_000, - >>> 'tune': 1_000, - >>> 'chains': 1, - >>> 'target_accept': 0.95, - >>> } - - Returns - ------- - sampler_config : dict - A set of default settings for used by model in fit process. + >>> model = MyModel() + >>> model.id + '0123456789abcdef' """ + hasher = hashlib.sha256() + hasher.update(str(self.model_config.values()).encode()) + hasher.update(self.version.encode()) + hasher.update(self._model_type.encode()) + return hasher.hexdigest()[:16] + @property @abstractmethod - def build_model( - self, - X: pd.DataFrame | xr.Dataset | xr.DataArray, - y: pd.Series | np.ndarray | xr.DataArray, - **kwargs, - ) -> None: - """Create an instance of `pm.Model` based on provided data and model_config. - - It attaches the model to self.model. - - Parameters - ---------- - X : pd.DataFrame - The input data that is going to be used in the model. This should be a DataFrame - containing the features (predictors) for the model. For efficiency reasons, it should - only contain the necessary data columns, not the entire available dataset, as this - will be encoded into the data used to recreate the model. - - y : Union[pd.Series, np.ndarray] - The target data for the model. This should be a Series representing the output - or dependent variable for the model. - - kwargs : dict - Additional keyword arguments that may be used for model configuration. + def _serializable_model_config(self) -> dict[str, int | float | dict]: + """Converts non-serializable values from model_config to their serializable reversable equivalent. - See Also - -------- - default_model_config : returns default model config + Data types like pandas DataFrame, Series or datetime aren't JSON serializable, + so in order to save the model they need to be formatted. Returns ------- - None + model_config: dict """ @@ -541,133 +393,524 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]: "sampler_config": json.loads(attrs["sampler_config"]), } + @classmethod + def idata_to_init_kwargs(cls, idata: az.InferenceData) -> dict[str, Any]: + """Convert the model configuration and sampler configuration from the InferenceData to keyword arguments.""" + return cls.attrs_to_init_kwargs(idata.attrs) + + @abstractmethod def build_from_idata(self, idata: az.InferenceData) -> None: - """Build model from the InferenceData object. + """Build the model from the InferenceData object.""" - This is part of the :func:`load` method. See :func:`load` for more larger context. + @classmethod + def load(cls, fname: str, check: bool = True): + """Create a ModelBuilder instance from a file. - Usually a wrapper around the :func:`build_model` method unless the model - has some additional steps to be built. + Loads inference data for the model. + + This class method has a few steps: + + - Load the InferenceData from the file. + - Construct a new instance of the model using the InferenceData attrs + - Build the model from the InferenceData + - Check if the model id matches the id in the InferenceData loaded. + + Parameters + ---------- + fname : string + This denotes the name with path from where idata should be loaded from. + check : bool, optional + Whether to check if the model id matches the id in the InferenceData loaded. + Defaults to True. + + Returns + ------- + Returns an instance of ModelBuilder. + + Raises + ------ + DifferentModelError + If the inference data that is loaded doesn't match with the model. + + Examples + -------- + Load a model from a file + + .. code-block:: python + + file_name: str = "./mymodel.nc" + model = MyModel.load(file_name) + + """ + filepath = Path(str(fname)) + idata = from_netcdf(filepath) + + try: + return cls.load_from_idata(idata, check=check) + except DifferentModelError as e: + error_msg = ( + f"The file '{fname}' does not contain " + "an InferenceData of the same model " + f"or configuration as '{cls._model_type}'" + ) + raise DifferentModelError(error_msg) from e + + @classmethod + def load_from_idata(cls, idata: az.InferenceData, check: bool = True) -> "ModelIO": + """Create a ModelBuilder instance from an InferenceData object. + + This class method has a few steps: + + - Construct a new instance of the model using the InferenceData attrs + - Build the model from the InferenceData + - Check if the model id matches the id in the InferenceData loaded. Parameters ---------- idata : az.InferenceData - The InferenceData object to build the model from. + The InferenceData object to load the model from. + check : bool, optional + Whether to check if the model id matches the id in the InferenceData loaded. + Defaults to True. + + Returns + ------- + ModelBuilder + An instance of the ModelBuilder class. + + Raises + ------ + DifferentModelError + If the model id in the InferenceData does not match the model id built. """ - dataset = idata.fit_data.to_dataframe() # type: ignore - X = dataset.drop(columns=[self.output_var]) - y = dataset[self.output_var] + # needs to be converted, because json.loads was changing tuple to list + init_kwargs = cls.idata_to_init_kwargs(idata) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + model = cls(**init_kwargs) + + model.idata = idata + model.build_from_idata(idata) + + if not check: + return model + + if model.id != idata.attrs["id"]: + msg = ( + "The model id in the InferenceData does not match the model id. " + "There was no error loading the inference data, but the model may " + "be different. " + "Investigate if the model structure or configuration has changed." + ) + raise DifferentModelError(msg) + + return model + + +class BaseModelBuilder(ABC, ModelIO): + """Base class for building models with PyMC Marketing. + + It provides an easy-to-use API (similar to scikit-learn) for models + and help with deployment. + """ + + _model_type = "BaseClass" + version = "None" + + def __init__( + self, + model_config: dict | None = None, + sampler_config: dict | None = None, + ): + """Initialize model configuration and sampler configuration for the model. + + Parameters + ---------- + model_config : Dictionary, optional + dictionary of parameters that initialise model configuration. + Class-default defined by the user default_model_config method. + sampler_config : Dictionary, optional + dictionary of parameters that initialise sampler configuration. + Class-default defined by the user default_sampler_config method. + + Examples + -------- + >>> class MyModel(ModelBuilder): + >>> ... + >>> model = MyModel(model_config, sampler_config) + + """ + if sampler_config is None: + sampler_config = {} + if model_config is None: + model_config = {} + self.sampler_config = ( + self.default_sampler_config | sampler_config + ) # Parameters for fit sampling + self.model_config = ( + self.default_model_config | model_config + ) # parameters for priors etc. + self.model: pm.Model + self.idata: az.InferenceData | None = None # idata is generated during fitting + self.is_fitted_ = False + + @property + @abstractmethod + def default_model_config(self) -> dict: + """Return a class default configuration dictionary. + + For model builder if no model_config is provided on class initialization + Useful for understanding structure of required model_config to allow its customization by users + + Examples + -------- + >>> @classmethod + >>> def default_model_config(self): + >>> Return { + >>> 'a' : { + >>> 'loc': 7, + >>> 'scale' : 3 + >>> }, + >>> 'b' : { + >>> 'loc': 3, + >>> 'scale': 5 + >>> } + >>> 'obs_error': 2 + >>> } + + Returns + ------- + model_config : dict + A set of default parameters for predictor distributions that allow to save and recreate the model. + + """ + + @property + @abstractmethod + def default_sampler_config(self) -> dict: + """Return a class default sampler configuration dictionary. + + For model builder if no sampler_config is provided on class initialization + Useful for understanding structure of required sampler_config to allow its customization by users + + Examples + -------- + >>> @classmethod + >>> def default_sampler_config(self): + >>> Return { + >>> 'draws': 1_000, + >>> 'tune': 1_000, + >>> 'chains': 1, + >>> 'target_accept': 0.95, + >>> } + + Returns + ------- + sampler_config : dict + A set of default settings for used by model in fit process. + + """ + + def graphviz(self, **kwargs): + """Get the graphviz representation of the model. + + Parameters + ---------- + **kwargs + Keyword arguments for the `pm.model_to_graphviz` function + + Returns + ------- + graphviz.Digraph + + """ + return pm.model_to_graphviz(self.model, **kwargs) + + @property + def fit_result(self) -> xr.Dataset: + """Get the posterior fit_result. + + Returns + ------- + InferenceData object. + + """ + return create_idata_accessor( + "posterior", "The model hasn't been fit yet, call .fit() first" + ).__get__(self) + + @fit_result.setter + def fit_result(self, res: az.InferenceData) -> None: + """Create a setter method to overwrite the pre-existing fit_result. + + Parameters + ---------- + res : az.InferenceData + The inferencedata object to be set + + Returns + ------- + property + The property setter for the InferenceData object. + + """ + if self.idata is None: + self.idata = res + elif "posterior" in self.idata: + warnings.warn("Overriding pre-existing fit_result", stacklevel=1) + self.idata.posterior = res + else: + self.idata.posterior = res + + prior = create_idata_accessor( + "prior", + "The model hasn't been sampled yet, call .sample_prior_predictive() first", + ) + prior_predictive = create_idata_accessor( + "prior_predictive", + "The model hasn't been sampled yet, call .sample_prior_predictive() first", + ) + posterior = create_idata_accessor( + "posterior", "The model hasn't been fit yet, call .fit() first" + ) + + posterior_predictive = create_idata_accessor( + "posterior_predictive", + "The model hasn't been fit yet, call .sample_posterior_predictive() first", + ) + predictions = create_idata_accessor( + "predictions", + "Call the 'sample_posterior_predictive' method with predictions=True first.", + ) + + +class ModelBuilder(BaseModelBuilder): + """ModelBuilder that takes data at initialization.""" + + def __init__(self, data, model_config=None, sampler_config=None): + self.data = data + super().__init__(model_config, sampler_config) + + @classmethod + def idata_to_init_kwargs(cls, idata: az.InferenceData) -> dict[str, Any]: + """Convert the model configuration and sampler configuration from the InferenceData to keyword arguments.""" + kwargs = cls.attrs_to_init_kwargs(idata.attrs) + kwargs["data"] = idata.fit_data + + return kwargs + + @abstractmethod + def build_model( + self, + **kwargs, + ) -> None: + """Create an instance of `pm.Model` based on provided data and model_config. + + It attaches the model to self.model. + + Parameters + ---------- + kwargs : dict + Additional keyword arguments that may be used for model configuration. + + See Also + -------- + default_model_config : returns default model config + + Returns + ------- + None + + """ + + def create_fit_data(self) -> xr.Dataset: + """Create the fit_data group based on the input data.""" + if isinstance(self.data, pd.DataFrame): + return self.data.to_xarray() + + return self.data + + def fit( + self, + progressbar: bool | None = None, + random_seed: RandomState | None = None, + **kwargs: Any, + ) -> az.InferenceData: + """Fit a model using the data passed as a parameter. + + Sets attrs to inference data of the model. + + Parameters + ---------- + X : array-like | array, shape (n_obs, n_features) + The training input samples. If scikit-learn is available, array-like, otherwise array. + y : array-like | array, shape (n_obs,) + The target values (real numbers). If scikit-learn is available, array-like, otherwise array. + progressbar : bool, optional + Specifies whether the fit progress bar should be displayed. Defaults to True. + random_seed : Optional[RandomState] + Provides sampler with initial random seed for obtaining reproducible samples. + **kwargs : Any + Custom sampler settings can be provided in form of keyword arguments. + + Returns + ------- + self : az.InferenceData + Returns inference data of the fitted model. + + Examples + -------- + >>> model = MyModel() + >>> idata = model.fit(X,y) + Auto-assigning NUTS sampler... + Initializing NUTS using jitter+adapt_diag... + + """ + if not hasattr(self, "model"): + self.build_model() + + sampler_kwargs = create_sample_kwargs( + self.sampler_config, + progressbar, + random_seed, + **kwargs, + ) + with self.model: + idata = pm.sample(**sampler_kwargs) + + if self.idata: + self.idata = self.idata.copy() + self.idata.extend(idata, join="right") + else: + self.idata = idata + + if "fit_data" in self.idata: + del self.idata.fit_data + + fit_data = self.create_fit_data() + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="The group fit_data is not defined in the InferenceData scheme", + ) + self.idata.add_groups(fit_data=fit_data) + self.set_idata_attrs(self.idata) + return self.idata # type: ignore - self.build_model(X, y) - @classmethod - def load_from_idata(cls, idata: az.InferenceData) -> "ModelBuilder": - """Create a ModelBuilder instance from an InferenceData object. +class RegressionModelBuilder(BaseModelBuilder): + """Regression based ModelBuilder class.""" - This class method has a few steps: + def _validate_data(self, X, y=None): + if y is not None: + return check_X_y( + X, y, accept_sparse=False, y_numeric=True, multi_output=False + ) + else: + return check_array(X, accept_sparse=False) - - Construct a new instance of the model using the InferenceData attrs - - Build the model from the InferenceData - - Check if the model id matches the id in the InferenceData loaded. + @abstractmethod + def _data_setter( + self, + X: np.ndarray | pd.DataFrame | xr.Dataset | xr.DataArray, + y: np.ndarray | pd.Series | xr.DataArray | None = None, + ) -> None: + """Set new data in the model. Parameters ---------- - idata : az.InferenceData - The InferenceData object to load the model from. + X : array, shape (n_obs, n_features) + The training input samples. + y : array, shape (n_obs,) + The target values (real numbers). Returns ------- - ModelBuilder - An instance of the ModelBuilder class. + None - Raises - ------ - DifferentModelError - If the model id in the InferenceData does not match the model id built. + Examples + -------- + >>> def _data_setter(self, data : pd.DataFrame): + >>> with self.model: + >>> pm.set_data({'x': X['x'].values}) + >>> try: # if y values in new data + >>> pm.set_data({'y_data': y.values}) + >>> except: # dummies otherwise + >>> pm.set_data({'y_data': np.zeros(len(data))}) """ - # needs to be converted, because json.loads was changing tuple to list - init_kwargs = cls.attrs_to_init_kwargs(idata.attrs) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=DeprecationWarning) - model = cls(**init_kwargs) + @property + @abstractmethod + def output_var(self) -> str: + """Returns the name of the output variable of the model. - model.idata = idata - model.build_from_idata(idata) - model.post_sample_model_transformation() + Returns + ------- + output_var : str + Name of the output variable of the model. - if (model_version := model.version) != ( - loaded_version := idata.attrs["version"] - ): - msg = ( - f"The model version ({loaded_version}) in the InferenceData does not " - f"match the model version ({model_version}). " - "There was no error loading the inference data, but the model structure " - "is different. " - ) - raise DifferentModelError(msg) - elif model.id != idata.attrs["id"]: - msg = ( - "The model id in the InferenceData does not match the model id. " - "There was no error loading the inference data, but the model may " - "be different. " - "Investigate if the model structure or configuration has changed." - ) - raise DifferentModelError(msg) + """ - return model + @abstractmethod + def build_model( + self, + X: pd.DataFrame | xr.Dataset | xr.DataArray, + y: pd.Series | np.ndarray | xr.DataArray, + **kwargs, + ) -> None: + """Create an instance of `pm.Model` based on provided data and model_config. - @classmethod - def load(cls, fname: str): - """Create a ModelBuilder instance from a file. + It attaches the model to self.model. - Loads inference data for the model. + Parameters + ---------- + X : pd.DataFrame | xr.Dataset | xr.DataArray + The input data that is going to be used in the model. This should be a DataFrame + containing the features (predictors) for the model. For efficiency reasons, it should + only contain the necessary data columns, not the entire available dataset, as this + will be encoded into the data used to recreate the model. - This class method has a few steps: + y : pd.Series | np.ndarray | xr.DataArray + The target data for the model. This should be a Series representing the output + or dependent variable for the model. - - Load the InferenceData from the file. - - Construct a new instance of the model using the InferenceData attrs - - Build the model from the InferenceData - - Check if the model id matches the id in the InferenceData loaded. + kwargs : dict + Additional keyword arguments that may be used for model configuration. - Parameters - ---------- - fname : string - This denotes the name with path from where idata should be loaded from. + See Also + -------- + default_model_config : returns default model config Returns ------- - Returns an instance of ModelBuilder. + None - Raises - ------ - DifferentModelError - If the inference data that is loaded doesn't match with the model. + """ - Examples - -------- - Load a model from a file + def build_from_idata(self, idata: az.InferenceData) -> None: + """Build model from the InferenceData object. - .. code-block:: python + This is part of the :func:`load` method. See :func:`load` for more larger context. - file_name: str = "./mymodel.nc" - model = MyModel.load(file_name) + Usually a wrapper around the :func:`build_model` method unless the model + has some additional steps to be built. + + Parameters + ---------- + idata : az.InferenceData + The InferenceData object to build the model from. """ - filepath = Path(str(fname)) - idata = from_netcdf(filepath) + dataset = idata.fit_data.to_dataframe() # type: ignore + X = dataset.drop(columns=[self.output_var]) + y = dataset[self.output_var] - try: - return cls.load_from_idata(idata) - except DifferentModelError as e: - error_msg = ( - f"The file '{fname}' does not contain " - "an InferenceData of the same model " - f"or configuration as '{cls._model_type}'" - ) - raise DifferentModelError(error_msg) from e + self.build_model(X, y) def create_fit_data( self, @@ -791,42 +1034,6 @@ def fit( self.set_idata_attrs(self.idata) return self.idata # type: ignore - @property - def fit_result(self) -> xr.Dataset: - """Get the posterior fit_result. - - Returns - ------- - InferenceData object. - - """ - return create_idata_accessor( - "posterior", "The model hasn't been fit yet, call .fit() first" - ).__get__(self) - - @fit_result.setter - def fit_result(self, res: az.InferenceData) -> None: - """Create a setter method to overwrite the pre-existing fit_result. - - Parameters - ---------- - res : az.InferenceData - The inferencedata object to be set - - Returns - ------- - property - The property setter for the InferenceData object. - - """ - if self.idata is None: - self.idata = res - elif "posterior" in self.idata: - warnings.warn("Overriding pre-existing fit_result", stacklevel=1) - self.idata.posterior = res - else: - self.idata.posterior = res - def predict( self, X: np.ndarray | pd.DataFrame | pd.Series | None = None, @@ -988,20 +1195,6 @@ def sample_posterior_predictive( return az.extract(post_pred, variable_name, combined=combined) - @property - @abstractmethod - def _serializable_model_config(self) -> dict[str, int | float | dict]: - """Converts non-serializable values from model_config to their serializable reversable equivalent. - - Data types like pandas DataFrame, Series or datetime aren't JSON serializable, - so in order to save the model they need to be formatted. - - Returns - ------- - model_config: dict - - """ - def predict_proba( self, X: np.ndarray | pd.DataFrame | pd.Series | None = None, @@ -1052,81 +1245,3 @@ def predict_posterior( ) return posterior_predictive_samples[self.output_var] - - @property - def id(self) -> str: - """Generate a unique hash value for the model. - - The hash value is created using the last 16 characters of the SHA256 hash encoding, - based on the model configuration, version, and model type. - - Returns - ------- - str - A string of length 16 characters containing a unique hash of the model. - - Examples - -------- - >>> model = MyModel() - >>> model.id - '0123456789abcdef' - - """ - hasher = hashlib.sha256() - hasher.update(str(self.model_config.values()).encode()) - hasher.update(self.version.encode()) - hasher.update(self._model_type.encode()) - return hasher.hexdigest()[:16] - - def graphviz(self, **kwargs): - """Get the graphviz representation of the model. - - Parameters - ---------- - **kwargs - Keyword arguments for the `pm.model_to_graphviz` function - - Returns - ------- - graphviz.Digraph - - """ - return pm.model_to_graphviz(self.model, **kwargs) - - @requires_model - def table(self, **model_table_kwargs) -> Table: - """Get the summary table of the model. - - Parameters - ---------- - **model_table_kwargs - Keyword arguments for the `model_table` function - - Returns - ------- - rich.table.Table - A rich table containing the summary of the model. - - """ - return model_table(self.model, **model_table_kwargs) - - prior = create_idata_accessor( - "prior", - "The model hasn't been sampled yet, call .sample_prior_predictive() first", - ) - prior_predictive = create_idata_accessor( - "prior_predictive", - "The model hasn't been sampled yet, call .sample_prior_predictive() first", - ) - posterior = create_idata_accessor( - "posterior", "The model hasn't been fit yet, call .fit() first" - ) - - posterior_predictive = create_idata_accessor( - "posterior_predictive", - "The model hasn't been fit yet, call .sample_posterior_predictive() first", - ) - predictions = create_idata_accessor( - "predictions", - "Call the 'sample_posterior_predictive' method with predictions=True first.", - ) diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 22a4d9e2..61e1681d 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -28,6 +28,7 @@ from rich.table import Table from pymc_marketing.model_builder import ( + RegressionModelBuilder, DifferentModelError, ModelBuilder, create_sample_kwargs, @@ -92,7 +93,7 @@ def not_fitted_model_instance(): ) -class ModelBuilderTest(ModelBuilder): +class ModelBuilderTest(RegressionModelBuilder): def __init__(self, model_config=None, sampler_config=None, test_parameter=None): self.test_parameter = test_parameter super().__init__(model_config=model_config, sampler_config=sampler_config) @@ -194,6 +195,7 @@ def test_save_load(fitted_model_instance): rng = np.random.default_rng(42) temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) fitted_model_instance.save(temp.name) + test_builder2 = ModelBuilderTest.load(temp.name) assert fitted_model_instance.idata.groups() == test_builder2.idata.groups() @@ -209,7 +211,9 @@ def test_save_load(fitted_model_instance): temp.close() -def test_initial_build_and_fit(fitted_model_instance, check_idata=True) -> ModelBuilder: +def test_initial_build_and_fit( + fitted_model_instance, check_idata=True +) -> RegressionModelBuilder: if check_idata: assert fitted_model_instance.idata is not None assert "posterior" in fitted_model_instance.idata.groups() @@ -495,7 +499,7 @@ def test_second_fit(toy_X, toy_y, mock_pymc_sample): assert id_before != id_after -class InsufficientModel(ModelBuilder): +class InsufficientModel(RegressionModelBuilder): def __init__( self, model_config=None, sampler_config=None, new_parameter=None ) -> None: @@ -735,7 +739,7 @@ def test_X_pred_prior_deprecation(fitted_model_instance, toy_X, toy_y) -> None: assert isinstance(fitted_model_instance.prior_predictive, xr.Dataset) -class XarrayModel(ModelBuilder): +class XarrayModel(RegressionModelBuilder): """Multivariate Regression model.""" def build_model(self, X, y, **kwargs):