Skip to content

Commit eac1ef3

Browse files
committed
code simplification in _ols_plot
1 parent ebddbb5 commit eac1ef3

File tree

1 file changed

+2
-16
lines changed

1 file changed

+2
-16
lines changed

causalpy/experiments/synthetic_control.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -403,22 +403,8 @@ def _ols_plot(
403403
c="k",
404404
)
405405
ax[0].set(title=f"{self._get_score_title(treated_unit, round_to)}")
406-
# Shaded causal effect - handle different prediction formats
407-
try:
408-
# For OLS, predictions might be simple arrays
409-
post_pred_values = np.squeeze(self.post_pred)
410-
except (TypeError, AttributeError):
411-
# TODO: WILL THIS PATH EVERY BIT HIT?
412-
# For PyMC predictions (InferenceData)
413-
post_pred_values = (
414-
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
415-
.mean("sample")
416-
.values
417-
)
418-
if len(post_pred_values.shape) > 1:
419-
post_pred_values = post_pred_values[
420-
:, 0
421-
] # Take first treated unit for OLS
406+
# Shaded causal effect
407+
post_pred_values = np.squeeze(self.post_pred)
422408

423409
ax[0].fill_between(
424410
self.datapost.index,

0 commit comments

Comments
 (0)