Skip to content

Add CLVWrapper #1377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 110 additions & 3 deletions pymc_marketing/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@
msg = "This module requires mlflow. Install using `pip install mlflow`"
raise ImportError(msg)

import mlflow.artifacts
from mlflow.pyfunc.model import PythonModel
from mlflow.utils.autologging_utils import autologging_integration

from pymc_marketing.clv.models.basic import CLVModel
Expand Down Expand Up @@ -897,8 +899,47 @@ def log_mmm(
mlflow.register_model(model_uri, registered_model_name)


def load_mmm(
class CLVWrapper(PythonModel):
"""Wrapper class for logging with MLflow."""

def __init__(self, model: CLVModel, method: str):
self.model = model
self.method = method

def predict(
self,
context: Any,
model_input,
params: dict[str, Any] | None = None,
) -> Any:
"""Perform predictions or sampling using the specified prediction method."""
return getattr(self.model, self.method)(model_input, **params)


def log_clv(
model: CLVModel,
method: str,
artifact_path: str = "model",
registered_model_name: str | None = None,
) -> None:
"""Log a PyMC-Marketing CLV model as a native MLflow model for the current run."""
mlflow_clv = CLVWrapper(model=model, method=method)

mlflow.pyfunc.log_model(
artifact_path=artifact_path,
python_model=mlflow_clv,
)

run_id = mlflow.active_run().info.run_id
model_uri = f"runs:/{run_id}/{artifact_path}"

if registered_model_name:
mlflow.register_model(model_uri, registered_model_name)


def _load_model(
run_id: str,
cls,
full_model: bool = False,
keep_idata: bool = False,
artifact_path: str = "model",
Expand Down Expand Up @@ -928,7 +969,6 @@ def load_mmm(
model : mlflow.pyfunc.PyFuncModel | MMM
The loaded MLflow PyFuncModel or MMM model.


Examples
--------
.. code-block:: python
Expand All @@ -950,7 +990,7 @@ def load_mmm(
run_id=run_id, artifact_path="idata.nc", dst_path=dst_path
)

model = MMM.load(idata_path)
model = cls.load(idata_path)

if not keep_idata:
_force_load_idata_groups(model.idata)
Expand All @@ -968,6 +1008,73 @@ def load_mmm(
return model


def load_mmm(
run_id: str,
full_model: bool = False,
keep_idata: bool = False,
artifact_path: str = "model",
dst_path: str | None = None,
) -> mlflow.pyfunc.PyFuncModel | MMM:
"""
Load a PyMC-Marketing MMM model from MLflow.

Can either load the full model including the InferenceData, or just the lighter PyFuncModel version.

Parameters
----------
run_id : str
The MLflow run ID from which to load the model.
full_model : bool, default=True
If True, load the full MMM model including the InferenceData.
keep_idata : bool, default=False
If True, keep the downloaded InferenceData saved locally.
artifact_path : str, default="model"
The artifact path within the run where the model is stored.
dst_path : str | None, default=None
The local destination path where the InferenceData will be downloaded.
If None, defaults to "idata_{run_id}" to avoid conflicts when loading multiple models.

Returns
-------
model : mlflow.pyfunc.PyFuncModel | MMM
The loaded MLflow PyFuncModel or MMM model.


Examples
--------
.. code-block:: python

# Load model using run_id
model = load_mmm(run_id="your_run_id", full_model=True, keep_idata=True)
"""
return _load_model(
run_id=run_id,
full_model=full_model,
keep_idata=keep_idata,
artifact_path=artifact_path,
dst_path=dst_path,
cls=MMM,
)


def load_clv(
run_id: str,
full_model: bool = False,
keep_idata: bool = False,
artifact_path: str = "model",
dst_path: str | None = None,
) -> mlflow.pyfunc.PyFuncModel | CLVModel:
"""Load a PyMC-Marketing CLV model from MLflow."""
return _load_model(
run_id=run_id,
full_model=full_model,
keep_idata=keep_idata,
artifact_path=artifact_path,
dst_path=dst_path,
cls=CLVModel,
)


def log_versions() -> None:
"""Log the versions of PyMC-Marketing, PyMC, and ArviZ to MLflow."""
mlflow.log_param("pymc_marketing_version", __version__)
Expand Down
Loading