Skip to content

Commit 4631ba1

Browse files
authored
Get marginal_effects hdi estimates (#1885)
* Improve marginal_effects calculation * Update sensitivity analysis notebook * Improve MockMMM.idata
1 parent be685bd commit 4631ba1

File tree

3 files changed

+2445
-1092
lines changed

3 files changed

+2445
-1092
lines changed

docs/source/notebooks/mmm/mmm_sensitivity_analysis.ipynb

Lines changed: 2413 additions & 1073 deletions
Large diffs are not rendered by default.

pymc_marketing/mmm/sensitivity_analysis.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def run_sweep(
8080
predictions = []
8181
for sweep_value in self.sweep_values:
8282
X_new = self.create_intervention(sweep_value)
83-
counterfac = self.mmm.predict(X_new, extend_idata=False, progressbar=False)
83+
counterfac = self.mmm.sample_posterior_predictive(
84+
X_new, extend_idata=False, combined=False, progressbar=False
85+
)
8486
uplift = counterfac - actual
8587
predictions.append(uplift)
8688

@@ -92,12 +94,13 @@ def run_sweep(
9294

9395
marginal_effects = self.compute_marginal_effects(results, self.sweep_values)
9496

95-
results = xr.Dataset(
96-
{
97-
"y": results,
98-
"marginal_effects": marginal_effects,
99-
}
100-
)
97+
results = xr.merge(
98+
[
99+
results,
100+
marginal_effects.rename({"y": "marginal_effects"}),
101+
]
102+
).transpose(..., "sweep")
103+
101104
# Add metadata to the results
102105
results.attrs["sweep_type"] = self.sweep_type
103106
results.attrs["var_names"] = self.var_names
@@ -129,9 +132,5 @@ def create_intervention(self, sweep_value: float) -> pd.DataFrame:
129132
def compute_marginal_effects(results, sweep_values) -> xr.DataArray:
130133
"""Compute marginal effects via finite differences from the sweep results."""
131134
marginal_effects = results.differentiate(coord="sweep")
132-
marginal_effects = xr.DataArray(
133-
marginal_effects,
134-
dims=results.dims,
135-
coords=results.coords,
136-
)
135+
137136
return marginal_effects

tests/mmm/test_sensitivity_analysis.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,21 @@ def __init__(self):
4545
n_chains, n_draws, n_dates = 2, 10, 3
4646

4747
# This is what the sensitivity analysis expects to find
48-
posterior_predictive_data = xr.DataArray(
49-
np.random.normal(size=(n_chains, n_draws, n_dates)),
50-
dims=["chain", "draw", "date"],
48+
posterior_predictive_data = xr.Dataset(
49+
dict(
50+
y=(
51+
["chain", "draw", "date"],
52+
np.random.normal(size=(n_chains, n_draws, n_dates)),
53+
),
54+
),
5155
coords={
5256
"chain": np.arange(n_chains),
5357
"draw": np.arange(n_draws),
5458
"date": dates, # Use the same dates as the DataFrame
5559
},
5660
)
5761

58-
self.posterior_predictive = {"y": posterior_predictive_data}
62+
self.posterior_predictive = posterior_predictive_data
5963

6064
def __getitem__(self, key):
6165
if key == "posterior_predictive":
@@ -81,16 +85,26 @@ def predict(self, X_new, extend_idata=False, progressbar=False):
8185
size=(n_chains, n_draws, n_dates),
8286
)
8387

84-
return xr.DataArray(
85-
data,
86-
dims=["chain", "draw", "date"],
88+
return xr.Dataset(
89+
dict(
90+
y=(
91+
["chain", "draw", "date"],
92+
data,
93+
) # Use the same dimensions as expected
94+
),
8795
coords={
8896
"chain": np.arange(n_chains),
8997
"draw": np.arange(n_draws),
9098
"date": X_new.index, # Use the DataFrame index as dates
9199
},
92100
)
93101

102+
def sample_posterior_predictive(
103+
self, X_new, extend_idata=False, combined=False, progressbar=False
104+
):
105+
# Mock implementation of sample_posterior_predictive
106+
return self.predict(X_new, extend_idata, progressbar)
107+
94108
return MockMMM()
95109

96110

0 commit comments

Comments
 (0)