Skip to content

Commit a85ccfb

Browse files
committed
fix errors in tests + make some tests pass
1 parent ae7c405 commit a85ccfb

File tree

3 files changed

+33
-14
lines changed

3 files changed

+33
-14
lines changed

causalpy/plotting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def plot_pre_post(results, round_to=None):
5353
counterfactual_label = "Counterfactual"
5454

5555
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
56-
5756
# TOP PLOT --------------------------------------------------
5857
# pre-intervention period
5958
h_line, h_patch = plot_xY(
@@ -147,6 +146,7 @@ def plot_pre_post(results, round_to=None):
147146
labels=labels,
148147
fontsize=LEGEND_FONT_SIZE,
149148
)
149+
150150
return fig, ax
151151

152152
@staticmethod

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def test_its():
349349
350350
Loads data and checks:
351351
1. data is a dataframe
352-
2. pymc_experiments.SyntheticControl returns correct type
352+
2. pymc_experiments.InterruptedTimeSeries returns correct type
353353
3. the correct number of MCMC chains exists in the posterior inference data
354354
4. the correct number of MCMC draws exists in the posterior inference data
355355
"""
@@ -359,20 +359,23 @@ def test_its():
359359
.set_index("date")
360360
)
361361
treatment_time = pd.to_datetime("2017-01-01")
362-
result = cp.SyntheticControl(
362+
result = cp.InterruptedTimeSeries(
363363
df,
364364
treatment_time,
365365
formula="y ~ 1 + t + C(month)",
366366
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
367367
)
368368
assert isinstance(df, pd.DataFrame)
369-
assert isinstance(result, cp.SyntheticControl)
369+
assert isinstance(result, cp.InterruptedTimeSeries)
370370
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
371371
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
372372
result.summary()
373373
fig, ax = result.plot()
374374
assert isinstance(fig, plt.Figure)
375-
assert isinstance(ax, plt.Axes)
375+
# For multi-panel plots, ax should be an array of axes
376+
assert isinstance(ax, np.ndarray) and all(
377+
isinstance(item, plt.Axes) for item in ax
378+
), "ax must be a numpy.ndarray of plt.Axes"
376379

377380

378381
@pytest.mark.integration
@@ -406,7 +409,10 @@ def test_its_covid():
406409
result.summary()
407410
fig, ax = result.plot()
408411
assert isinstance(fig, plt.Figure)
409-
assert isinstance(ax, plt.Axes)
412+
# For multi-panel plots, ax should be an array of axes
413+
assert isinstance(ax, np.ndarray) and all(
414+
isinstance(item, plt.Axes) for item in ax
415+
), "ax must be a numpy.ndarray of plt.Axes"
410416

411417

412418
@pytest.mark.integration
@@ -436,7 +442,10 @@ def test_sc():
436442
result.summary()
437443
fig, ax = result.plot()
438444
assert isinstance(fig, plt.Figure)
439-
assert isinstance(ax, plt.Axes)
445+
# For multi-panel plots, ax should be an array of axes
446+
assert isinstance(ax, np.ndarray) and all(
447+
isinstance(item, plt.Axes) for item in ax
448+
), "ax must be a numpy.ndarray of plt.Axes"
440449

441450

442451
@pytest.mark.integration
@@ -478,7 +487,10 @@ def test_sc_brexit():
478487
result.summary()
479488
fig, ax = result.plot()
480489
assert isinstance(fig, plt.Figure)
481-
assert isinstance(ax, plt.Axes)
490+
# For multi-panel plots, ax should be an array of axes
491+
assert isinstance(ax, np.ndarray) and all(
492+
isinstance(item, plt.Axes) for item in ax
493+
), "ax must be a numpy.ndarray of plt.Axes"
482494

483495

484496
@pytest.mark.integration

causalpy/tests/test_integration_skl_examples.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import numpy as np
1415
import pandas as pd
1516
import pytest
1617
from matplotlib import pyplot as plt
@@ -86,7 +87,7 @@ def test_its():
8687
8788
Loads data and checks:
8889
1. data is a dataframe
89-
2. skl_experiements.SyntheticControl returns correct type
90+
2. skl_experiements.InterruptedTimeSeries returns correct type
9091
"""
9192

9293
df = (
@@ -95,18 +96,21 @@ def test_its():
9596
.set_index("date")
9697
)
9798
treatment_time = pd.to_datetime("2017-01-01")
98-
result = cp.SyntheticControl(
99+
result = cp.InterruptedTimeSeries(
99100
df,
100101
treatment_time,
101102
formula="y ~ 1 + t + C(month)",
102-
model=LinearRegression(),
103+
model=cp.skl_models.LinearRegression(),
103104
)
104105
assert isinstance(df, pd.DataFrame)
105-
assert isinstance(result, cp.SyntheticControl)
106+
assert isinstance(result, cp.InterruptedTimeSeries)
106107
result.summary()
107108
fig, ax = result.plot()
108109
assert isinstance(fig, plt.Figure)
109-
assert isinstance(ax, plt.Axes)
110+
# For multi-panel plots, ax should be an array of axes
111+
assert isinstance(ax, np.ndarray) and all(
112+
isinstance(item, plt.Axes) for item in ax
113+
), "ax must be a numpy.ndarray of plt.Axes"
110114

111115

112116
@pytest.mark.integration
@@ -131,7 +135,10 @@ def test_sc():
131135
result.summary()
132136
fig, ax = result.plot()
133137
assert isinstance(fig, plt.Figure)
134-
assert isinstance(ax, plt.Axes)
138+
# For multi-panel plots, ax should be an array of axes
139+
assert isinstance(ax, np.ndarray) and all(
140+
isinstance(item, plt.Axes) for item in ax
141+
), "ax must be a numpy.ndarray of plt.Axes"
135142

136143

137144
@pytest.mark.integration

0 commit comments

Comments
 (0)