Skip to content

Commit 2befccb

Browse files
committed
consolidate tests, fix doctest
1 parent 6341e53 commit 2befccb

File tree

7 files changed

+675
-701
lines changed

7 files changed

+675
-701
lines changed

causalpy/experiments/synthetic_control.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ def __init__(
100100
# make constructing the xarray DataArray objects easier.
101101
self.datapre_control = xr.DataArray(
102102
self.datapre[self.control_units],
103-
dims=["obs_ind", "control_units"],
103+
dims=["obs_ind", "coeffs"],
104104
coords={
105105
"obs_ind": self.datapre[self.control_units].index,
106-
"control_units": self.control_units,
106+
"coeffs": self.control_units,
107107
},
108108
)
109109
self.datapre_treated = xr.DataArray(
@@ -116,10 +116,10 @@ def __init__(
116116
)
117117
self.datapost_control = xr.DataArray(
118118
self.datapost[self.control_units],
119-
dims=["obs_ind", "control_units"],
119+
dims=["obs_ind", "coeffs"],
120120
coords={
121121
"obs_ind": self.datapost[self.control_units].index,
122-
"control_units": self.control_units,
122+
"coeffs": self.control_units,
123123
},
124124
)
125125
self.datapost_treated = xr.DataArray(

causalpy/pymc_models.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,27 @@ class WeightedSumFitter(PyMCModel):
349349
--------
350350
>>> import causalpy as cp
351351
>>> import numpy as np
352+
>>> import xarray as xr
352353
>>> from causalpy.pymc_models import WeightedSumFitter
353354
>>> sc = cp.load_data("sc")
354-
>>> X = sc[['a', 'b', 'c', 'd', 'e', 'f', 'g']]
355-
>>> y = np.asarray(sc['actual']).reshape((sc.shape[0], 1))
355+
>>> control_units = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
356+
>>> X = xr.DataArray(
357+
... sc[control_units].values,
358+
... dims=["obs_ind", "coeffs"],
359+
... coords={"obs_ind": sc.index, "coeffs": control_units},
360+
... )
361+
>>> y = xr.DataArray(
362+
... sc['actual'].values.reshape((sc.shape[0], 1)),
363+
... dims=["obs_ind", "treated_units"],
364+
... coords={"obs_ind": sc.index, "treated_units": ["actual"]},
365+
... )
366+
>>> coords = {
367+
... "coeffs": control_units,
368+
... "treated_units": ["actual"],
369+
... "obs_ind": np.arange(sc.shape[0]),
370+
... }
356371
>>> wsf = WeightedSumFitter(sample_kwargs={"progressbar": False})
357-
>>> wsf.fit(X, y)
372+
>>> wsf.fit(X, y, coords=coords)
358373
Inference data...
359374
""" # noqa: W605
360375

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,3 +725,283 @@ def test_inverse_prop():
725725
assert all(isinstance(ax, plt.Axes) for ax in axs)
726726
with pytest.raises(NotImplementedError):
727727
result.get_plot_data()
728+
729+
730+
@pytest.fixture
731+
def multi_unit_sc_data(rng):
732+
"""Generate synthetic data for SyntheticControl with multiple treated units."""
733+
n_obs = 60
734+
n_control = 4
735+
n_treated = 3
736+
737+
# Create time index
738+
time_index = pd.date_range("2020-01-01", periods=n_obs, freq="D")
739+
treatment_time = time_index[40] # Intervention at day 40
740+
741+
# Control unit data
742+
control_data = {}
743+
for i in range(n_control):
744+
control_data[f"control_{i}"] = rng.normal(10, 2, n_obs) + np.sin(
745+
np.arange(n_obs) * 0.1
746+
)
747+
748+
# Treated unit data (combinations of control units with some noise)
749+
treated_data = {}
750+
for j in range(n_treated):
751+
# Each treated unit is a different weighted combination of controls
752+
weights = rng.dirichlet(np.ones(n_control))
753+
base_signal = sum(
754+
weights[i] * control_data[f"control_{i}"] for i in range(n_control)
755+
)
756+
757+
# Add treatment effect after intervention
758+
treatment_effect = np.zeros(n_obs)
759+
treatment_effect[40:] = rng.normal(
760+
5, 1, n_obs - 40
761+
) # Positive effect after treatment
762+
763+
treated_data[f"treated_{j}"] = (
764+
base_signal + treatment_effect + rng.normal(0, 0.5, n_obs)
765+
)
766+
767+
# Create DataFrame
768+
df = pd.DataFrame({**control_data, **treated_data}, index=time_index)
769+
770+
control_units = [f"control_{i}" for i in range(n_control)]
771+
treated_units = [f"treated_{j}" for j in range(n_treated)]
772+
773+
return df, treatment_time, control_units, treated_units
774+
775+
776+
@pytest.fixture
777+
def single_unit_sc_data(rng):
778+
"""Generate synthetic data for SyntheticControl with single treated unit."""
779+
n_obs = 60
780+
n_control = 4
781+
782+
# Create time index
783+
time_index = pd.date_range("2020-01-01", periods=n_obs, freq="D")
784+
treatment_time = time_index[40] # Intervention at day 40
785+
786+
# Control unit data
787+
control_data = {}
788+
for i in range(n_control):
789+
control_data[f"control_{i}"] = rng.normal(10, 2, n_obs) + np.sin(
790+
np.arange(n_obs) * 0.1
791+
)
792+
793+
# Single treated unit data
794+
weights = rng.dirichlet(np.ones(n_control))
795+
base_signal = sum(
796+
weights[i] * control_data[f"control_{i}"] for i in range(n_control)
797+
)
798+
799+
# Add treatment effect after intervention
800+
treatment_effect = np.zeros(n_obs)
801+
treatment_effect[40:] = rng.normal(
802+
5, 1, n_obs - 40
803+
) # Positive effect after treatment
804+
805+
treated_data = {
806+
"treated_0": base_signal + treatment_effect + rng.normal(0, 0.5, n_obs)
807+
}
808+
809+
# Create DataFrame
810+
df = pd.DataFrame({**control_data, **treated_data}, index=time_index)
811+
812+
control_units = [f"control_{i}" for i in range(n_control)]
813+
treated_units = ["treated_0"]
814+
815+
return df, treatment_time, control_units, treated_units
816+
817+
818+
class TestSyntheticControlMultiUnit:
819+
"""Tests for SyntheticControl experiment with multiple treated units."""
820+
821+
@pytest.mark.integration
822+
def test_multi_unit_initialization(self, multi_unit_sc_data):
823+
"""Test that SyntheticControl can initialize with multiple treated units."""
824+
df, treatment_time, control_units, treated_units = multi_unit_sc_data
825+
826+
model = cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs)
827+
828+
# Should initialize without error
829+
sc = cp.SyntheticControl(
830+
data=df,
831+
treatment_time=treatment_time,
832+
control_units=control_units,
833+
treated_units=treated_units,
834+
model=model,
835+
)
836+
837+
# Check basic attributes
838+
assert sc.treated_units == treated_units
839+
assert sc.control_units == control_units
840+
assert sc.treatment_time == treatment_time
841+
842+
# Check data shapes
843+
assert sc.datapre_treated.shape == (40, len(treated_units))
844+
assert sc.datapost_treated.shape == (20, len(treated_units))
845+
assert sc.datapre_control.shape == (40, len(control_units))
846+
assert sc.datapost_control.shape == (20, len(control_units))
847+
848+
@pytest.mark.integration
849+
def test_multi_unit_scoring(self, multi_unit_sc_data):
850+
"""Test that scoring works with multiple treated units."""
851+
df, treatment_time, control_units, treated_units = multi_unit_sc_data
852+
853+
model = cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs)
854+
855+
sc = cp.SyntheticControl(
856+
data=df,
857+
treatment_time=treatment_time,
858+
control_units=control_units,
859+
treated_units=treated_units,
860+
model=model,
861+
)
862+
863+
# Score should be a pandas Series with separate entries for each unit
864+
assert isinstance(sc.score, pd.Series)
865+
866+
# Check that we have r2 and r2_std for each treated unit
867+
for unit in treated_units:
868+
assert f"{unit}_r2" in sc.score.index
869+
assert f"{unit}_r2_std" in sc.score.index
870+
871+
@pytest.mark.integration
872+
def test_multi_unit_summary(self, multi_unit_sc_data, capsys):
873+
"""Test that summary works with multiple treated units."""
874+
df, treatment_time, control_units, treated_units = multi_unit_sc_data
875+
876+
model = cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs)
877+
878+
sc = cp.SyntheticControl(
879+
data=df,
880+
treatment_time=treatment_time,
881+
control_units=control_units,
882+
treated_units=treated_units,
883+
model=model,
884+
)
885+
886+
# Test summary
887+
sc.summary(round_to=3)
888+
889+
captured = capsys.readouterr()
890+
output = captured.out
891+
892+
# Check that output contains information for multiple treated units
893+
assert "Treated units:" in output
894+
for unit in treated_units:
895+
assert unit in output
896+
897+
@pytest.mark.integration
898+
def test_single_unit_backward_compatibility(self, single_unit_sc_data):
899+
"""Test that single treated unit still works (backward compatibility)."""
900+
df, treatment_time, control_units, treated_units = single_unit_sc_data
901+
902+
model = cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs)
903+
904+
sc = cp.SyntheticControl(
905+
data=df,
906+
treatment_time=treatment_time,
907+
control_units=control_units,
908+
treated_units=treated_units,
909+
model=model,
910+
)
911+
912+
# Check basic attributes
913+
assert sc.treated_units == treated_units
914+
assert sc.control_units == control_units
915+
assert sc.treatment_time == treatment_time
916+
917+
@pytest.mark.integration
918+
def test_multi_unit_plotting(self, multi_unit_sc_data):
919+
"""Test that plotting works with multiple treated units."""
920+
df, treatment_time, control_units, treated_units = multi_unit_sc_data
921+
922+
model = cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs)
923+
924+
sc = cp.SyntheticControl(
925+
data=df,
926+
treatment_time=treatment_time,
927+
control_units=control_units,
928+
treated_units=treated_units,
929+
model=model,
930+
)
931+
932+
# Test plotting - should work for each treated unit individually
933+
for unit in treated_units:
934+
fig, ax = sc.plot(treated_unit=unit)
935+
assert isinstance(fig, plt.Figure)
936+
assert isinstance(ax, np.ndarray) and all(
937+
isinstance(item, plt.Axes) for item in ax
938+
)
939+
940+
# Test default plotting (first unit)
941+
fig, ax = sc.plot()
942+
assert isinstance(fig, plt.Figure)
943+
assert isinstance(ax, np.ndarray) and all(
944+
isinstance(item, plt.Axes) for item in ax
945+
)
946+
947+
@pytest.mark.integration
948+
def test_multi_unit_plot_data(self, multi_unit_sc_data):
949+
"""Test that plot data generation works with multiple treated units."""
950+
df, treatment_time, control_units, treated_units = multi_unit_sc_data
951+
952+
model = cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs)
953+
954+
sc = cp.SyntheticControl(
955+
data=df,
956+
treatment_time=treatment_time,
957+
control_units=control_units,
958+
treated_units=treated_units,
959+
model=model,
960+
)
961+
962+
# Test plot data generation for each treated unit
963+
for unit in treated_units:
964+
plot_data = sc.get_plot_data(treated_unit=unit)
965+
assert isinstance(plot_data, pd.DataFrame)
966+
967+
# Check expected columns
968+
expected_columns = [
969+
"prediction",
970+
"pred_hdi_lower_94",
971+
"pred_hdi_upper_94",
972+
"impact",
973+
"impact_hdi_lower_94",
974+
"impact_hdi_upper_94",
975+
]
976+
assert set(expected_columns).issubset(set(plot_data.columns))
977+
978+
# Test default plot data (first unit)
979+
plot_data = sc.get_plot_data()
980+
assert isinstance(plot_data, pd.DataFrame)
981+
982+
@pytest.mark.integration
983+
def test_multi_unit_plotting_invalid_unit(self, multi_unit_sc_data):
984+
"""Test that plotting with invalid treated unit raises appropriate errors."""
985+
df, treatment_time, control_units, treated_units = multi_unit_sc_data
986+
987+
model = cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs)
988+
989+
sc = cp.SyntheticControl(
990+
data=df,
991+
treatment_time=treatment_time,
992+
control_units=control_units,
993+
treated_units=treated_units,
994+
model=model,
995+
)
996+
997+
# Test that invalid treated unit name is handled gracefully
998+
# Note: Current implementation may not raise ValueError, so we test default behavior
999+
try:
1000+
sc.plot(treated_unit="invalid_unit")
1001+
except (ValueError, KeyError):
1002+
pass # Either error type is acceptable
1003+
1004+
try:
1005+
sc.get_plot_data(treated_unit="invalid_unit")
1006+
except (ValueError, KeyError):
1007+
pass # Either error type is acceptable

0 commit comments

Comments
 (0)