Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ dmypy.json

# Gallery images
docs/source/gallery/images/
docs/gettext/
5 changes: 5 additions & 0 deletions docs/source/gallery/gallery.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ Welcome to the PyMC-Marketing example gallery! This gallery provides visual navi
:img-top: ../gallery/images/mmm_allocation_assessment.png
:link: ../notebooks/mmm/mmm_allocation_assessment.html
:::

:::{grid-item-card} Multi-Objective Optimization
:img-top: ../gallery/images/mmm_allocation_assessment.png
:link: ../notebooks/mmm/mmm_multi_objective_optimization.html
:::
::::

### Lift Test Calibration
Expand Down
7,719 changes: 7,719 additions & 0 deletions docs/source/notebooks/mmm/mmm_multi_objective_optimization.ipynb

Large diffs are not rendered by default.

105 changes: 72 additions & 33 deletions pymc_marketing/mmm/budget_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,30 +141,7 @@


# 2) Create a minimal wrapper satisfying OptimizerCompatibleModelWrapper
class SimpleWrapper:
def __init__(self, base_model, idata, channels):
# required attributes
self._base_model = base_model
self.idata = idata
self.channel_columns = list(channels) # used if bounds is a dict
self._channel_scales = 1.0 # scalar or array broadcastable to channel dims
self.adstock = type("Adstock", (), {"l_max": 0})() # no carryover

def _set_predictors_for_optimization(self, num_periods: int) -> pm.Model:
coords = {"date": np.arange(num_periods), "channel": self.channel_columns}
# clone model
m = clone_model(self._base_model)

# Set the channel_data for optimization
pm.set_data(
{"channel_data": np.zeros((num_periods, len(self.channel_columns)))},
model=m,
coords=coords,
)
return m


wrapper = SimpleWrapper(base_model=train_model, idata=idata, channels=channels)
wrapper = CustomModelWrapper(base_model=train_model, idata=idata, channels=channels)

# 3) Optimize N future periods with optional bounds and/or masks
optimizer = BudgetOptimizer(model=wrapper, num_periods=8)
Expand Down Expand Up @@ -214,11 +191,13 @@ def _set_predictors_for_optimization(self, num_periods: int) -> pm.Model:
from typing import Any, ClassVar, Protocol, cast, runtime_checkable

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from arviz import InferenceData
from pydantic import BaseModel, ConfigDict, Field, InstanceOf
from pydantic import BaseModel, ConfigDict, Field, InstanceOf, PrivateAttr
from pymc import Model, do
from pymc.model.fgraph import clone_model
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc.pytensorf import rewrite_pregrad
from pytensor import function
Expand Down Expand Up @@ -699,15 +678,28 @@ def __init__(self, **data):
self._budget_shape = tuple(len(coord) for coord in self._budget_coords.values())

# 4. Ensure that we only optmize over non-zero channels
# Only perform non-zero channel detection for MMM instances.
# For OptimizerCompatibleModelWrapper, default to optimizing all channels unless a mask is provided.
is_wrapper = (
False if "channel_contribution" in self.mmm_model.idata.posterior else True
)

if self.budgets_to_optimize is None:
# If no mask is provided, we optimize all channels
self.budgets_to_optimize = (
self.mmm_model.idata.posterior.channel_contribution.mean(
("chain", "draw", "date")
).astype(bool)
)
else:
# If a mask is provided, ensure it has the correct shape
if is_wrapper:
# Wrapper path: default to all True over budget dims
ones = np.ones(self._budget_shape, dtype=bool)
self.budgets_to_optimize = xr.DataArray(
ones, coords=self._budget_coords, dims=self._budget_dims
)
else:
# If no mask is provided, optimize all non-zero channels in the model
self.budgets_to_optimize = (
self.mmm_model.idata.posterior.channel_contribution.mean(
("chain", "draw", "date")
).astype(bool)
)
elif not is_wrapper:
# If a mask is provided for MMM instances, ensure it has the correct shape
expected_mask = self.mmm_model.idata.posterior.channel_contribution.mean(
("chain", "draw", "date")
).astype(bool)
Expand Down Expand Up @@ -1215,3 +1207,50 @@ def track_progress(xk):

else:
raise MinimizeException(f"Optimization failed: {result.message}")


class CustomModelWrapper(BaseModel):
"""Wrapper for the BudgetOptimizer to handle custom PyMC models."""

model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")

base_model: Model = Field(
...,
description="Underlying PyMC model to be cloned for optimization.",
)
idata: InferenceData
channel_columns: list[str] = Field(
...,
description="Channel labels used for budget optimization.",
)
adstock: Any = Field(
default_factory=lambda: type("Adstock", (), {"l_max": 0})(),
description="Default adstock placeholder with zero carryover.",
)

