Skip to content

Commit 44d3870

Browse files
committed
added dynamic naming for hdi columns in _get_plot_data_bayesian, updated tests accordingly, and updated tests' docstring
1 parent 0edca77 commit 44d3870

File tree

3 files changed

+89
-22
lines changed

3 files changed

+89
-22
lines changed

causalpy/experiments/prepostfit.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,13 @@ def _get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
308308
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
309309
"""
310310
if isinstance(self.model, PyMCModel):
311+
hdi_pct = int(round(hdi_prob * 100))
312+
313+
pred_lower_col = f"pred_hdi_lower_{hdi_pct}"
314+
pred_upper_col = f"pred_hdi_upper_{hdi_pct}"
315+
impact_lower_col = f"impact_hdi_lower_{hdi_pct}"
316+
impact_upper_col = f"impact_hdi_upper_{hdi_pct}"
317+
311318
pre_data = self.datapre.copy()
312319
post_data = self.datapost.copy()
313320

@@ -321,19 +328,19 @@ def _get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
321328
.mean("sample")
322329
.values
323330
)
324-
pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(
331+
pre_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
325332
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
326333
).set_index(pre_data.index)
327-
post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(
334+
post_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
328335
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
329336
).set_index(post_data.index)
330337

331338
pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
332339
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
333-
pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(
340+
pre_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
334341
self.pre_impact, hdi_prob=hdi_prob
335342
).set_index(pre_data.index)
336-
post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(
343+
post_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
337344
self.post_impact, hdi_prob=hdi_prob
338345
).set_index(post_data.index)
339346

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def test_its():
353353
2. causalpy.InterruptedTimeSeries returns correct type
354354
3. the correct number of MCMC chains exists in the posterior inference data
355355
4. the correct number of MCMC draws exists in the posterior inference data
356+
5. the method get_plot_data returns a DataFrame with expected columns
356357
"""
357358
df = (
358359
cp.load_data("its")
@@ -378,9 +379,21 @@ def test_its():
378379
isinstance(item, plt.Axes) for item in ax
379380
), "ax must be a numpy.ndarray of plt.Axes"
380381
plot_data = result.get_plot_data()
381-
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
382-
expected_columns = ['prediction', 'pred_hdi_lower', 'pred_hdi_upper', 'impact', 'impact_hdi_lower', 'impact_hdi_upper']
383-
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
382+
assert isinstance(plot_data, pd.DataFrame), (
383+
"The returned object is not a pandas DataFrame"
384+
)
385+
expected_columns = [
386+
"prediction",
387+
"pred_hdi_lower_94",
388+
"pred_hdi_upper_94",
389+
"impact",
390+
"impact_hdi_lower_94",
391+
"impact_hdi_upper_94",
392+
]
393+
assert set(expected_columns).issubset(set(plot_data.columns)), (
394+
f"DataFrame is missing expected columns {expected_columns}"
395+
)
396+
384397

385398
@pytest.mark.integration
386399
def test_its_covid():
@@ -392,6 +405,7 @@ def test_its_covid():
392405
2. causalpy.InterruptedtimeSeries returns correct type
393406
3. the correct number of MCMC chains exists in the posterior inference data
394407
4. the correct number of MCMC draws exists in the posterior inference data
408+
5. the method get_plot_data returns a DataFrame with expected columns
395409
"""
396410

397411
df = (
@@ -418,9 +432,20 @@ def test_its_covid():
418432
isinstance(item, plt.Axes) for item in ax
419433
), "ax must be a numpy.ndarray of plt.Axes"
420434
plot_data = result.get_plot_data()
421-
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
422-
expected_columns = ['prediction', 'pred_hdi_lower', 'pred_hdi_upper', 'impact', 'impact_hdi_lower', 'impact_hdi_upper']
423-
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
435+
assert isinstance(plot_data, pd.DataFrame), (
436+
"The returned object is not a pandas DataFrame"
437+
)
438+
expected_columns = [
439+
"prediction",
440+
"pred_hdi_lower_94",
441+
"pred_hdi_upper_94",
442+
"impact",
443+
"impact_hdi_lower_94",
444+
"impact_hdi_upper_94",
445+
]
446+
assert set(expected_columns).issubset(set(plot_data.columns)), (
447+
f"DataFrame is missing expected columns {expected_columns}"
448+
)
424449

425450

426451
@pytest.mark.integration
@@ -433,6 +458,7 @@ def test_sc():
433458
2. causalpy.SyntheticControl returns correct type
434459
3. the correct number of MCMC chains exists in the posterior inference data
435460
4. the correct number of MCMC draws exists in the posterior inference data
461+
5. the method get_plot_data returns a DataFrame with expected columns
436462
"""
437463

