@@ -48,40 +48,50 @@ def data_preprocessing(self, data, treatment_time, model):
4848 model .set_time_range (treatment_time , data )
4949 return data
5050
51- def data_postprocessing (self , data , idata , treatment_time , pre_y , pre_X ):
51+ def data_postprocessing (self , model , data , idata , treatment_time , pre_y , pre_X ):
5252 """
5353 Postprocess data based on the inferred treatment time for further analysis and plotting.
5454 """
55+ # --- Getting the time_variable_name ---
56+ time_variable_name = model .get_time_variable_name ()
57+
5558 # --- Inferred treatment time ---
5659 treatment_time_mean = idata .posterior ["treatment_time" ].mean ().item ()
5760 inferred_treatment_time = int (treatment_time_mean )
58- idx_treatment_time = data .index [data ["t" ] == inferred_treatment_time ][0 ]
61+ idx_treatment_time = data .index [
62+ data [time_variable_name ] == inferred_treatment_time
63+ ][0 ]
5964
6065 # --- HDI bounds (credible interval) ---
6166 hdi_bounds = az .hdi (idata , var_names = ["treatment_time" ])[
6267 "treatment_time"
6368 ].values
6469 hdi_start_time = int (hdi_bounds [0 ])
65- indice = data .index .get_loc (data .index [data ["t" ] == hdi_start_time ][0 ])
70+ indice = data .index .get_loc (
71+ data .index [data [time_variable_name ] == hdi_start_time ][0 ]
72+ )
6673
6774 # --- Slicing ---
68- datapre = data [data ["t" ] < hdi_start_time ]
69- datapost = data [data ["t" ] >= hdi_start_time ]
75+ datapre = data [data [time_variable_name ] < hdi_start_time ]
76+ datapost = data [data [time_variable_name ] >= hdi_start_time ]
7077
7178 truncated_y = pre_y .isel (obs_ind = slice (0 , indice ))
7279 truncated_X = pre_X .isel (obs_ind = slice (0 , indice ))
7380
7481 return datapre , datapost , truncated_y , truncated_X , idx_treatment_time
7582
76- def plot_intervention_line (self , ax , idata , datapost , treatment_time ):
83+ def plot_intervention_line (self , ax , model , idata , datapost , treatment_time ):
7784 """
7885 Plot a vertical line at the inferred treatment time, along with a shaded area
7986 representing the Highest Density Interval (HDI) of the inferred time.
8087 """
88+ # --- Getting the time_variable_name ---
89+ time_variable_name = model .get_time_variable_name ()
90+
8191 # Extract the HDI (uncertainty interval) of the treatment time
8292 hdi = az .hdi (idata , var_names = ["treatment_time" ])["treatment_time" ].values
83- x1 = datapost .index [datapost ["t" ] == int (hdi [0 ])][0 ]
84- x2 = datapost .index [datapost ["t" ] == int (hdi [1 ])][0 ]
93+ x1 = datapost .index [datapost [time_variable_name ] == int (hdi [0 ])][0 ]
94+ x2 = datapost .index [datapost [time_variable_name ] == int (hdi [1 ])][0 ]
8595
8696 for i in [0 , 1 , 2 ]:
8797 ymin , ymax = ax [i ].get_ylim ()
@@ -119,7 +129,7 @@ def plot_treated_counterfactual(
119129 plot_hdi_kwargs = {"color" : "yellowgreen" },
120130 )
121131 handles .append ((h_line , h_patch ))
122- labels .append ("treated counterfactual" )
132+ labels .append ("Treated counterfactual" )
123133
124134
125135class HandlerKTT :
@@ -135,7 +145,7 @@ def data_preprocessing(self, data, treatment_time, model):
135145 # Use only data before treatment for training the model
136146 return data [data .index < treatment_time ]
137147
138- def data_postprocessing (self , data , idata , treatment_time , pre_y , pre_X ):
148+ def data_postprocessing (self , model , data , idata , treatment_time , pre_y , pre_X ):
139149 """
140150 Split data into pre- and post-treatment periods using the known treatment time.
141151 """
@@ -147,7 +157,7 @@ def data_postprocessing(self, data, idata, treatment_time, pre_y, pre_X):
147157 treatment_time ,
148158 )
149159
150- def plot_intervention_line (self , ax , idata , datapost , treatment_time ):
160+ def plot_intervention_line (self , model , ax , idata , datapost , treatment_time ):
151161 """
152162 Plot a vertical line at the known treatment time on provided axes.
153163 """
@@ -276,7 +286,7 @@ def __init__(
276286 # Postprocessing with handler
277287 self .datapre , self .datapost , self .pre_y , self .pre_X , self .treatment_time = (
278288 self .handler .data_postprocessing (
279- data , idata , treatment_time , self .pre_y , self .pre_X
289+ self . model , data , idata , treatment_time , self .pre_y , self .pre_X
280290 )
281291 )
282292
@@ -443,7 +453,7 @@ def _bayesian_plot(
443453
444454 # Plot vertical line marking treatment time (with HDI if it's inferred)
445455 self .handler .plot_intervention_line (
446- ax , self .idata , self .datapost , self .treatment_time
456+ ax , self .model , self . idata , self .datapost , self .treatment_time
447457 )
448458
449459 ax [0 ].legend (
0 commit comments