diff --git a/pymc_marketing/mlflow.py b/pymc_marketing/mlflow.py index 6ef0d6d8..098ab93a 100644 --- a/pymc_marketing/mlflow.py +++ b/pymc_marketing/mlflow.py @@ -166,6 +166,8 @@ msg = "This module requires mlflow. Install using `pip install mlflow`" raise ImportError(msg) +import mlflow.artifacts +from mlflow.pyfunc.model import PythonModel from mlflow.utils.autologging_utils import autologging_integration from pymc_marketing.clv.models.basic import CLVModel @@ -897,8 +899,47 @@ def log_mmm( mlflow.register_model(model_uri, registered_model_name) -def load_mmm( +class CLVWrapper(PythonModel): + """Wrapper class for logging with MLflow.""" + + def __init__(self, model: CLVModel, method: str): + self.model = model + self.method = method + + def predict( + self, + context: Any, + model_input, + params: dict[str, Any] | None = None, + ) -> Any: + """Perform predictions or sampling using the specified prediction method.""" + return getattr(self.model, self.method)(model_input, **params) + + +def log_clv( + model: CLVModel, + method: str, + artifact_path: str = "model", + registered_model_name: str | None = None, +) -> None: + """Log a PyMC-Marketing CLV model as a native MLflow model for the current run.""" + mlflow_clv = CLVWrapper(model=model, method=method) + + mlflow.pyfunc.log_model( + artifact_path=artifact_path, + python_model=mlflow_clv, + ) + + run_id = mlflow.active_run().info.run_id + model_uri = f"runs:/{run_id}/{artifact_path}" + + if registered_model_name: + mlflow.register_model(model_uri, registered_model_name) + + +def _load_model( run_id: str, + cls, full_model: bool = False, keep_idata: bool = False, artifact_path: str = "model", @@ -928,7 +969,6 @@ def load_mmm( model : mlflow.pyfunc.PyFuncModel | MMM The loaded MLflow PyFuncModel or MMM model. - Examples -------- .. code-block:: python @@ -950,7 +990,7 @@ def load_mmm( run_id=run_id, artifact_path="idata.nc", dst_path=dst_path ) - model = MMM.load(idata_path) + model = cls.load(idata_path) if not keep_idata: _force_load_idata_groups(model.idata) @@ -968,6 +1008,73 @@ def load_mmm( return model +def load_mmm( + run_id: str, + full_model: bool = False, + keep_idata: bool = False, + artifact_path: str = "model", + dst_path: str | None = None, +) -> mlflow.pyfunc.PyFuncModel | MMM: + """ + Load a PyMC-Marketing MMM model from MLflow. + + Can either load the full model including the InferenceData, or just the lighter PyFuncModel version. + + Parameters + ---------- + run_id : str + The MLflow run ID from which to load the model. + full_model : bool, default=True + If True, load the full MMM model including the InferenceData. + keep_idata : bool, default=False + If True, keep the downloaded InferenceData saved locally. + artifact_path : str, default="model" + The artifact path within the run where the model is stored. + dst_path : str | None, default=None + The local destination path where the InferenceData will be downloaded. + If None, defaults to "idata_{run_id}" to avoid conflicts when loading multiple models. + + Returns + ------- + model : mlflow.pyfunc.PyFuncModel | MMM + The loaded MLflow PyFuncModel or MMM model. + + + Examples + -------- + .. code-block:: python + + # Load model using run_id + model = load_mmm(run_id="your_run_id", full_model=True, keep_idata=True) + """ + return _load_model( + run_id=run_id, + full_model=full_model, + keep_idata=keep_idata, + artifact_path=artifact_path, + dst_path=dst_path, + cls=MMM, + ) + + +def load_clv( + run_id: str, + full_model: bool = False, + keep_idata: bool = False, + artifact_path: str = "model", + dst_path: str | None = None, +) -> mlflow.pyfunc.PyFuncModel | CLVModel: + """Load a PyMC-Marketing CLV model from MLflow.""" + return _load_model( + run_id=run_id, + full_model=full_model, + keep_idata=keep_idata, + artifact_path=artifact_path, + dst_path=dst_path, + cls=CLVModel, + ) + + def log_versions() -> None: """Log the versions of PyMC-Marketing, PyMC, and ArviZ to MLflow.""" mlflow.log_param("pymc_marketing_version", __version__)