3333LEGEND_FONT_SIZE = 12
3434
3535
36+ class HandlerUTT :
37+ """
38+ Handle data preprocessing, postprocessing, and plotting steps for models
39+ with unknown treatment intervention times.
40+ """
41+
42+ def data_preprocessing (self , data , treatment_time , formula , model ):
43+ """
44+ Preprocess the data using patsy for fittng into the model and update the model with required infos
45+ """
46+ y , X = dmatrices (formula , data )
47+ # Restrict model's treatment time inference to given range
48+ model .set_time_range (treatment_time , data )
49+ # Needed to track time evolution across model predictions
50+ model .set_timeline (X .design_info .column_names .index ("t" ))
51+ return y , X
52+
53+ def data_postprocessing (self , data , idata , treatment_time , pre_y , pre_X ):
54+ """
55+ Postprocess the data accordingly to the inferred treatment time for calculation and plot purpose
56+ """
57+ # Retrieve posterior mean of inferred treatment time
58+ treatment_time_mean = idata .posterior ["treatment_time" ].mean ().item ()
59+ inferred_time = int (treatment_time_mean )
60+
61+ # Safety check: ensure the inferred time is present in the dataset
62+ if inferred_time not in data ["t" ].values :
63+ raise ValueError (
64+ f"Inferred treatment time { inferred_time } not found in data['t']."
65+ )
66+
67+ # Convert the inferred time to its corresponding DataFrame index
68+ inferred_index = data [data ["t" ] == inferred_time ].index [0 ]
69+
70+ # Retrieve HDI bounds of treatment time (uncertainty interval)
71+ hdi_bounds = az .hdi (idata , var_names = ["treatment_time" ])[
72+ "treatment_time"
73+ ].values
74+ hdi_start_time = int (hdi_bounds [0 ])
75+
76+ # Convert HDI lower bound to DataFrame index for slicing
77+ if hdi_start_time not in data ["t" ].values :
78+ raise ValueError (f"HDI start time { hdi_start_time } not found in data['t']." )
79+
80+ hdi_start_idx_df = data [data ["t" ] == hdi_start_time ].index [0 ]
81+ hdi_start_idx_np = data .index .get_loc (hdi_start_idx_df )
82+
83+ # Slice both pandas and numpy objects accordingly
84+ df_pre = data [data .index < hdi_start_idx_df ]
85+ df_post = data [data .index >= hdi_start_idx_df ]
86+ truncated_y = pre_y [:hdi_start_idx_np ]
87+ truncated_X = pre_X [:hdi_start_idx_np ]
88+
89+ return df_pre , df_post , truncated_y , truncated_X , inferred_index
90+
91+ def plot_intervention_line (self , ax , idata , datapost , treatment_time ):
92+ """
93+ Plot a vertical line at the inferred treatment time, along with a shaded area
94+ representing the Highest Density Interval (HDI) of the inferred time.
95+ """
96+ # Extract the HDI (uncertainty interval) of the treatment time
97+ hdi = az .hdi (idata , var_names = ["treatment_time" ])["treatment_time" ].values
98+ x1 = datapost .index [datapost ["t" ] == int (hdi [0 ])][0 ]
99+ x2 = datapost .index [datapost ["t" ] == int (hdi [1 ])][0 ]
100+
101+ for i in [0 , 1 , 2 ]:
102+ ymin , ymax = ax [i ].get_ylim ()
103+
104+ # Vertical line for inferred treatment time
105+ ax [i ].plot (
106+ [treatment_time , treatment_time ],
107+ [ymin , ymax ],
108+ ls = "-" ,
109+ lw = 3 ,
110+ color = "r" ,
111+ solid_capstyle = "butt" ,
112+ )
113+
114+ # Shaded region for HDI of treatment time
115+ ax [i ].fill_betweenx (
116+ y = [ymin , ymax ],
117+ x1 = x1 ,
118+ x2 = x2 ,
119+ alpha = 0.1 ,
120+ color = "r" ,
121+ )
122+
123+ def plot_treated_counterfactual (
124+ self , ax , handles , labels , datapost , post_pred , post_y
125+ ):
126+ """
127+ Plot the inferred post-intervention trajectory (with treatment effect).
128+ """
129+ # --- Plot predicted trajectory under treatment (with HDI)
130+ h_line , h_patch = plot_xY (
131+ datapost .index ,
132+ post_pred ["posterior_predictive" ].mu_ts ,
133+ ax = ax [0 ],
134+ plot_hdi_kwargs = {"color" : "yellowgreen" },
135+ )
136+ handles .append ((h_line , h_patch ))
137+ labels .append ("treated counterfactual" )
138+
139+
140+ class HandlerKTT :
141+ """
142+ Handles data preprocessing, postprocessing, and plotting logic for models
143+ where the treatment time is known in advance.
144+ """
145+
146+ def data_preprocessing (self , data , treatment_time , formula , model ):
147+ """
148+ Preprocess the data using patsy for fitting into the model
149+ """
150+ # Use only data before treatment for training the model
151+ return dmatrices (formula , data [data .index < treatment_time ])
152+
153+ def data_postprocessing (self , data , idata , treatment_time , pre_y , pre_X ):
154+ """
155+ Postprocess data by splitting it into pre- and post-intervention periods, using the known treatment time.
156+ """
157+ return (
158+ data [data .index < treatment_time ],
159+ data [data .index >= treatment_time ],
160+ pre_y ,
161+ pre_X ,
162+ treatment_time ,
163+ )
164+
165+ def plot_intervention_line (self , ax , idata , datapost , treatment_time ):
166+ """
167+ Plot a vertical line at the known treatment time.
168+ """
169+ # --- Plot a vertical line at the known treatment time
170+ for i in [0 , 1 , 2 ]:
171+ ax [i ].axvline (
172+ x = treatment_time , ls = "-" , lw = 3 , color = "r" , solid_capstyle = "butt"
173+ )
174+
175+ def plot_treated_counterfactual (
176+ self , sax , handles , labels , datapost , post_pred , post_y
177+ ):
178+ """
179+ Placeholder method to maintain interface compatibility.
180+ """
181+ pass
182+
183+
36184class InterruptedTimeSeries (BaseExperiment ):
37185 """
38186 The class for interrupted time series analysis.
@@ -86,38 +234,33 @@ def __init__(
86234 self .input_validation (data , treatment_time , model )
87235
88236 self .treatment_time = treatment_time
89- # set experiment type - usually done in subclasses
90- self .expt_type = "Pre-Post Fit"
91- # set if the model is supposed to infer the treatment_time
92- self .infer_treatment_time = isinstance (self .treatment_time , (type (None ), tuple ))
237+ self .formula = formula
93238
94- # Set the data according to if the model is fitted on the whole bunch or not
95- if self . infer_treatment_time :
96- self .datapre = data
239+ # Getting the right handler
240+ if treatment_time is None or isinstance ( treatment_time , tuple ) :
241+ self .handler = HandlerUTT ()
97242 else :
98- # split data in to pre and post intervention
99- self .datapre = data [data .index < self .treatment_time ]
243+ self .handler = HandlerKTT ()
100244
101- self .formula = formula
245+ # set experiment type - usually done in subclasses
246+ self .expt_type = "Pre-Post Fit"
247+
248+ # Preprocessing based on handler type
249+ y , X = self .handler .data_preprocessing (
250+ data , self .treatment_time , formula , self .model
251+ )
102252
103253 # set things up with pre-intervention data
104- y , X = dmatrices (formula , self .datapre )
105254 self .outcome_variable_name = y .design_info .column_names [0 ]
106255 self ._y_design_info = y .design_info
107256 self ._x_design_info = X .design_info
108257 self .labels = X .design_info .column_names
109258 self .pre_y , self .pre_X = np .asarray (y ), np .asarray (X )
110259
111- # Setting the time range in which the model infers treatment_time
112- # Setting the timeline index so that the model can keep of time track between predicts
113- if self .infer_treatment_time :
114- self .model .set_time_range (self .treatment_time , self .datapre )
115- self .model .set_timeline (self .labels .index ("t" ))
116-
117260 # fit the model to the observed (pre-intervention) data
118261 if isinstance (self .model , PyMCModel ):
119262 COORDS = {"coeffs" : self .labels , "obs_ind" : np .arange (self .pre_X .shape [0 ])}
120- idata = self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
263+ self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
121264 elif isinstance (self .model , RegressorMixin ):
122265 self .model .fit (X = self .pre_X , y = self .pre_y )
123266 else :
@@ -126,29 +269,17 @@ def __init__(
126269 # score the goodness of fit to the pre-intervention data
127270 self .score = self .model .score (X = self .pre_X , y = self .pre_y )
128271
129- if self .infer_treatment_time :
130- # We're getting the inferred switchpoint as one of the values of the timeline, from the last column
131- switchpoint = int (
132- az .extract (idata , group = "posterior" , var_names = "switchpoint" )
133- .mean ("sample" )
134- .values
135- )
136- # we're getting the associated index of that switchpoint
137- self .treatment_time = data [data ["t" ] == switchpoint ].index [0 ]
138-
139- # We're getting datapre as intended for prediction
140- self .datapre = data [data .index < self .treatment_time ]
141- (new_y , new_x ) = build_design_matrices (
142- [self ._y_design_info , self ._x_design_info ], self .datapre
272+ # Postprocessing with handler
273+ self .datapre , self .datapost , self .pre_y , self .pre_X , self .treatment_time = (
274+ self .handler .data_postprocessing (
275+ data , self .idata , treatment_time , self .pre_y , self .pre_X
143276 )
144- self .pre_X = np .asarray (new_x )
145- self .pre_y = np .asarray (new_y )
277+ )
146278
147279 # get the model predictions of the observed (pre-intervention) data
148280 self .pre_pred = self .model .predict (X = self .pre_X )
149- # process post-intervention data
150- self .datapost = data [data .index >= self .treatment_time ]
151281
282+ # process post-intervention data
152283 (new_y , new_x ) = build_design_matrices (
153284 [self ._y_design_info , self ._x_design_info ], self .datapost
154285 )
@@ -211,6 +342,7 @@ def _bayesian_plot(
211342
212343 fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
213344 # TOP PLOT --------------------------------------------------
345+
214346 # pre-intervention period
215347 h_line , h_patch = plot_xY (
216348 self .datapre .index ,
@@ -225,6 +357,11 @@ def _bayesian_plot(
225357 handles .append (h )
226358 labels .append ("Observations" )
227359
360+ # Green line for treated counterfactual (if unknown treatment time)
361+ self .handler .plot_treated_counterfactual (
362+ ax , handles , labels , self .datapost , self .post_pred , self .post_y
363+ )
364+
228365 # post intervention period
229366 h_line , h_patch = plot_xY (
230367 self .datapost .index ,
@@ -289,14 +426,10 @@ def _bayesian_plot(
289426 )
290427 ax [2 ].axhline (y = 0 , c = "k" )
291428
292- # Intervention line
293- for i in [0 , 1 , 2 ]:
294- ax [i ].axvline (
295- x = self .treatment_time ,
296- ls = "-" ,
297- lw = 3 ,
298- color = "r" ,
299- )
429+ # Plot vertical line marking treatment time (with HDI if it's inferred)
430+ self .handler .plot_intervention_line (
431+ ax , self .idata , self .datapost , self .treatment_time
432+ )
300433
301434 ax [0 ].legend (
302435 handles = (h_tuple for h_tuple in handles ),
@@ -441,3 +574,14 @@ def get_plot_data_ols(self) -> pd.DataFrame:
441574 self .plot_data = pd .concat ([pre_data , post_data ])
442575
443576 return self .plot_data
577+
578+ def plot_treatment_time (self ):
579+ """
580+ display the posterior estimates of the treatment time
581+ """
582+ if "treatment_time" not in self .idata .posterior .data_vars :
583+ raise ValueError (
584+ "Variable 'treatment_time' not found in inference data (idata)."
585+ )
586+
587+ az .plot_trace (self .idata , var_names = "treatment_time" )
0 commit comments