438464
df = cp.load_data("sc")
@@ -463,9 +489,21 @@ def test_sc():
463489
isinstance(item, plt.Axes) for item in ax
464490
), "ax must be a numpy.ndarray of plt.Axes"
465491
plot_data = result.get_plot_data()
466-
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
467-
expected_columns = ['prediction', 'pred_hdi_lower', 'pred_hdi_upper', 'impact', 'impact_hdi_lower', 'impact_hdi_upper']
468-
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
492+
assert isinstance(plot_data, pd.DataFrame), (
493+
"The returned object is not a pandas DataFrame"
494+
)
495+
expected_columns = [
496+
"prediction",
497+
"pred_hdi_lower_94",
498+
"pred_hdi_upper_94",
499+
"impact",
500+
"impact_hdi_lower_94",
501+
"impact_hdi_upper_94",
502+
]
503+
assert set(expected_columns).issubset(set(plot_data.columns)), (
504+
f"DataFrame is missing expected columns {expected_columns}"
505+
)
506+
469507

470508
@pytest.mark.integration
471509
def test_sc_brexit():
@@ -477,6 +515,7 @@ def test_sc_brexit():
477515
2. causalpy.SyntheticControl returns correct type
478516
3. the correct number of MCMC chains exists in the posterior inference data
479517
4. the correct number of MCMC draws exists in the posterior inference data
518+
5. the method get_plot_data returns a DataFrame with expected columns
480519
"""
481520

482521
df = (
@@ -512,9 +551,20 @@ def test_sc_brexit():
512551
isinstance(item, plt.Axes) for item in ax
513552
), "ax must be a numpy.ndarray of plt.Axes"
514553
plot_data = result.get_plot_data()
515-
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
516-
expected_columns = ['prediction', 'pred_hdi_lower', 'pred_hdi_upper', 'impact', 'impact_hdi_lower', 'impact_hdi_upper']
517-
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
554+
assert isinstance(plot_data, pd.DataFrame), (
555+
"The returned object is not a pandas DataFrame"
556+
)
557+
expected_columns = [
558+
"prediction",
559+
"pred_hdi_lower_94",
560+
"pred_hdi_upper_94",
561+
"impact",
562+
"impact_hdi_lower_94",
563+
"impact_hdi_upper_94",
564+
]
565+
assert set(expected_columns).issubset(set(plot_data.columns)), (
566+
f"DataFrame is missing expected columns {expected_columns}"
567+
)
518568

519569

520570
@pytest.mark.integration

causalpy/tests/test_integration_skl_examples.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def test_its():
8888
Loads data and checks:
8989
1. data is a dataframe
9090
2. skl_experiements.InterruptedTimeSeries returns correct type
91+
3. the method get_plot_data returns a DataFrame with expected columns
9192
"""
9293

9394
df = (
@@ -112,9 +113,13 @@ def test_its():
112113
isinstance(item, plt.Axes) for item in ax
113114
), "ax must be a numpy.ndarray of plt.Axes"
114115
plot_data = result.get_plot_data()
115-
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
116-
expected_columns = ['prediction', 'impact']
117-
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
116+
assert isinstance(plot_data, pd.DataFrame), (
117+
"The returned object is not a pandas DataFrame"
118+
)
119+
expected_columns = ["prediction", "impact"]
120+
assert set(expected_columns).issubset(set(plot_data.columns)), (
121+
f"DataFrame is missing expected columns {expected_columns}"
122+
)
118123

119124

120125
@pytest.mark.integration
@@ -125,6 +130,7 @@ def test_sc():
125130
Loads data and checks:
126131
1. data is a dataframe
127132
2. skl_experiements.SyntheticControl returns correct type
133+
3. the method get_plot_data returns a DataFrame with expected columns
128134
"""
129135
df = cp.load_data("sc")
130136
treatment_time = 70
@@ -152,9 +158,13 @@ def test_sc():
152158
isinstance(item, plt.Axes) for item in ax
153159
), "ax must be a numpy.ndarray of plt.Axes"
154160
plot_data = result.get_plot_data()
155-
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
156-
expected_columns = ['prediction', 'impact']
157-
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
161+
assert isinstance(plot_data, pd.DataFrame), (
162+
"The returned object is not a pandas DataFrame"
163+
)
164+
expected_columns = ["prediction", "impact"]
165+
assert set(expected_columns).issubset(set(plot_data.columns)), (
166+
f"DataFrame is missing expected columns {expected_columns}"
167+
)
158168

159169

160170
@pytest.mark.integration

0 commit comments

Comments
 (0)