@@ -244,7 +244,6 @@ def __init__(
244
244
self .y , self .X = np .asarray (y ), np .asarray (X )
245
245
self .outcome_variable_name = y .design_info .column_names [0 ]
246
246
247
-
248
247
# Input validation ----------------------------------------------------
249
248
# Check that `treated` appears in the module formula
250
249
assert (
@@ -254,17 +253,26 @@ def __init__(
254
253
assert (
255
254
"treated" in self .data .columns
256
255
), "Require a boolean column labelling observations which are `treated`"
257
- # Check for `unit` in the incoming dataframe. *This is only used for plotting purposes*
256
+ # Check for `unit` in the incoming dataframe.
257
+ # *This is only used for plotting purposes*
258
258
assert (
259
259
"unit" in self .data .columns
260
- ), "Require a `unit` column to label unique units. This is used for plotting purposes"
261
- # Check that `group_variable_name` has TWO levels, representing the treated/untreated.
262
- # But it does not matter what the actual names of the levels are.
260
+ ), """
261
+ Require a `unit` column to label unique units.
262
+ This is used for plotting purposes
263
+ """
264
+ # Check that `group_variable_name` has TWO levels, representing the
265
+ # treated/untreated. But it does not matter what the actual names of
266
+ # the levels are.
263
267
assert (
264
- len (pd .Categorical (self .data [self .group_variable_name ]).categories ) is 2
265
- ), f"There must be 2 levels of the grouping variable { self .group_variable_name } .I.e. the treated and untreated."
268
+ len (pd .Categorical (self .data [self .group_variable_name ]).categories ) == 2
269
+ ), f"""
270
+ There must be 2 levels of the grouping variable { self .group_variable_name }
271
+ .I.e. the treated and untreated.
272
+ """
266
273
267
- # TODO: `treated` is a deterministic function of group and time, so this could be a function rather than supplied data
274
+ # TODO: `treated` is a deterministic function of group and time, so this could
275
+ # be a function rather than supplied data
268
276
269
277
# DEVIATION FROM SKL EXPERIMENT CODE =============================
270
278
# fit the model to the observed (pre-intervention) data
@@ -369,7 +377,8 @@ def plot(self):
369
377
pc .set_facecolor ("C1" )
370
378
pc .set_edgecolor ("None" )
371
379
pc .set_alpha (0.5 )
372
- # Plot counterfactual - post-test for treatment group IF no treatment had occurred.
380
+ # Plot counterfactual - post-test for treatment group IF no treatment
381
+ # had occurred.
373
382
parts = ax .violinplot (
374
383
az .extract (
375
384
self .y_pred_counterfactual ,
@@ -397,7 +406,8 @@ def plot(self):
397
406
398
407
def _plot_causal_impact_arrow (self , ax ):
399
408
"""
400
- draw a vertical arrow between `y_pred_counterfactual` and `y_pred_counterfactual`
409
+ draw a vertical arrow between `y_pred_counterfactual` and
410
+ `y_pred_counterfactual`
401
411
"""
402
412
# Calculate y values to plot the arrow between
403
413
y_pred_treatment = (
0 commit comments