Skip to content

Commit c77de98

Browse files
committed
fix plotting of control units for synthetic control + add test coverage
1 parent 41bf080 commit c77de98

File tree

10 files changed

+130
-164
lines changed

10 files changed

+130
-164
lines changed

causalpy/expt_prepostfit.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,25 @@ class SyntheticControl(PrePostFit):
182182
"""
183183

184184
expt_type = "SyntheticControl"
185+
186+
def plot(self, round_to=None, plot_predictors: bool = False):
187+
"""
188+
Plot the results
189+
190+
:param round_to:
191+
Number of decimals used to round results. Defaults to 2. Use "None" to
192+
return raw numbers.
193+
:param plot_predictors:
194+
Whether to plot the control units as well. Defaults to False.
195+
"""
196+
# Get a BayesianPlotComponent or OLSPlotComponent depending on the model
197+
plot_component = self.model.get_plot_component()
198+
fig, ax = plot_component.plot_pre_post(self)
199+
if plot_predictors:
200+
# plot control units as well
201+
ax[0].plot(self.datapre.index, self.pre_X, "-", c=[0.8, 0.8, 0.8], zorder=1)
202+
ax[0].plot(
203+
self.datapost.index, self.post_X, "-", c=[0.8, 0.8, 0.8], zorder=1
204+
)
205+
206+
return fig, ax

causalpy/plotting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class BayesianPlotComponent(PlotComponent):
3535
"""Plotting component for Bayesian models."""
3636

3737
@staticmethod
38-
def plot_pre_post(results, round_to=None):
38+
def plot_pre_post(results, round_to=None, counterfactual_label=None):
3939
"""Generate plot for pre-post experiment types, such as Interrupted Time Series
4040
and Synthetic Control."""
4141
datapre = results.datapre

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,13 +440,21 @@ def test_sc():
440440
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
441441
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
442442
result.summary()
443+
443444
fig, ax = result.plot()
444445
assert isinstance(fig, plt.Figure)
445446
# For multi-panel plots, ax should be an array of axes
446447
assert isinstance(ax, np.ndarray) and all(
447448
isinstance(item, plt.Axes) for item in ax
448449
), "ax must be a numpy.ndarray of plt.Axes"
449450

451+
fig, ax = result.plot(plot_predictors=True)
452+
assert isinstance(fig, plt.Figure)
453+
# For multi-panel plots, ax should be an array of axes
454+
assert isinstance(ax, np.ndarray) and all(
455+
isinstance(item, plt.Axes) for item in ax
456+
), "ax must be a numpy.ndarray of plt.Axes"
457+
450458

451459
@pytest.mark.integration
452460
def test_sc_brexit():
@@ -485,6 +493,7 @@ def test_sc_brexit():
485493
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
486494
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
487495
result.summary()
496+
488497
fig, ax = result.plot()
489498
assert isinstance(fig, plt.Figure)
490499
# For multi-panel plots, ax should be an array of axes

causalpy/tests/test_integration_skl_examples.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@ def test_sc():
133133
assert isinstance(df, pd.DataFrame)
134134
assert isinstance(result, cp.SyntheticControl)
135135
result.summary()
136+
137+
fig, ax = result.plot()
138+
assert isinstance(fig, plt.Figure)
139+
# For multi-panel plots, ax should be an array of axes
140+
assert isinstance(ax, np.ndarray) and all(
141+
isinstance(item, plt.Axes) for item in ax
142+
), "ax must be a numpy.ndarray of plt.Axes"
143+
136144
fig, ax = result.plot()
137145
assert isinstance(fig, plt.Figure)
138146
# For multi-panel plots, ax should be an array of axes

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

docs/source/notebooks/geolift1.ipynb

Lines changed: 30 additions & 37 deletions
Large diffs are not rendered by default.

docs/source/notebooks/multi_cell_geolift.ipynb

Lines changed: 14 additions & 14 deletions
Large diffs are not rendered by default.

docs/source/notebooks/sc_pymc.ipynb

Lines changed: 11 additions & 11 deletions
Large diffs are not rendered by default.

docs/source/notebooks/sc_pymc_brexit.ipynb

Lines changed: 19 additions & 19 deletions
Large diffs are not rendered by default.

docs/source/notebooks/sc_skl.ipynb

Lines changed: 13 additions & 79 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)