@@ -265,11 +265,11 @@ def __init__(
265
265
# Input validation ----------------------------------------------------
266
266
# Check that `treated` appears in the module formula
267
267
assert (
268
- "treated " in formula
269
- ), "A predictor column called `treated ` should be in the provided dataframe"
268
+ "post_treatment " in formula
269
+ ), "A predictor called `post_treatment ` should be in the dataframe"
270
270
# Check that we have `treated` in the incoming dataframe
271
271
assert (
272
- "treated " in self .data .columns
272
+ "post_treatment " in self .data .columns
273
273
), "Require a boolean column labelling observations which are `treated`"
274
274
# Check for `unit` in the incoming dataframe.
275
275
# *This is only used for plotting purposes*
@@ -289,46 +289,45 @@ def __init__(
289
289
.I.e. the treated and untreated.
290
290
"""
291
291
292
- # TODO: `treated` is a deterministic function of group and time, so this could
293
- # be a function rather than supplied data
294
-
295
292
# DEVIATION FROM SKL EXPERIMENT CODE =============================
296
- # fit the model to the observed (pre-intervention) data
297
293
COORDS = {"coeffs" : self .labels , "obs_indx" : np .arange (self .X .shape [0 ])}
298
294
self .prediction_model .fit (X = self .X , y = self .y , coords = COORDS )
299
295
# ================================================================
300
296
301
- time_levels = self .data [self .time_variable_name ].unique ()
302
-
303
297
# predicted outcome for control group
304
- self .x_pred_control = pd . DataFrame (
305
- {
306
- self . group_variable_name : [ self . untreated , self . untreated ],
307
- self .time_variable_name : time_levels ,
308
- "treated" : [ 0 , 0 ],
309
- }
298
+ self .x_pred_control = (
299
+ self . data
300
+ # just the untreated group
301
+ . query ( f"district == ' { self .untreated } '" )
302
+ # drop the outcome variable
303
+ . drop ( self . outcome_variable_name , axis = 1 )
310
304
)
311
305
(new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_control )
312
306
self .y_pred_control = self .prediction_model .predict (np .asarray (new_x ))
313
307
314
308
# predicted outcome for treatment group
315
- self .x_pred_treatment = pd . DataFrame (
316
- {
317
- self . group_variable_name : [ self . treated , self . treated ],
318
- self .time_variable_name : time_levels ,
319
- "treated" : [ 0 , 1 ],
320
- }
309
+ self .x_pred_treatment = (
310
+ self . data
311
+ # just the treated group
312
+ . query ( f"district == ' { self .treated } '" )
313
+ # drop the outcome variable
314
+ . drop ( self . outcome_variable_name , axis = 1 )
321
315
)
322
316
(new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_treatment )
323
317
self .y_pred_treatment = self .prediction_model .predict (np .asarray (new_x ))
324
318
325
319
# predicted outcome for counterfactual
326
- self .x_pred_counterfactual = pd .DataFrame (
327
- {
328
- self .group_variable_name : [self .treated ],
329
- self .time_variable_name : time_levels [1 ],
330
- "treated" : [0 ],
331
- }
320
+ self .x_pred_counterfactual = (
321
+ self .data
322
+ # just the treated group
323
+ .query (f"district == '{ self .treated } '" )
324
+ # just the treatment period(s)
325
+ # TODO: the line below might need some work to be more robust
326
+ .query ("post_treatment == True" )
327
+ # drop the outcome variable
328
+ .drop (self .outcome_variable_name , axis = 1 )
329
+ # DO AN INTERVENTION. Set the post_treatment variable to False
330
+ .assign (post_treatment = False )
332
331
)
333
332
(new_x ,) = build_design_matrices (
334
333
[self ._x_design_info ], self .x_pred_counterfactual
@@ -340,14 +339,6 @@ def __init__(
340
339
self .y_pred_treatment ["posterior_predictive" ].mu .isel ({"obs_ind" : 1 })
341
340
- self .y_pred_counterfactual ["posterior_predictive" ].mu .squeeze ()
342
341
)
343
- # self.causal_impact = (
344
- # self.y_pred_treatment["posterior_predictive"]
345
- # .mu.isel({"obs_ind": 1})
346
- # .stack(samples=["chain", "draw"])
347
- # - self.y_pred_counterfactual["posterior_predictive"]
348
- # .mu.stack(samples=["chain", "draw"])
349
- # .squeeze()
350
- # )
351
342
352
343
def plot (self ):
353
344
"""Plot the results"""
0 commit comments