@@ -62,16 +62,7 @@ def __init__(
62
62
** kwargs ,
63
63
) -> None :
64
64
super ().__init__ (model = model , ** kwargs )
65
-
66
- # Input validation
67
- if isinstance (data .index , pd .DatetimeIndex ):
68
- assert isinstance (
69
- treatment_time , pd .Timestamp
70
- ), "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
71
- else :
72
- assert (
73
- isinstance (treatment_time , pd .Timestamp ) is False
74
- ), "If treatment_time is pd.Timestamp, this only makese sense if data.index is DatetimeIndex." # noqa: E501
65
+ self ._input_validation (data , treatment_time )
75
66
76
67
self .treatment_time = treatment_time
77
68
# split data in to pre and post intervention
@@ -124,6 +115,17 @@ def __init__(
124
115
# cumulative impact post
125
116
self .post_impact_cumulative = self .post_impact .cumsum (dim = "obs_ind" )
126
117
118
+ def _input_validation (self , data , treatment_time ):
119
+ """Validate the input data for correctness"""
120
+ if isinstance (data .index , pd .DatetimeIndex ):
121
+ assert isinstance (
122
+ treatment_time , pd .Timestamp
123
+ ), "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
124
+ else :
125
+ assert (
126
+ isinstance (treatment_time , pd .Timestamp ) is False
127
+ ), "If treatment_time is pd.Timestamp, this only makese sense if data.index is DatetimeIndex." # noqa: E501
128
+
127
129
def plot (self ):
128
130
129
131
"""Plot the results"""
@@ -276,36 +278,15 @@ def __init__(
276
278
self .formula = formula
277
279
self .time_variable_name = time_variable_name
278
280
self .group_variable_name = group_variable_name
281
+ self ._input_validation ()
282
+
279
283
y , X = dmatrices (formula , self .data )
280
284
self ._y_design_info = y .design_info
281
285
self ._x_design_info = X .design_info
282
286
self .labels = X .design_info .column_names
283
287
self .y , self .X = np .asarray (y ), np .asarray (X )
284
288
self .outcome_variable_name = y .design_info .column_names [0 ]
285
289
286
- # Input validation ----------------------------------------------------
287
- assert (
288
- "post_treatment" in formula
289
- ), "A predictor called `post_treatment` should be in the dataframe"
290
- assert (
291
- "post_treatment" in self .data .columns
292
- ), "Require a boolean column labelling observations which are `treated`"
293
- # Check for `unit` in the incoming dataframe.
294
- # *This is only used for plotting purposes*
295
- assert (
296
- "unit" in self .data .columns
297
- ), """
298
- Require a `unit` column to label unique units.
299
- This is used for plotting purposes
300
- """
301
- # Check that `group_variable_name` is dummy coded. It should be 0 or 1
302
- assert not set (self .data [self .group_variable_name ]).difference (
303
- set ([0 , 1 ])
304
- ), f"""
305
- The grouping variable { self .group_variable_name } should be dummy coded.
306
- Consisting of 0's and 1's only.
307
- """
308
-
309
290
COORDS = {"coeffs" : self .labels , "obs_indx" : np .arange (self .X .shape [0 ])}
310
291
self .model .fit (X = self .X , y = self .y , coords = COORDS )
311
292
@@ -374,6 +355,30 @@ def __init__(
374
355
if "post_treatment" in label and self .group_variable_name in label :
375
356
self .causal_impact = self .idata .posterior ["beta" ].isel ({"coeffs" : i })
376
357
358
+ def _input_validation (self ):
359
+ """Validate the input data for correctness"""
360
+ assert (
361
+ "post_treatment" in self .formula
362
+ ), "A predictor called `post_treatment` should be in the dataframe"
363
+ assert (
364
+ "post_treatment" in self .data .columns
365
+ ), "Require a boolean column labelling observations which are `treated`"
366
+ # Check for `unit` in the incoming dataframe.
367
+ # *This is only used for plotting purposes*
368
+ assert (
369
+ "unit" in self .data .columns
370
+ ), """
371
+ Require a `unit` column to label unique units.
372
+ This is used for plotting purposes
373
+ """
374
+ # Check that `group_variable_name` is dummy coded. It should be 0 or 1
375
+ assert not set (self .data [self .group_variable_name ]).difference (
376
+ set ([0 , 1 ])
377
+ ), f"""
378
+ The grouping variable { self .group_variable_name } should be dummy coded.
379
+ Consisting of 0's and 1's only.
380
+ """
381
+
377
382
def plot (self ):
378
383
"""Plot the results.
379
384
Creating the combined mean + HDI legend entries is a bit involved.
@@ -686,6 +691,7 @@ def __init__(
686
691
self .formula = formula
687
692
self .group_variable_name = group_variable_name
688
693
self .pretreatment_variable_name = pretreatment_variable_name
694
+ self ._input_validation ()
689
695
690
696
y , X = dmatrices (formula , self .data )
691
697
self ._y_design_info = y .design_info
@@ -694,17 +700,6 @@ def __init__(
694
700
self .y , self .X = np .asarray (y ), np .asarray (X )
695
701
self .outcome_variable_name = y .design_info .column_names [0 ]
696
702
697
- # Input validation ----------------------------------------------------
698
- # Check that `group_variable_name` has TWO levels, representing the
699
- # treated/untreated. But it does not matter what the actual names of
700
- # the levels are.
701
- assert (
702
- len (pd .Categorical (self .data [self .group_variable_name ]).categories ) == 2
703
- ), f"""
704
- There must be 2 levels of the grouping variable { self .group_variable_name }
705
- .I.e. the treated and untreated.
706
- """
707
-
708
703
# fit the model to the observed (pre-intervention) data
709
704
COORDS = {"coeffs" : self .labels , "obs_indx" : np .arange (self .X .shape [0 ])}
710
705
self .model .fit (X = self .X , y = self .y , coords = COORDS )
@@ -743,6 +738,18 @@ def __init__(
743
738
744
739
# ================================================================
745
740
741
+ def _input_validation (self ):
742
+ """Validate the input data for correctness"""
743
+ # Check that `group_variable_name` has TWO levels, representing the
744
+ # treated/untreated. But it does not matter what the actual names of
745
+ # the levels are.
746
+ assert (
747
+ len (pd .Categorical (self .data [self .group_variable_name ]).categories ) == 2
748
+ ), f"""
749
+ There must be 2 levels of the grouping variable { self .group_variable_name }
750
+ .I.e. the treated and untreated.
751
+ """
752
+
746
753
def plot (self ):
747
754
"""Plot the results"""
748
755
fig , ax = plt .subplots (
0 commit comments