Skip to content

Commit 0247dae

Browse files
committed
treatment_time input validation + associated test
1 parent 16f925b commit 0247dae

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed

causalpy/pymc_experiments.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class TimeSeriesExperiment(ExperimentalDesign):
5454
def __init__(
5555
self,
5656
data: pd.DataFrame,
57-
treatment_time: int,
57+
treatment_time: int | float | pd.Timestamp,
5858
formula: str,
5959
model=None,
6060
**kwargs,
@@ -81,6 +81,12 @@ def __init__(
8181
self.post_X = np.asarray(new_x)
8282
self.post_y = np.asarray(new_y)
8383

84+
# Input validation
85+
if isinstance(data.index, pd.DatetimeIndex):
86+
assert isinstance(
87+
treatment_time, pd.Timestamp
88+
), "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
89+
8490
# DEVIATION FROM SKL EXPERIMENT CODE =============================
8591
# fit the model to the observed (pre-intervention) data
8692
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.pre_X.shape[0])}

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,31 @@ def test_sc_brexit():
217217
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
218218

219219

220+
@pytest.mark.integration
221+
def test_sc_brexit_input_error():
222+
"""Test that an error is raised if the data index is datetime and the treatment time
223+
is not pd.Timestamp."""
224+
with pytest.raises(AssertionError):
225+
df = cp.load_data("brexit")
226+
df["Time"] = pd.to_datetime(df["Time"])
227+
df.set_index("Time", inplace=True)
228+
df = df.iloc[df.index > "2009", :]
229+
treatment_time = "2016 June 24" # NOTE This is not of type pd.Timestamp
230+
df = df.drop(["Japan", "Italy", "US", "Spain"], axis=1)
231+
target_country = "UK"
232+
all_countries = df.columns
233+
other_countries = all_countries.difference({target_country})
234+
all_countries = list(all_countries)
235+
other_countries = list(other_countries)
236+
formula = target_country + " ~ " + "0 + " + " + ".join(other_countries)
237+
_ = cp.pymc_experiments.SyntheticControl(
238+
df,
239+
treatment_time,
240+
formula=formula,
241+
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
242+
)
243+
244+
220245
@pytest.mark.integration
221246
def test_ancova():
222247
df = cp.load_data("anova1")

img/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)