@@ -92,7 +92,7 @@ def __init__(
9292 # Set the data according to if the model is
9393 if treatment_time is None or isinstance (treatment_time , tuple ):
9494 self .datapre = data
95- self .model .set_time_range (self .treatment_time )
95+ self .model .set_time_range (self .treatment_time , self . datapre )
9696 else :
9797 # split data in to pre and post intervention
9898 self .datapre = data [data .index < self .treatment_time ]
@@ -120,11 +120,18 @@ def __init__(
120120 self .score = self .model .score (X = self .pre_X , y = self .pre_y )
121121
122122 if treatment_time is None or isinstance (treatment_time , tuple ):
123- self .treatment_time = int (
123+ # We're getting the inferred switchpoint as one of the values of the timeline, from the last column
124+ switchpoint = int (
124125 az .extract (idata , group = "posterior" , var_names = "switchpoint" )
125126 .mean ("sample" )
126127 .values
127128 )
129+
130+ # we're getting the associated index of that switchpoint
131+ last_column = data .columns [- 1 ]
132+ self .treatment_time = data [data [last_column ] == switchpoint ].index [0 ]
133+
134+ # We're getting datapre as intended for prediction
128135 self .datapre = data [data .index < self .treatment_time ]
129136 (new_y , new_x ) = build_design_matrices (
130137 [self ._y_design_info , self ._x_design_info ], self .datapre
@@ -155,22 +162,20 @@ def __init__(
155162
156163 def input_validation (self , data , treatment_time , model ):
157164 """Validate the input data and model formula for correctness"""
158- if treatment_time is None and not hasattr (model , "set_time_range" ):
159- raise ModelException (
160- "If treatment_time is None, provided model must have a 'set_time_range' method"
161- )
162- elif isinstance (treatment_time , tuple ) and not hasattr (model , "set_time_range" ):
165+ if isinstance (treatment_time , (type (None ), tuple )) and not hasattr (
166+ model , "set_time_range"
167+ ):
163168 raise ModelException (
164- "If treatment_time is a tuple, provided model must have a 'set_time_range' method"
169+ "If treatment_time is None or a tuple, provided model must have a 'set_time_range' method"
165170 )
166- elif isinstance (data .index , pd .DatetimeIndex ) and not isinstance (
167- treatment_time , pd .Timestamp
171+ if isinstance (data .index , pd .DatetimeIndex ) and not isinstance (
172+ treatment_time , ( pd .Timestamp , tuple , type ( None ))
168173 ):
169174 raise BadIndexException (
170175 "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
171176 )
172- elif not isinstance (data .index , pd .DatetimeIndex ) and isinstance (
173- treatment_time , pd .Timestamp
177+ if not isinstance (data .index , pd .DatetimeIndex ) and isinstance (
178+ treatment_time , ( pd .Timestamp )
174179 ):
175180 raise BadIndexException (
176181 "If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
0 commit comments