Skip to content

Commit 82d041d

Browse files
committed
update API in tests for SyntheticControl class
1 parent 98127fd commit 82d041d

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

causalpy/tests/test_input_validation.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ def test_sc_input_error():
166166
_ = cp.SyntheticControl(
167167
df,
168168
treatment_time,
169-
formula="actual ~ 0 + a + b + c + d + e + f + g",
169+
control_units=["a", "b", "c", "d", "e", "f", "g"],
170+
treated_units=["actual"],
170171
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
171172
)
172173

@@ -176,7 +177,8 @@ def test_sc_input_error():
176177
_ = cp.SyntheticControl(
177178
df,
178179
treatment_time,
179-
formula="actual ~ 0 + a + b + c + d + e + f + g",
180+
control_units=["a", "b", "c", "d", "e", "f", "g"],
181+
treated_units=["actual"],
180182
model=cp.skl_models.WeightedProportion(),
181183
)
182184

@@ -196,11 +198,11 @@ def test_sc_brexit_input_error():
196198
other_countries = all_countries.difference({target_country})
197199
all_countries = list(all_countries)
198200
other_countries = list(other_countries)
199-
formula = target_country + " ~ " + "0 + " + " + ".join(other_countries)
200201
_ = cp.SyntheticControl(
201202
df,
202203
treatment_time,
203-
formula=formula,
204+
control_units=other_countries,
205+
treated_units=[target_country],
204206
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
205207
)
206208

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,8 @@ def test_sc():
474474
result = cp.SyntheticControl(
475475
df,
476476
treatment_time,
477-
formula="actual ~ 0 + a + b + c + d + e + f + g",
477+
control_units=["a", "b", "c", "d", "e", "f", "g"],
478+
treated_units=["actual"],
478479
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
479480
)
480481
assert isinstance(df, pd.DataFrame)
@@ -540,11 +541,11 @@ def test_sc_brexit():
540541
other_countries = all_countries.difference({target_country})
541542
all_countries = list(all_countries)
542543
other_countries = list(other_countries)
543-
formula = target_country + " ~ " + "0 + " + " + ".join(other_countries)
544544
result = cp.SyntheticControl(
545545
df,
546546
treatment_time,
547-
formula=formula,
547+
control_units=other_countries,
548+
treated_units=[target_country],
548549
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
549550
)
550551
assert isinstance(df, pd.DataFrame)
@@ -629,8 +630,8 @@ def test_geolift1():
629630
result = cp.SyntheticControl(
630631
df,
631632
treatment_time,
632-
formula="""Denmark ~ 0 + Austria + Belgium + Bulgaria + Croatia + Cyprus
633-
+ Czech_Republic""",
633+
control_units=["Austria", "Belgium", "Bulgaria", "Croatia", "Cyprus"],
634+
treated_units=["Denmark"],
634635
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
635636
)
636637
assert isinstance(df, pd.DataFrame)

0 commit comments

Comments
 (0)