Skip to content

Commit 51c75c1

Browse files
cetagostiniricardoV94williambdean
authored
Create wrapper to make budget optimizer compatible with multidimensional class (#1652)
* Saving progress * Changes * Closing remarks * Adding plots Not working yet * Updating * Adding quick test * Solving plot issues * Adding tests * Solve issue Ricky always helping! Co-Authored-By: Ricardo Vieira <[email protected]> * Solving test issues * Modifying example * Adding more explanations * Removing comment * Requested changes * Replacing cell. * Other small change. --------- Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Will Dean <[email protected]>
1 parent aa22482 commit 51c75c1

File tree

7 files changed

+2714
-1129
lines changed

7 files changed

+2714
-1129
lines changed

docs/source/notebooks/mmm/mmm_multidimensional_example.ipynb

Lines changed: 1699 additions & 1042 deletions
Large diffs are not rendered by default.

pymc_marketing/mmm/budget_optimizer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ class BudgetOptimizer(BaseModel):
165165
def __init__(self, **data):
166166
super().__init__(**data)
167167
# 1. Prepare model with time dimension for optimization
168-
pymc_model = self.mmm_model._set_predictors_for_optimization(self.num_periods)
168+
pymc_model = self.mmm_model._set_predictors_for_optimization(
169+
self.num_periods
170+
) # TODO: Once multidimensional class becomes the main class.
169171

170172
# 2. Shared variable for total_budget: Use annotation to avoid type checking
171173
self._total_budget: SharedVariable = shared(
@@ -270,13 +272,20 @@ def _replace_channel_data_by_optimization_variable(self, model: Model) -> Model:
270272
repeated_budgets_with_carry_over_shape.insert(
271273
date_dim_idx, num_periods + max_lag
272274
)
275+
276+
# Get the dtype from the model's channel_data to ensure type compatibility
277+
channel_data_dtype = model["channel_data"].dtype
278+
273279
repeated_budgets_with_carry_over = pt.zeros(
274-
repeated_budgets_with_carry_over_shape
280+
repeated_budgets_with_carry_over_shape,
281+
dtype=channel_data_dtype, # Use the same dtype as channel_data
275282
)
276283
set_idxs = (*((slice(None),) * date_dim_idx), slice(None, num_periods))
277284
repeated_budgets_with_carry_over = repeated_budgets_with_carry_over[
278285
set_idxs
279-
].set(repeated_budgets)
286+
].set(
287+
pt.cast(repeated_budgets, channel_data_dtype)
288+
) # Cast to ensure type compatibility
280289
repeated_budgets_with_carry_over.name = "repeated_budgets_with_carry_over"
281290

282291
# Freeze dims & data in the underlying PyMC model

pymc_marketing/mmm/multidimensional.py

Lines changed: 137 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import json
1919
import warnings
20+
from collections.abc import Sequence
2021
from copy import deepcopy
2122
from typing import Any, Literal
2223

@@ -29,9 +30,11 @@
2930
import xarray as xr
3031
from pymc.model.fgraph import clone_model as cm
3132
from pymc.util import RandomState
33+
from scipy.optimize import OptimizeResult
3234

3335
from pymc_marketing.mmm import SoftPlusHSGP
3436
from pymc_marketing.mmm.additive_effect import MuEffect, create_event_mu_effect
37+
from pymc_marketing.mmm.budget_optimizer import OptimizerCompatibleModelWrapper
3538
from pymc_marketing.mmm.components.adstock import (
3639
AdstockTransformation,
3740
adstock_from_dict,
@@ -45,6 +48,11 @@
4548
from pymc_marketing.mmm.plot import MMMPlotSuite
4649
from pymc_marketing.mmm.scaling import Scaling, VariableScaling
4750
from pymc_marketing.mmm.tvp import infer_time_index
51+
from pymc_marketing.mmm.utility import UtilityFunctionType, average_response
52+
from pymc_marketing.mmm.utils import (
53+
add_noise_to_channel_allocation,
54+
create_zero_dataset,
55+
)
4856
from pymc_marketing.model_builder import ModelBuilder, _handle_deprecate_pred_argument
4957
from pymc_marketing.model_config import parse_model_config
5058
from pymc_marketing.model_graph import deterministics_to_flat
@@ -945,14 +953,15 @@ def build_model(
945953
channel_data_.name = "channel_data_scaled"
946954
channel_data_.dims = ("date", *self.dims, "channel")
947955

948-
## Hot fix for target data meanwhile pymc allows for internal scaling `https://github.com/pymc-devs/pymc/pull/7656`
949956
target_dim_handler = create_dim_handler(("date", *self.dims))
950-
target_data_scaled = pm.Deterministic(
951-
name="target_scaled",
952-
var=_target
953-
/ target_dim_handler(_target_scale, self.scalers._target.dims),
954-
dims=("date", *self.dims),
957+
958+
target_data_scaled = _target / target_dim_handler(
959+
_target_scale, self.scalers._target.dims
955960
)
961+
target_data_scaled.name = "target_scaled"
962+
target_data_scaled.dims = ("date", *self.dims)
963+
## TODO: Find a better way to save it or access it in the pytensor graph.
964+
self.target_data_scaled = target_data_scaled
956965

957966
for mu_effect in self.mu_effects:
958967
mu_effect.create_data(self)
@@ -1417,3 +1426,125 @@ def create_sample_kwargs(
14171426
# Update with additional keyword arguments
14181427
sampler_config.update(kwargs)
14191428
return sampler_config
1429+
1430+
1431+
class MultiDimensionalBudgetOptimizerWrapper(OptimizerCompatibleModelWrapper):
1432+
"""Wrapper for the BudgetOptimizer to handle multi-dimensional model."""
1433+
1434+
def __init__(self, model: MMM, start_date: str, end_date: str):
1435+
self.model_class = model
1436+
self.start_date = start_date
1437+
self.end_date = end_date
1438+
# Compute the number of periods to allocate budget for
1439+
self.zero_data = create_zero_dataset(
1440+
model=self.model_class, start_date=start_date, end_date=end_date
1441+
)
1442+
self.num_periods = len(self.zero_data[self.model_class.date_column].unique())
1443+
# Adding missing dependencies for compatibility with BudgetOptimizer
1444+
self._channel_scales = 1.0
1445+
1446+
def __getattr__(self, name):
1447+
"""Delegate attribute access to the wrapped MMM model."""
1448+
try:
1449+
# First, try to get the attribute from the wrapper itself
1450+
return object.__getattribute__(self, name)
1451+
except AttributeError:
1452+
# If not found, delegate to the wrapped model
1453+
try:
1454+
return getattr(self.model_class, name)
1455+
except AttributeError as e:
1456+
# Raise an AttributeError if the attribute is not found in either
1457+
raise AttributeError(
1458+
f"'{type(self).__name__}' object and its wrapped 'MMM' object have no attribute '{name}'"
1459+
) from e
1460+
1461+
def _set_predictors_for_optimization(self, num_periods: int) -> pm.Model:
1462+
"""Return the respective PyMC model with any predictors set for optimization."""
1463+
# Use the model's method for transformation
1464+
dataset_xarray = self._posterior_predictive_data_transformation(
1465+
X=self.zero_data,
1466+
include_last_observations=False,
1467+
)
1468+
1469+
# Use the model's method to set data
1470+
pymc_model = self._set_xarray_data(
1471+
dataset_xarray=dataset_xarray,
1472+
clone_model=True, # Ensure we work on a clone
1473+
)
1474+
1475+
# Use the model's mu_effects and set data using the model instance
1476+
for mu_effect in self.mu_effects:
1477+
mu_effect.set_data(self, pymc_model, dataset_xarray)
1478+
1479+
return pymc_model
1480+
1481+
def optimize_budget(
1482+
self,
1483+
budget: float | int,
1484+
budget_bounds: xr.DataArray | dict[str, tuple[float, float]] | None = None,
1485+
response_variable: str = "total_media_contribution_original_scale",
1486+
utility_function: UtilityFunctionType = average_response,
1487+
constraints: Sequence[dict[str, Any]] = (),
1488+
default_constraints: bool = True,
1489+
**minimize_kwargs,
1490+
) -> tuple[xr.DataArray, OptimizeResult]:
1491+
"""Optimize the budget allocation for the model."""
1492+
from pymc_marketing.mmm.budget_optimizer import BudgetOptimizer
1493+
1494+
allocator = BudgetOptimizer(
1495+
num_periods=self.num_periods,
1496+
utility_function=utility_function,
1497+
response_variable=response_variable,
1498+
custom_constraints=constraints,
1499+
default_constraints=default_constraints,
1500+
model=self, # Pass the wrapper instance itself to the BudgetOptimizer
1501+
)
1502+
1503+
return allocator.allocate_budget(
1504+
total_budget=budget,
1505+
budget_bounds=budget_bounds,
1506+
**minimize_kwargs,
1507+
)
1508+
1509+
def sample_response_distribution(
1510+
self,
1511+
allocation_strategy: xr.DataArray,
1512+
noise_level: float = 0.001,
1513+
) -> az.InferenceData:
1514+
"""Generate synthetic dataset and sample posterior predictive based on allocation.
1515+
1516+
Parameters
1517+
----------
1518+
allocation_strategy : DataArray
1519+
The allocation strategy for the channels.
1520+
noise_level : float
1521+
The relative level of noise to add to the data allocation.
1522+
1523+
Returns
1524+
-------
1525+
az.InferenceData
1526+
The posterior predictive samples based on the synthetic dataset.
1527+
"""
1528+
data = create_zero_dataset(
1529+
model=self,
1530+
start_date=self.start_date,
1531+
end_date=self.end_date,
1532+
channel_xr=allocation_strategy.to_dataset(dim="channel"),
1533+
)
1534+
1535+
data_with_noise = add_noise_to_channel_allocation(
1536+
df=data,
1537+
channels=self.channel_columns,
1538+
rel_std=noise_level,
1539+
seed=42,
1540+
)
1541+
1542+
constant_data = allocation_strategy.to_dataset(name="allocation")
1543+
1544+
return self.sample_posterior_predictive(
1545+
X=data_with_noise,
1546+
extend_idata=False,
1547+
include_last_observations=True,
1548+
var_names=["y", "channel_contribution_original_scale"],
1549+
progressbar=False,
1550+
).merge(constant_data)

0 commit comments

Comments
 (0)