Skip to content

Commit fad78d8

Browse files
committed
add additional asserts to integration tests to detect shape problems
1 parent c89a147 commit fad78d8

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,12 @@ def test_its():
375375
)
376376
assert isinstance(df, pd.DataFrame)
377377
assert isinstance(result, cp.InterruptedTimeSeries)
378+
assert result.pre_impact.shape[1] == len(result.treated_units), (
379+
"Mismatch between pre_impact shape and number of treated units"
380+
)
381+
assert result.post_impact.shape[1] == len(result.treated_units), (
382+
"Mismatch between post_impact shape and number of treated units"
383+
)
378384
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
379385
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
380386
result.summary()

causalpy/tests/test_integration_skl_examples.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ def test_its():
109109
)
110110
assert isinstance(df, pd.DataFrame)
111111
assert isinstance(result, cp.InterruptedTimeSeries)
112+
assert result.pre_impact.shape[1] == len(result.treated_units), (
113+
"Mismatch between pre_impact shape and number of treated units"
114+
)
115+
assert result.post_impact.shape[1] == len(result.treated_units), (
116+
"Mismatch between post_impact shape and number of treated units"
117+
)
112118
result.summary()
113119
fig, ax = result.plot()
114120
assert isinstance(fig, plt.Figure)

0 commit comments

Comments
 (0)