Skip to content

Commit 97f0d79

Browse files
committed
added tests to get_plot_data for pymc and skl experiments
1 parent 6a6face commit 97f0d79

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,10 @@ def test_its():
377377
assert isinstance(ax, np.ndarray) and all(
378378
isinstance(item, plt.Axes) for item in ax
379379
), "ax must be a numpy.ndarray of plt.Axes"
380-
380+
plot_data = result.get_plot_data()
381+
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
382+
expected_columns = ['prediction', 'pred_hdi_lower', 'pred_hdi_upper', 'impact', 'impact_hdi_lower', 'impact_hdi_upper']
383+
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
381384

382385
@pytest.mark.integration
383386
def test_its_covid():
@@ -414,6 +417,10 @@ def test_its_covid():
414417
assert isinstance(ax, np.ndarray) and all(
415418
isinstance(item, plt.Axes) for item in ax
416419
), "ax must be a numpy.ndarray of plt.Axes"
420+
plot_data = result.get_plot_data()
421+
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
422+
expected_columns = ['prediction', 'pred_hdi_lower', 'pred_hdi_upper', 'impact', 'impact_hdi_lower', 'impact_hdi_upper']
423+
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
417424

418425

419426
@pytest.mark.integration
@@ -455,7 +462,10 @@ def test_sc():
455462
assert isinstance(ax, np.ndarray) and all(
456463
isinstance(item, plt.Axes) for item in ax
457464
), "ax must be a numpy.ndarray of plt.Axes"
458-
465+
plot_data = result.get_plot_data()
466+
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
467+
expected_columns = ['prediction', 'pred_hdi_lower', 'pred_hdi_upper', 'impact', 'impact_hdi_lower', 'impact_hdi_upper']
468+
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
459469

460470
@pytest.mark.integration
461471
def test_sc_brexit():
@@ -501,6 +511,10 @@ def test_sc_brexit():
501511
assert isinstance(ax, np.ndarray) and all(
502512
isinstance(item, plt.Axes) for item in ax
503513
), "ax must be a numpy.ndarray of plt.Axes"
514+
plot_data = result.get_plot_data()
515+
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
516+
expected_columns = ['prediction', 'pred_hdi_lower', 'pred_hdi_upper', 'impact', 'impact_hdi_lower', 'impact_hdi_upper']
517+
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
504518

505519

506520
@pytest.mark.integration

causalpy/tests/test_integration_skl_examples.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ def test_its():
111111
assert isinstance(ax, np.ndarray) and all(
112112
isinstance(item, plt.Axes) for item in ax
113113
), "ax must be a numpy.ndarray of plt.Axes"
114+
plot_data = result.get_plot_data()
115+
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
116+
expected_columns = ['prediction', 'impact']
117+
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
114118

115119

116120
@pytest.mark.integration
@@ -147,6 +151,10 @@ def test_sc():
147151
assert isinstance(ax, np.ndarray) and all(
148152
isinstance(item, plt.Axes) for item in ax
149153
), "ax must be a numpy.ndarray of plt.Axes"
154+
plot_data = result.get_plot_data()
155+
assert isinstance(plot_data, pd.DataFrame), "The returned object is not a pandas DataFrame"
156+
expected_columns = ['prediction', 'impact']
157+
assert set(expected_columns).issubset(set(plot_data.columns)), f"DataFrame is missing expected columns {expected_columns}"
150158

151159

152160
@pytest.mark.integration

0 commit comments

Comments
 (0)