@@ -85,17 +85,15 @@ def __init__(
8585        self .input_validation (data , treatment_time )
8686        self .treatment_time  =  treatment_time 
8787        self .control_units  =  control_units 
88+         self .labels  =  control_units 
8889        self .treated_units  =  treated_units 
8990        self .expt_type  =  "SyntheticControl" 
9091        # split data in to pre and post intervention 
9192        self .datapre  =  data [data .index  <  self .treatment_time ]
9293        self .datapost  =  data [data .index  >=  self .treatment_time ]
9394
94-         # split data into the 4 quadrants (pre/post, control/treated) and store as xarray dataarray 
95-         # self.datapre_control = self.datapre[self.control_units] 
96-         # self.datapre_treated = self.datapre[self.treated_units] 
97-         # self.datapost_control = self.datapost[self.control_units] 
98-         # self.datapost_treated = self.datapost[self.treated_units] 
95+         # split data into the 4 quadrants (pre/post, control/treated) and store as 
96+         # xarray DataArray objects 
9997        self .datapre_control  =  xr .DataArray (
10098            self .datapre [self .control_units ],
10199            dims = ["obs_ind" , "control_units" ],
@@ -137,14 +135,12 @@ def __init__(
137135                "obs_ind" : np .arange (self .datapre .shape [0 ]),
138136            }
139137            self .model .fit (
140-                 X = self .datapre_control . to_numpy () ,
141-                 y = self .datapre_treated . to_numpy () ,
138+                 X = self .datapre_control ,
139+                 y = self .datapre_treated ,
142140                coords = COORDS ,
143141            )
144142        elif  isinstance (self .model , RegressorMixin ):
145-             self .model .fit (
146-                 X = self .datapre_control .to_numpy (), y = self .datapre_treated .to_numpy ()
147-             )
143+             self .model .fit (X = self .datapre_control , y = self .datapre_treated )
148144        else :
149145            raise  ValueError ("Model type not recognized" )
150146
@@ -154,20 +150,10 @@ def __init__(
154150        )
155151
156152        # get the model predictions of the observed (pre-intervention) data 
157-         self .pre_pred  =  self .model .predict (X = self .datapre_control . to_numpy () )
153+         self .pre_pred  =  self .model .predict (X = self .datapre_control )
158154
159155        # calculate the counterfactual 
160-         self .post_pred  =  self .model .predict (X = self .datapost_control .to_numpy ())
161-         # TODO: Remove the need for this 'hack' by properly updating the coords when we 
162-         # run model.predict 
163-         # TEMPORARY HACK: -------------------------------------------------------------- 
164-         # : set the coords (obs_ind) for self.post_pred to be the same as the datapost 
165-         # index. This is needed for xarray to properly do the comparison (-) between 
166-         # datapre_treated and self.post_pred 
167-         # self.post_pred["posterior_predictive"] = self.post_pred[ 
168-         #     "posterior_predictive" 
169-         # ].assign_coords(obs_ind=self.datapost.index) 
170-         # ------------------------------------------------------------------------------ 
156+         self .post_pred  =  self .model .predict (X = self .datapost_control )
171157        self .pre_impact  =  self .model .calculate_impact (
172158            self .datapre_treated , self .pre_pred 
173159        )
0 commit comments