1
+ from typing import Union
2
+
1
3
import arviz as az
2
4
import matplotlib .pyplot as plt
3
5
import numpy as np
6
8
import xarray as xr
7
9
from patsy import build_design_matrices , dmatrices
8
10
11
+ from causalpy .custom_exceptions import BadIndexException # NOQA
12
+ from causalpy .custom_exceptions import DataException , FormulaException
9
13
from causalpy .plot_utils import plot_xY
14
+ from causalpy .utils import _is_variable_dummy_coded , _series_has_2_levels
10
15
11
16
LEGEND_FONT_SIZE = 12
12
17
az .style .use ("arviz-darkgrid" )
@@ -54,12 +59,14 @@ class TimeSeriesExperiment(ExperimentalDesign):
54
59
def __init__ (
55
60
self ,
56
61
data : pd .DataFrame ,
57
- treatment_time : int ,
62
+ treatment_time : Union [ int , float , pd . Timestamp ] ,
58
63
formula : str ,
59
64
model = None ,
60
65
** kwargs ,
61
66
) -> None :
62
67
super ().__init__ (model = model , ** kwargs )
68
+ self ._input_validation (data , treatment_time )
69
+
63
70
self .treatment_time = treatment_time
64
71
# split data in to pre and post intervention
65
72
self .datapre = data [data .index <= self .treatment_time ]
@@ -111,6 +118,21 @@ def __init__(
111
118
# cumulative impact post
112
119
self .post_impact_cumulative = self .post_impact .cumsum (dim = "obs_ind" )
113
120
121
+ def _input_validation (self , data , treatment_time ):
122
+ """Validate the input data and model formula for correctness"""
123
+ if isinstance (data .index , pd .DatetimeIndex ) and not isinstance (
124
+ treatment_time , pd .Timestamp
125
+ ):
126
+ raise BadIndexException (
127
+ "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
128
+ )
129
+ if not isinstance (data .index , pd .DatetimeIndex ) and isinstance (
130
+ treatment_time , pd .Timestamp
131
+ ):
132
+ raise BadIndexException (
133
+ "If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
134
+ )
135
+
114
136
def plot (self ):
115
137
116
138
"""Plot the results"""
@@ -263,36 +285,15 @@ def __init__(
263
285
self .formula = formula
264
286
self .time_variable_name = time_variable_name
265
287
self .group_variable_name = group_variable_name
288
+ self ._input_validation ()
289
+
266
290
y , X = dmatrices (formula , self .data )
267
291
self ._y_design_info = y .design_info
268
292
self ._x_design_info = X .design_info
269
293
self .labels = X .design_info .column_names
270
294
self .y , self .X = np .asarray (y ), np .asarray (X )
271
295
self .outcome_variable_name = y .design_info .column_names [0 ]
272
296
273
- # Input validation ----------------------------------------------------
274
- assert (
275
- "post_treatment" in formula
276
- ), "A predictor called `post_treatment` should be in the dataframe"
277
- assert (
278
- "post_treatment" in self .data .columns
279
- ), "Require a boolean column labelling observations which are `treated`"
280
- # Check for `unit` in the incoming dataframe.
281
- # *This is only used for plotting purposes*
282
- assert (
283
- "unit" in self .data .columns
284
- ), """
285
- Require a `unit` column to label unique units.
286
- This is used for plotting purposes
287
- """
288
- # Check that `group_variable_name` is dummy coded. It should be 0 or 1
289
- assert not set (self .data [self .group_variable_name ]).difference (
290
- set ([0 , 1 ])
291
- ), f"""
292
- The grouping variable { self .group_variable_name } should be dummy coded.
293
- Consisting of 0's and 1's only.
294
- """
295
-
296
297
COORDS = {"coeffs" : self .labels , "obs_indx" : np .arange (self .X .shape [0 ])}
297
298
self .model .fit (X = self .X , y = self .y , coords = COORDS )
298
299
@@ -361,6 +362,29 @@ def __init__(
361
362
if "post_treatment" in label and self .group_variable_name in label :
362
363
self .causal_impact = self .idata .posterior ["beta" ].isel ({"coeffs" : i })
363
364
365
+ def _input_validation (self ):
366
+ """Validate the input data and model formula for correctness"""
367
+ if "post_treatment" not in self .formula :
368
+ raise FormulaException (
369
+ "A predictor called `post_treatment` should be in the formula"
370
+ )
371
+
372
+ if "post_treatment" not in self .data .columns :
373
+ raise DataException (
374
+ "Require a boolean column labelling observations which are `treated`"
375
+ )
376
+
377
+ if "unit" not in self .data .columns :
378
+ raise DataException (
379
+ "Require a `unit` column to label unique units. This is used for plotting purposes" # noqa: E501
380
+ )
381
+
382
+ if _is_variable_dummy_coded (self .data [self .group_variable_name ]) is False :
383
+ raise DataException (
384
+ f"""The grouping variable { self .group_variable_name } should be dummy
385
+ coded. Consisting of 0's and 1's only."""
386
+ )
387
+
364
388
def plot (self ):
365
389
"""Plot the results.
366
390
Creating the combined mean + HDI legend entries is a bit involved.
@@ -536,16 +560,15 @@ def __init__(
536
560
self .formula = formula
537
561
self .running_variable_name = running_variable_name
538
562
self .treatment_threshold = treatment_threshold
563
+ self ._input_validation ()
564
+
539
565
y , X = dmatrices (formula , self .data )
540
566
self ._y_design_info = y .design_info
541
567
self ._x_design_info = X .design_info
542
568
self .labels = X .design_info .column_names
543
569
self .y , self .X = np .asarray (y ), np .asarray (X )
544
570
self .outcome_variable_name = y .design_info .column_names [0 ]
545
571
546
- # TODO: `treated` is a deterministic function of x and treatment_threshold, so
547
- # this could be a function rather than supplied data
548
-
549
572
# DEVIATION FROM SKL EXPERIMENT CODE =============================
550
573
# fit the model to the observed (pre-intervention) data
551
574
COORDS = {"coeffs" : self .labels , "obs_indx" : np .arange (self .X .shape [0 ])}
@@ -586,6 +609,18 @@ def __init__(
586
609
- self .pred_discon ["posterior_predictive" ].sel (obs_ind = 0 )["mu" ]
587
610
)
588
611
612
+ def _input_validation (self ):
613
+ """Validate the input data and model formula for correctness"""
614
+ if "treated" not in self .formula :
615
+ raise FormulaException (
616
+ "A predictor called `treated` should be in the formula"
617
+ )
618
+
619
+ if _is_variable_dummy_coded (self .data ["treated" ]) is False :
620
+ raise DataException (
621
+ """The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
622
+ )
623
+
589
624
def _is_treated (self , x ):
590
625
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.
591
626
@@ -673,6 +708,7 @@ def __init__(
673
708
self .formula = formula
674
709
self .group_variable_name = group_variable_name
675
710
self .pretreatment_variable_name = pretreatment_variable_name
711
+ self ._input_validation ()
676
712
677
713
y , X = dmatrices (formula , self .data )
678
714
self ._y_design_info = y .design_info
@@ -681,17 +717,6 @@ def __init__(
681
717
self .y , self .X = np .asarray (y ), np .asarray (X )
682
718
self .outcome_variable_name = y .design_info .column_names [0 ]
683
719
684
- # Input validation ----------------------------------------------------
685
- # Check that `group_variable_name` has TWO levels, representing the
686
- # treated/untreated. But it does not matter what the actual names of
687
- # the levels are.
688
- assert (
689
- len (pd .Categorical (self .data [self .group_variable_name ]).categories ) == 2
690
- ), f"""
691
- There must be 2 levels of the grouping variable { self .group_variable_name }
692
- .I.e. the treated and untreated.
693
- """
694
-
695
720
# fit the model to the observed (pre-intervention) data
696
721
COORDS = {"coeffs" : self .labels , "obs_indx" : np .arange (self .X .shape [0 ])}
697
722
self .model .fit (X = self .X , y = self .y , coords = COORDS )
@@ -730,6 +755,16 @@ def __init__(
730
755
731
756
# ================================================================
732
757
758
+ def _input_validation (self ):
759
+ """Validate the input data and model formula for correctness"""
760
+ if not _series_has_2_levels (self .data [self .group_variable_name ]):
761
+ raise DataException (
762
+ f"""
763
+ There must be 2 levels of the grouping variable
764
+ { self .group_variable_name } . I.e. the treated and untreated.
765
+ """
766
+ )
767
+
733
768
def plot (self ):
734
769
"""Plot the results"""
735
770
fig , ax = plt .subplots (
0 commit comments