|
1 |
| -# Copyright 2024 The PyMC Labs Developers |
| 1 | +# Copyright 2025 The PyMC Labs Developers |
2 | 2 | #
|
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License");
|
4 | 4 | # you may not use this file except in compliance with the License.
|
|
25 | 25 | from sklearn.base import RegressorMixin
|
26 | 26 |
|
27 | 27 | from causalpy.custom_exceptions import BadIndexException
|
28 |
| -from causalpy.plot_utils import plot_xY, get_hdi_to_df |
| 28 | +from causalpy.plot_utils import get_hdi_to_df, plot_xY |
29 | 29 | from causalpy.pymc_models import PyMCModel
|
30 | 30 | from causalpy.utils import round_num
|
31 | 31 |
|
@@ -320,13 +320,21 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
|
320 | 320 | .mean("sample")
|
321 | 321 | .values
|
322 | 322 | )
|
323 |
| - pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob).set_index(pre_data.index) |
324 |
| - post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob).set_index(post_data.index) |
| 323 | + pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df( |
| 324 | + self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob |
| 325 | + ).set_index(pre_data.index) |
| 326 | + post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df( |
| 327 | + self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob |
| 328 | + ).set_index(post_data.index) |
325 | 329 |
|
326 | 330 | pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
|
327 | 331 | post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
|
328 |
| - pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.pre_impact, hdi_prob=hdi_prob).set_index(pre_data.index) |
329 |
| - post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.post_impact, hdi_prob=hdi_prob).set_index(post_data.index) |
| 332 | + pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df( |
| 333 | + self.pre_impact, hdi_prob=hdi_prob |
| 334 | + ).set_index(pre_data.index) |
| 335 | + post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df( |
| 336 | + self.post_impact, hdi_prob=hdi_prob |
| 337 | + ).set_index(post_data.index) |
330 | 338 |
|
331 | 339 | self.plot_data = pd.concat([pre_data, post_data])
|
332 | 340 |
|
|
0 commit comments