From d1f26ca49b9707e9e4922d047316c18a4df317f5 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 15 Jan 2025 18:18:53 +0100 Subject: [PATCH 1/2] add clv wrapper --- pymc_marketing/mlflow.py | 115 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 111 insertions(+), 4 deletions(-) diff --git a/pymc_marketing/mlflow.py b/pymc_marketing/mlflow.py index e63857594..b13f55d9e 100644 --- a/pymc_marketing/mlflow.py +++ b/pymc_marketing/mlflow.py @@ -116,6 +116,8 @@ try: import mlflow + import mlflow.artifacts + from mlflow.pyfunc.model import PythonModel except ImportError: # pragma: no cover msg = "This module requires mlflow. Install using `pip install mlflow`" raise ImportError(msg) @@ -471,7 +473,7 @@ def log_mmm_evaluation_metrics( mlflow.log_metric(f"{metric}_{stat.replace('%', '')}", value) -class MMMWrapper(mlflow.pyfunc.PythonModel): +class MMMWrapper(PythonModel): """A class to prepare a PyMC Marketing Mix Model (MMM) for logging and registering in MLflow. This class extends MLflow's PythonModel to handle prediction tasks using a PyMC-based MMM. @@ -706,8 +708,47 @@ def log_mmm( mlflow.register_model(model_uri, registered_model_name) -def load_mmm( +class CLVWrapper(PythonModel): + """Wrapper class for logging with MLflow.""" + + def __init__(self, model: CLVModel, method: str): + self.model = model + self.method = method + + def predict( + self, + context: Any, + model_input, + params: dict[str, Any] | None = None, + ) -> Any: + """Perform predictions or sampling using the specified prediction method.""" + return getattr(self.model, self.method)(model_input, **params) + + +def log_clv( + model: CLVModel, + method: str, + artifact_path: str = "model", + registered_model_name: str | None = None, +) -> None: + """Log a PyMC-Marketing CLV model as a native MLflow model for the current run.""" + mlflow_clv = CLVWrapper(model=model, method=method) + + mlflow.pyfunc.log_model( + artifact_path=artifact_path, + python_model=mlflow_clv, + ) + + run_id = mlflow.active_run().info.run_id + model_uri = f"runs:/{run_id}/{artifact_path}" + + if registered_model_name: + mlflow.register_model(model_uri, registered_model_name) + + +def _load_model( run_id: str, + cls, full_model: bool = False, keep_idata: bool = False, artifact_path: str = "model", @@ -737,7 +778,6 @@ def load_mmm( model : mlflow.pyfunc.PyFuncModel | MMM The loaded MLflow PyFuncModel or MMM model. - Examples -------- .. code-block:: python @@ -759,7 +799,7 @@ def load_mmm( run_id=run_id, artifact_path="idata.nc", dst_path=dst_path ) - model = MMM.load(idata_path) + model = cls.load(idata_path) if not keep_idata: _force_load_idata_groups(model.idata) @@ -777,6 +817,73 @@ def load_mmm( return model +def load_mmm( + run_id: str, + full_model: bool = False, + keep_idata: bool = False, + artifact_path: str = "model", + dst_path: str | None = None, +) -> mlflow.pyfunc.PyFuncModel | MMM: + """ + Load a PyMC-Marketing MMM model from MLflow. + + Can either load the full model including the InferenceData, or just the lighter PyFuncModel version. + + Parameters + ---------- + run_id : str + The MLflow run ID from which to load the model. + full_model : bool, default=True + If True, load the full MMM model including the InferenceData. + keep_idata : bool, default=False + If True, keep the downloaded InferenceData saved locally. + artifact_path : str, default="model" + The artifact path within the run where the model is stored. + dst_path : str | None, default=None + The local destination path where the InferenceData will be downloaded. + If None, defaults to "idata_{run_id}" to avoid conflicts when loading multiple models. + + Returns + ------- + model : mlflow.pyfunc.PyFuncModel | MMM + The loaded MLflow PyFuncModel or MMM model. + + + Examples + -------- + .. code-block:: python + + # Load model using run_id + model = load_mmm(run_id="your_run_id", full_model=True, keep_idata=True) + """ + return _load_model( + run_id=run_id, + full_model=full_model, + keep_idata=keep_idata, + artifact_path=artifact_path, + dst_path=dst_path, + cls=MMM, + ) + + +def load_clv( + run_id: str, + full_model: bool = False, + keep_idata: bool = False, + artifact_path: str = "model", + dst_path: str | None = None, +) -> mlflow.pyfunc.PyFuncModel | CLVModel: + """Load a PyMC-Marketing CLV model from MLflow.""" + return _load_model( + run_id=run_id, + full_model=full_model, + keep_idata=keep_idata, + artifact_path=artifact_path, + dst_path=dst_path, + cls=CLVModel, + ) + + @autologging_integration(FLAVOR_NAME) def autolog( log_sampler_info: bool = True, From 308ec373942e2e5381d107072404f42333f9d9f5 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 15 Jan 2025 18:22:02 +0100 Subject: [PATCH 2/2] change location of imports --- pymc_marketing/mlflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_marketing/mlflow.py b/pymc_marketing/mlflow.py index b13f55d9e..2a8d06cdf 100644 --- a/pymc_marketing/mlflow.py +++ b/pymc_marketing/mlflow.py @@ -116,12 +116,12 @@ try: import mlflow - import mlflow.artifacts - from mlflow.pyfunc.model import PythonModel except ImportError: # pragma: no cover msg = "This module requires mlflow. Install using `pip install mlflow`" raise ImportError(msg) +import mlflow.artifacts +from mlflow.pyfunc.model import PythonModel from mlflow.utils.autologging_utils import autologging_integration from pymc_marketing.clv.models.basic import CLVModel