_channel_scales: int = PrivateAttr(default=1.0)

def __init__(
self,
base_model: Model,
idata: InferenceData,
channels: Sequence[str],
) -> None:
super().__init__(
base_model=base_model,
idata=idata,
channel_columns=list(channels),
)

def _set_predictors_for_optimization(self, num_periods: int) -> pm.Model:
coords = {"date": np.arange(num_periods), "channel": self.channel_columns}
model_clone = clone_model(self.base_model)
pm.set_data(
{"channel_data": np.zeros((num_periods, len(self.channel_columns)))},
model=model_clone,
coords=coords,
)
return model_clone


OptimizerCompatibleModelWrapper.register(CustomModelWrapper)
109 changes: 106 additions & 3 deletions pymc_marketing/pytensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""PyTensor utility functions."""

from collections import Counter

import arviz as az
import pandas as pd
import pytensor
Expand Down Expand Up @@ -54,17 +56,39 @@ def _prefix_model(f2, prefix: str, exclude_vars: set | None = None):
for dim in v_dims:
exclude_dims.add(dim.data)

# Track dims and build a mapping from base variable names to prefixed names
dims = set()
base_to_prefixed: dict[str, str] = {}
for v in f2.outputs:
# Only prefix if not in exclude_vars
if v.name not in exclude_vars:
new_name = f"{prefix}_{v.name}"
# Only prefix if not in exclude_vars and has a valid name
old_name = getattr(v, "name", None)
if old_name and (old_name not in exclude_vars):
new_name = f"{prefix}_{old_name}"
v.name = new_name
if isinstance(v.owner.op, ModelVar):
rv = v.owner.inputs[0]
rv.name = new_name
# Record base to prefixed mapping for subsequent value-var renaming
base_to_prefixed[old_name] = new_name
dims.update(extract_dims(v))

# Also collect ModelVar outputs that may not be listed among f2.outputs
# (e.g., observed RVs or deterministics created internally)
for var in list(f2.variables):
if (
(owner := getattr(var, "owner", None)) is not None
and isinstance(owner.op, ModelVar)
and isinstance(name := getattr(var, "name", None), str)
and name
and name not in exclude_vars
and name not in base_to_prefixed
and not name.startswith(prefix + "_")
):
new_name = f"{prefix}_{name}"
var.name = new_name
owner.inputs[0].name = new_name
base_to_prefixed[name] = new_name

# Don't rename dimensions that belong to excluded variables
dims_rename = {
dim: pytensor.as_symbolic(f"{prefix}_{dim.data}")
Expand All @@ -83,6 +107,35 @@ def _prefix_model(f2, prefix: str, exclude_vars: set | None = None):
new_coords[k] = v
f2._coords = new_coords # type: ignore[attr-defined]

# Also rename associated transformed/value variables to keep names unique across merged graphs.
# Example patterns include: "<base>", "<base>_log__", "<base>_logodds__", etc.
# We only attempt renames for bases we actually prefixed above.
if base_to_prefixed:
for var in list(f2.variables):
if (
isinstance(name := getattr(var, "name", None), str)
and name
and name not in exclude_vars
and (
match := next(
(
(base, prefixed)
for base, prefixed in base_to_prefixed.items()
if isinstance(base, str)
and base
and (
name == base
or name.startswith(base + "_")
or name.startswith(base + "__")
)
),
None,
)
)
):
base, prefixed = match
var.name = name.replace(base, prefixed, 1)

return f2


Expand Down Expand Up @@ -162,6 +215,56 @@ def merge_models(
return model_from_fgraph(f, mutate_fgraph=True)


def validate_unique_value_vars(model: Model) -> None:
"""Validate that a model has unique, non-null value var names and 1:1 mappings.

