Skip to content

Commit 3d29fef

Browse files
committed
fix bug with SyntheticControl.get_plot_data_bayesian
1 parent 876c154 commit 3d29fef

File tree

1 file changed

+45
-40
lines changed

1 file changed

+45
-40
lines changed

causalpy/experiments/synthetic_control.py

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def __init__(
135135
if isinstance(self.model, PyMCModel):
136136
COORDS = {
137137
# key must stay as "coeffs" unless we can find a way to auto identify
138-
# the predictor dimension name
138+
# the predictor dimension name. "coeffs" is assumed by
139+
# PyMCModel.print_coefficients for example.
139140
"coeffs": self.control_units,
140141
"treated_units": self.treated_units,
141142
"obs_ind": np.arange(self.datapre.shape[0]),
@@ -423,45 +424,49 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
423424
:param hdi_prob:
424425
Prob for which the highest density interval will be computed. The default value is defined as the default from the :func:`arviz.hdi` function.
425426
"""
426-
if isinstance(self.model, PyMCModel):
427-
hdi_pct = int(round(hdi_prob * 100))
427+
if not isinstance(self.model, PyMCModel):
428+
raise ValueError("Unsupported model type")
428429

429-
pred_lower_col = f"pred_hdi_lower_{hdi_pct}"
430-
pred_upper_col = f"pred_hdi_upper_{hdi_pct}"
431-
impact_lower_col = f"impact_hdi_lower_{hdi_pct}"
432-
impact_upper_col = f"impact_hdi_upper_{hdi_pct}"
430+
hdi_pct = int(round(hdi_prob * 100))
433431

434-
pre_data = self.datapre.copy()
435-
post_data = self.datapost.copy()
432+
pred_lower_col = f"pred_hdi_lower_{hdi_pct}"
433+
pred_upper_col = f"pred_hdi_upper_{hdi_pct}"
434+
impact_lower_col = f"impact_hdi_lower_{hdi_pct}"
435+
impact_upper_col = f"impact_hdi_upper_{hdi_pct}"
436436

437-
pre_data["prediction"] = (
438-
az.extract(self.pre_pred, group="posterior_predictive", var_names="mu")
439-
.mean("sample")
440-
.values
441-
)
442-
post_data["prediction"] = (
443-
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
444-
.mean("sample")
445-
.values
446-
)
447-
pre_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
448-
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
449-
).set_index(pre_data.index)
450-
post_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
451-
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
452-
).set_index(post_data.index)
453-
454-
pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
455-
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
456-
pre_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
457-
self.pre_impact, hdi_prob=hdi_prob
458-
).set_index(pre_data.index)
459-
post_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
460-
self.post_impact, hdi_prob=hdi_prob
461-
).set_index(post_data.index)
462-
463-
self.plot_data = pd.concat([pre_data, post_data])
464-
465-
return self.plot_data
466-
else:
467-
raise ValueError("Unsupported model type")
437+
pre_data = self.datapre.copy()
438+
post_data = self.datapost.copy()
439+
440+
pre_data["prediction"] = (
441+
az.extract(self.pre_pred, group="posterior_predictive", var_names="mu")
442+
.mean("sample")
443+
.values
444+
)
445+
post_data["prediction"] = (
446+
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
447+
.mean("sample")
448+
.values
449+
)
450+
pre_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
451+
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
452+
).set_index(pre_data.index)
453+
post_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
454+
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
455+
).set_index(post_data.index)
456+
457+
pre_data["impact"] = (
458+
self.pre_impact.mean(dim=["chain", "draw"]).isel(treated_units=0).values
459+
)
460+
post_data["impact"] = (
461+
self.post_impact.mean(dim=["chain", "draw"]).isel(treated_units=0).values
462+
)
463+
pre_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
464+
self.pre_impact, hdi_prob=hdi_prob
465+
).set_index(pre_data.index)
466+
post_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
467+
self.post_impact, hdi_prob=hdi_prob
468+
).set_index(post_data.index)
469+
470+
self.plot_data = pd.concat([pre_data, post_data])
471+
472+
return self.plot_data

0 commit comments

Comments
 (0)