@@ -40,54 +40,38 @@ class HandlerUTT:
4040 with unknown treatment intervention times.
4141 """
4242
43- def data_preprocessing (self , data , treatment_time , formula , model ):
43+ def data_preprocessing (self , data , treatment_time , model ):
4444 """
45- Preprocess the data using patsy for fittng into the model and update the model with required infos
45+ Preprocess the input data and update the model's treatment time constraints.
4646 """
47- y , X = dmatrices (formula , data )
4847 # Restrict model's treatment time inference to given range
4948 model .set_time_range (treatment_time , data )
50- # Needed to track time evolution across model predictions
51- model .set_timeline (X .design_info .column_names .index ("t" ))
52- return y , X
49+ return data
5350
5451 def data_postprocessing (self , data , idata , treatment_time , pre_y , pre_X ):
5552 """
56- Postprocess the data accordingly to the inferred treatment time for calculation and plot purpose
53+ Postprocess data based on the inferred treatment time for further analysis and plotting.
5754 """
58- # Retrieve posterior mean of inferred treatment time
55+ # --- Inferred treatment time ---
5956 treatment_time_mean = idata .posterior ["treatment_time" ].mean ().item ()
60- inferred_time = int (treatment_time_mean )
57+ inferred_treatment_time = int (treatment_time_mean )
58+ idx_treatment_time = data .index [data ["t" ] == inferred_treatment_time ][0 ]
6159
62- # Safety check: ensure the inferred time is present in the dataset
63- if inferred_time not in data ["t" ].values :
64- raise ValueError (
65- f"Inferred treatment time { inferred_time } not found in data['t']."
66- )
67-
68- # Convert the inferred time to its corresponding DataFrame index
69- inferred_index = data [data ["t" ] == inferred_time ].index [0 ]
70-
71- # Retrieve HDI bounds of treatment time (uncertainty interval)
60+ # --- HDI bounds (credible interval) ---
7261 hdi_bounds = az .hdi (idata , var_names = ["treatment_time" ])[
7362 "treatment_time"
7463 ].values
7564 hdi_start_time = int (hdi_bounds [0 ])
65+ indice = data .index .get_loc (data .index [data ["t" ] == hdi_start_time ][0 ])
7666
77- # Convert HDI lower bound to DataFrame index for slicing
78- if hdi_start_time not in data ["t" ].values :
79- raise ValueError (f"HDI start time { hdi_start_time } not found in data['t']." )
80-
81- hdi_start_idx_df = data [data ["t" ] == hdi_start_time ].index [0 ]
82- hdi_start_idx_np = data .index .get_loc (hdi_start_idx_df )
67+ # --- Slicing ---
68+ datapre = data [data ["t" ] < hdi_start_time ]
69+ datapost = data [data ["t" ] >= hdi_start_time ]
8370
84- # Slice both pandas and numpy objects accordingly
85- df_pre = data [data .index < hdi_start_idx_df ]
86- df_post = data [data .index >= hdi_start_idx_df ]
87- truncated_y = pre_y [:hdi_start_idx_np ]
88- truncated_X = pre_X [:hdi_start_idx_np ]
71+ truncated_y = pre_y .isel (obs_ind = slice (0 , indice ))
72+ truncated_X = pre_X .isel (obs_ind = slice (0 , indice ))
8973
90- return df_pre , df_post , truncated_y , truncated_X , inferred_index
74+ return datapre , datapost , truncated_y , truncated_X , idx_treatment_time
9175
9276 def plot_intervention_line (self , ax , idata , datapost , treatment_time ):
9377 """
@@ -144,16 +128,16 @@ class HandlerKTT:
144128 where the treatment time is known in advance.
145129 """
146130
147- def data_preprocessing (self , data , treatment_time , formula , model ):
131+ def data_preprocessing (self , data , treatment_time , model ):
148132 """
149- Preprocess the data using patsy for fitting into the model
133+ Preprocess the data by selecting only the pre-treatment period for model fitting.
150134 """
151135 # Use only data before treatment for training the model
152- return dmatrices ( formula , data [data .index < treatment_time ])
136+ return data [data .index < treatment_time ]
153137
154138 def data_postprocessing (self , data , idata , treatment_time , pre_y , pre_X ):
155139 """
156- Postprocess data by splitting it into pre- and post-intervention periods, using the known treatment time.
140+ Split data into pre- and post-treatment periods using the known treatment time.
157141 """
158142 return (
159143 data [data .index < treatment_time ],
@@ -165,7 +149,7 @@ def data_postprocessing(self, data, idata, treatment_time, pre_y, pre_X):
165149
166150 def plot_intervention_line (self , ax , idata , datapost , treatment_time ):
167151 """
168- Plot a vertical line at the known treatment time.
152+ Plot a vertical line at the known treatment time on provided axes .
169153 """
170154 # --- Plot a vertical line at the known treatment time
171155 for i in [0 , 1 , 2 ]:
@@ -177,7 +161,7 @@ def plot_treated_counterfactual(
177161 self , sax , handles , labels , datapost , post_pred , post_y
178162 ):
179163 """
180- Placeholder method to maintain interface compatibility.
164+ Placeholder method to maintain interface compatibility with HandlerUTT .
181165 """
182166 pass
183167
@@ -236,7 +220,6 @@ def __init__(
236220 # rename the index to "obs_ind"
237221 data .index .name = "obs_ind"
238222 self .input_validation (data , treatment_time , model )
239- self .treatment_time = treatment_time
240223 # set experiment type - usually done in subclasses
241224 self .expt_type = "Pre-Post Fit"
242225
@@ -249,27 +232,41 @@ def __init__(
249232 else :
250233 self .handler = HandlerKTT ()
251234
252- # set experiment type - usually done in subclasses
253- self .expt_type = "Pre-Post Fit"
254-
255235 # Preprocessing based on handler type
256- y , X = self .handler .data_preprocessing (
257- data , self .treatment_time , formula , self .model
236+ self . datapre = self .handler .data_preprocessing (
237+ data , self .treatment_time , self .model
258238 )
259239
240+ y , X = dmatrices (formula , self .datapre )
260241 # set things up with pre-intervention data
261242 self .outcome_variable_name = y .design_info .column_names [0 ]
262243 self ._y_design_info = y .design_info
263244 self ._x_design_info = X .design_info
264245 self .labels = X .design_info .column_names
265246 self .pre_y , self .pre_X = np .asarray (y ), np .asarray (X )
266247
248+ # turn into xarray.DataArray's
249+ self .pre_X = xr .DataArray (
250+ self .pre_X ,
251+ dims = ["obs_ind" , "coeffs" ],
252+ coords = {
253+ "obs_ind" : self .datapre .index ,
254+ "coeffs" : self .labels ,
255+ },
256+ )
257+ self .pre_y = xr .DataArray (
258+ self .pre_y [:, 0 ],
259+ dims = ["obs_ind" ],
260+ coords = {"obs_ind" : self .datapre .index },
261+ )
262+
267263 # fit the model to the observed (pre-intervention) data
268264 if isinstance (self .model , PyMCModel ):
269- COORDS = {"coeffs" : self .labels , "obs_ind" : np .arange (self . pre_X .shape [0 ])}
270- self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
265+ COORDS = {"coeffs" : self .labels , "obs_ind" : np .arange (X .shape [0 ])}
266+ idata = self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
271267 elif isinstance (self .model , RegressorMixin ):
272268 self .model .fit (X = self .pre_X , y = self .pre_y )
269+ idata = None
273270 else :
274271 raise ValueError ("Model type not recognized" )
275272
@@ -279,7 +276,7 @@ def __init__(
279276 # Postprocessing with handler
280277 self .datapre , self .datapost , self .pre_y , self .pre_X , self .treatment_time = (
281278 self .handler .data_postprocessing (
282- data , self . idata , treatment_time , self .pre_y , self .pre_X
279+ data , idata , treatment_time , self .pre_y , self .pre_X
283280 )
284281 )
285282
@@ -292,20 +289,6 @@ def __init__(
292289 )
293290 self .post_X = np .asarray (new_x )
294291 self .post_y = np .asarray (new_y )
295- # turn into xarray.DataArray's
296- self .pre_X = xr .DataArray (
297- self .pre_X ,
298- dims = ["obs_ind" , "coeffs" ],
299- coords = {
300- "obs_ind" : self .datapre .index ,
301- "coeffs" : self .labels ,
302- },
303- )
304- self .pre_y = xr .DataArray (
305- self .pre_y [:, 0 ],
306- dims = ["obs_ind" ],
307- coords = {"obs_ind" : self .datapre .index },
308- )
309292 self .post_X = xr .DataArray (
310293 self .post_X ,
311294 dims = ["obs_ind" , "coeffs" ],
0 commit comments