This checks that:
- All entries in ``model.value_vars`` have unique, non-empty names
- Keys of ``model.values_to_rvs`` (value vars) also have unique names
- ``model.rvs_to_values`` mapping is consistent (bijection by names)
"""
# Check value_vars names are unique and non-empty
value_vars = list(getattr(model, "value_vars", []))
value_var_names = [getattr(v, "name", None) for v in value_vars]
if any(n is None or n == "" for n in value_var_names):
raise ValueError("Found unnamed value variables in model.value_vars")
dup_vnames = [n for n, c in Counter(value_var_names).items() if c > 1]
if dup_vnames:
raise ValueError(f"Duplicate value variable names: {dup_vnames}")

# Check values_to_rvs keys are unique by name
v2r = getattr(model, "values_to_rvs", {})
v2r_value_vars = list(v2r.keys())
v2r_value_names = [getattr(v, "name", None) for v in v2r_value_vars]
if any(n is None or n == "" for n in v2r_value_names):
raise ValueError("Found unnamed value variables in values_to_rvs")
# Some observed/deterministic value-vars may legitimately share names across merged models
# if they were intentionally merged on (e.g., merge_on) or are non-free and identical.
# Only enforce uniqueness among value vars that correspond to free RVs.
_ = {
getattr(v2r[v], "name", None)
for v in v2r_value_vars
if v in getattr(model, "value_vars", [])
and v2r.get(v) in getattr(model, "free_RVs", [])
}
# Map back to the value-var names for those free RVs
free_value_var_names = [
getattr(model.rvs_to_values[rv], "name", None) for rv in model.free_RVs
]
dup_map_names = [n for n, c in Counter(free_value_var_names).items() if n and c > 1]
if dup_map_names:
raise ValueError("Duplicate value variable names for free RVs: {dup_map_names}")

# Check consistency of reverse mapping by names
r2v = getattr(model, "rvs_to_values", {})
# Names on the value side of both dicts should align set-wise
r2v_value_names = [getattr(v, "name", None) for v in r2v.values()]
if set(r2v_value_names) != set(v2r_value_names):
raise ValueError(
"Mismatch between values_to_rvs and rvs_to_values by value var names"
)


def extract_response_distribution(
pymc_model: Model,
idata: InferenceData,
Expand Down
52 changes: 16 additions & 36 deletions tests/mmm/test_budget_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import pytensor.tensor as pt
import pytest
import xarray as xr
from pymc.model.fgraph import clone_model as cm

from pymc_marketing.mmm import MMM
from pymc_marketing.mmm.budget_optimizer import (
BudgetOptimizer,
CustomModelWrapper,
MinimizeException,
optimizer_xarray_builder,
)
Expand Down Expand Up @@ -910,15 +910,10 @@ def test_budget_distribution_over_period_integration(dummy_df, dummy_idata):


def test_custom_protocol_model_budget_optimizer_works():
"""Validate the optimizer works with a custom model that follows the protocol.

This serves as an example for users wanting to plug in their own PyMC models.
Requirements implemented here:
- The model has a variable named 'channel_data' with dims ("date", "channel").
- Deterministics 'channel_contribution' ("date", "channel") and 'total_contribution' ("date").
- A wrapper object exposes: idata, channel_columns, _channel_scales, adstock.l_max, and
a method `_set_predictors_for_optimization(num_periods) -> pm.Model` that returns a PyMC model
where 'channel_data' is set for the optimization horizon.
"""Validate the optimizer works with the built-in CustomModelWrapper.

This serves as an example for users wanting to plug in their own PyMC models via
``CustomModelWrapper``, which satisfies the OptimizerCompatibleModelWrapper protocol.
"""
# 1) Build and fit a tiny custom PyMC model
rng = np.random.default_rng(0)
Expand All @@ -944,33 +939,18 @@ def test_custom_protocol_model_budget_optimizer_works():

idata = pm.sample(50, tune=50, chains=1, progressbar=False, random_seed=1)

# 2) Minimal wrapper satisfying the optimizer protocol
class SimpleWrapper:
def __init__(self, base_model: pm.Model, idata, channels):
self._base_model = base_model
self.idata = idata
self.channel_columns = list(channels)
self._channel_scales = 1.0
self.adstock = type("Adstock", (), {"l_max": 0})() # no carryover

def _set_predictors_for_optimization(self, num_periods: int) -> pm.Model:
m = cm(self._base_model)
pm.set_data(
{
"channel_data": np.zeros(
(num_periods, len(self.channel_columns)),
dtype=m["channel_data"].dtype,
)
},
coords={
"date": np.arange(num_periods),
"channel": self.channel_columns,
},
model=m,
)
return m
# 2) Wrap the model with CustomModelWrapper
wrapper = CustomModelWrapper(
base_model=train_model,
idata=idata,
channels=channels,
)

wrapper = SimpleWrapper(base_model=train_model, idata=idata, channels=channels)
# Ensure the wrapper produces correctly shaped optimization models
opt_model = wrapper._set_predictors_for_optimization(num_periods=6)
assert tuple(opt_model.named_vars_to_dims["channel_data"]) == ("date", "channel")
assert list(opt_model.coords["channel"]) == channels
assert len(opt_model.coords["date"]) == 6

# 3) Optimize budgets over a small future horizon
optimizer = BudgetOptimizer(model=wrapper, num_periods=6)
Expand Down
Loading
Loading