@@ -85,17 +85,15 @@ def __init__(
85
85
self .input_validation (data , treatment_time )
86
86
self .treatment_time = treatment_time
87
87
self .control_units = control_units
88
+ self .labels = control_units
88
89
self .treated_units = treated_units
89
90
self .expt_type = "SyntheticControl"
90
91
# split data in to pre and post intervention
91
92
self .datapre = data [data .index < self .treatment_time ]
92
93
self .datapost = data [data .index >= self .treatment_time ]
93
94
94
- # split data into the 4 quadrants (pre/post, control/treated) and store as xarray dataarray
95
- # self.datapre_control = self.datapre[self.control_units]
96
- # self.datapre_treated = self.datapre[self.treated_units]
97
- # self.datapost_control = self.datapost[self.control_units]
98
- # self.datapost_treated = self.datapost[self.treated_units]
95
+ # split data into the 4 quadrants (pre/post, control/treated) and store as
96
+ # xarray DataArray objects
99
97
self .datapre_control = xr .DataArray (
100
98
self .datapre [self .control_units ],
101
99
dims = ["obs_ind" , "control_units" ],
@@ -137,14 +135,12 @@ def __init__(
137
135
"obs_ind" : np .arange (self .datapre .shape [0 ]),
138
136
}
139
137
self .model .fit (
140
- X = self .datapre_control . to_numpy () ,
141
- y = self .datapre_treated . to_numpy () ,
138
+ X = self .datapre_control ,
139
+ y = self .datapre_treated ,
142
140
coords = COORDS ,
143
141
)
144
142
elif isinstance (self .model , RegressorMixin ):
145
- self .model .fit (
146
- X = self .datapre_control .to_numpy (), y = self .datapre_treated .to_numpy ()
147
- )
143
+ self .model .fit (X = self .datapre_control , y = self .datapre_treated )
148
144
else :
149
145
raise ValueError ("Model type not recognized" )
150
146
@@ -154,20 +150,10 @@ def __init__(
154
150
)
155
151
156
152
# get the model predictions of the observed (pre-intervention) data
157
- self .pre_pred = self .model .predict (X = self .datapre_control . to_numpy () )
153
+ self .pre_pred = self .model .predict (X = self .datapre_control )
158
154
159
155
# calculate the counterfactual
160
- self .post_pred = self .model .predict (X = self .datapost_control .to_numpy ())
161
- # TODO: Remove the need for this 'hack' by properly updating the coords when we
162
- # run model.predict
163
- # TEMPORARY HACK: --------------------------------------------------------------
164
- # : set the coords (obs_ind) for self.post_pred to be the same as the datapost
165
- # index. This is needed for xarray to properly do the comparison (-) between
166
- # datapre_treated and self.post_pred
167
- # self.post_pred["posterior_predictive"] = self.post_pred[
168
- # "posterior_predictive"
169
- # ].assign_coords(obs_ind=self.datapost.index)
170
- # ------------------------------------------------------------------------------
156
+ self .post_pred = self .model .predict (X = self .datapost_control )
171
157
self .pre_impact = self .model .calculate_impact (
172
158
self .datapre_treated , self .pre_pred
173
159
)
0 commit comments