Skip to content

Commit e2263ea

Browse files
committed
fix diverging branch
2 parents 5521a07 + 3f813ac commit e2263ea

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

causalpy/experiments/prepostfit.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 The PyMC Labs Developers
1+
# Copyright 2025 The PyMC Labs Developers
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@
2525
from sklearn.base import RegressorMixin
2626

2727
from causalpy.custom_exceptions import BadIndexException
28-
from causalpy.plot_utils import plot_xY, get_hdi_to_df
28+
from causalpy.plot_utils import get_hdi_to_df, plot_xY
2929
from causalpy.pymc_models import PyMCModel
3030
from causalpy.utils import round_num
3131

@@ -320,13 +320,21 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
320320
.mean("sample")
321321
.values
322322
)
323-
pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob).set_index(pre_data.index)
324-
post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob).set_index(post_data.index)
323+
pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(
324+
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
325+
).set_index(pre_data.index)
326+
post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(
327+
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
328+
).set_index(post_data.index)
325329

326330
pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
327331
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
328-
pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.pre_impact, hdi_prob=hdi_prob).set_index(pre_data.index)
329-
post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.post_impact, hdi_prob=hdi_prob).set_index(post_data.index)
332+
pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(
333+
self.pre_impact, hdi_prob=hdi_prob
334+
).set_index(pre_data.index)
335+
post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(
336+
self.post_impact, hdi_prob=hdi_prob
337+
).set_index(post_data.index)
330338

331339
self.plot_data = pd.concat([pre_data, post_data])
332340

0 commit comments

Comments
 (0)