@@ -474,7 +474,8 @@ def test_sc():
474474 result = cp .SyntheticControl (
475475 df ,
476476 treatment_time ,
477- formula = "actual ~ 0 + a + b + c + d + e + f + g" ,
477+ control_units = ["a" , "b" , "c" , "d" , "e" , "f" , "g" ],
478+ treated_units = ["actual" ],
478479 model = cp .pymc_models .WeightedSumFitter (sample_kwargs = sample_kwargs ),
479480 )
480481 assert isinstance (df , pd .DataFrame )
@@ -540,11 +541,11 @@ def test_sc_brexit():
540541 other_countries = all_countries .difference ({target_country })
541542 all_countries = list (all_countries )
542543 other_countries = list (other_countries )
543- formula = target_country + " ~ " + "0 + " + " + " .join (other_countries )
544544 result = cp .SyntheticControl (
545545 df ,
546546 treatment_time ,
547- formula = formula ,
547+ control_units = other_countries ,
548+ treated_units = [target_country ],
548549 model = cp .pymc_models .WeightedSumFitter (sample_kwargs = sample_kwargs ),
549550 )
550551 assert isinstance (df , pd .DataFrame )
@@ -629,8 +630,8 @@ def test_geolift1():
629630 result = cp .SyntheticControl (
630631 df ,
631632 treatment_time ,
632- formula = """Denmark ~ 0 + Austria + Belgium + Bulgaria + Croatia + Cyprus
633- + Czech_Republic""" ,
633+ control_units = [ "Austria" , " Belgium" , " Bulgaria" , " Croatia" , " Cyprus" ],
634+ treated_units = [ "Denmark" ] ,
634635 model = cp .pymc_models .WeightedSumFitter (sample_kwargs = sample_kwargs ),
635636 )
636637 assert isinstance (df , pd .DataFrame )
0 commit comments