8686--------------
8787Known treatment time (traditional approach):
8888
89- >>> result = cp.InterruptedTimeSeries (
89+ >>> result = cp.ChangePointDetection (
9090... data=df,
91- ... treatment_time=pd.to_datetime("2017-01-01"), # Known intervention
91+ ... time_range=None
9292... formula="y ~ 1 + t + C(month)",
93- ... model=cp.pymc_models.LinearRegression(),
94- ... )
95-
96- Unknown treatment time (inference approach):
97-
98- >>> model = cp.pymc_models.InterventionTimeEstimator(treatment_effect_type="level")
99- >>> result = cp.InterruptedTimeSeries(
100- ... data=df,
101- ... treatment_time=None, # Let model infer the time
102- ... formula="y ~ 1 + t + C(month)",
103- ... model=model,
93+ ... model=cp.pymc_models.LinearChangePointDetection(),
10494... )
10595
10696The module automatically selects the appropriate handler based on the treatment_time
@@ -164,35 +154,35 @@ class ChangePointDetection(BaseExperiment):
164154 ... )
165155 """
166156
167- expt_type = "Interrupted Time Series "
157+ expt_type = "Change Point Detection "
168158 supports_ols = False
169159 supports_bayes = True
170160
171161 def __init__ (
172162 self ,
173163 data : pd .DataFrame ,
174164 formula : str ,
175- treatment_time_range : Union [Iterable , None ] = None ,
165+ time_range : Union [Iterable , None ] = None ,
176166 model = None ,
177167 ** kwargs ,
178168 ) -> None :
179169 super ().__init__ (model = model )
180170
181171 # rename the index to "obs_ind"
182172 data .index .name = "obs_ind"
183- self .input_validation (data , treatment_time_range , model )
173+ self .input_validation (data , time_range , model )
184174
185175 # set experiment type - usually done in subclasses
186176 self .expt_type = "Pre-Post Fit"
187177
188- self .treatment_time_range = treatment_time_range
178+ self .time_range = time_range
189179 self .formula = formula
190180
191181 # Define the time interval over which the model will perform inference
192- model .set_time_range (self .treatment_time_range , data )
182+ model .set_time_range (self .time_range , data )
193183
194184 # Preprocess the data according to the given formula
195- y , X = dmatrices (formula , self . datapre )
185+ y , X = dmatrices (formula , data )
196186
197187 self .outcome_variable_name = y .design_info .column_names [0 ]
198188 self ._y_design_info = y .design_info
@@ -205,14 +195,14 @@ def __init__(
205195 self .X ,
206196 dims = ["obs_ind" , "coeffs" ],
207197 coords = {
208- "obs_ind" : self . datapre .index ,
198+ "obs_ind" : data .index ,
209199 "coeffs" : self .labels ,
210200 },
211201 )
212202 self .y = xr .DataArray (
213203 self .y , # Keep 2D shape
214204 dims = ["obs_ind" , "treated_units" ],
215- coords = {"obs_ind" : self . datapre .index , "treated_units" : ["unit_0" ]},
205+ coords = {"obs_ind" : data .index , "treated_units" : ["unit_0" ]},
216206 )
217207
218208 # fit the model to the observed data
@@ -266,34 +256,40 @@ def __init__(
266256 timeline_broadcast = np .array (timeline )
267257 tt_broadcast = cp_samples [:, :, None ].astype (int )
268258 mask = (timeline_broadcast >= tt_broadcast ).astype (int )
259+ mask = mask [:, :, np .newaxis , :]
260+ post_impact_masked = impact * mask
269261
270- # --- Compute cumulative post-treatment impact ---
262+ # --- Compute cumulative post-change point impact ---
271263 post_impact_masked = impact * mask
272264 self .post_impact_cumulative = model .calculate_cumulative_impact (
273265 post_impact_masked
274266 )
275267
276- def input_validation (self , data , treatment_time_range , model ):
268+ def input_validation (self , data , time_range , model ):
277269 """Validate the input data and model formula for correctness"""
278270 if not hasattr (model , "set_time_range" ):
279271 raise ModelException ("Provided model must have a 'set_time_range' method" )
280- if treatment_time_range is not None and len (treatment_time_range ) != 2 :
272+ if time_range is not None and len (time_range ) != 2 :
281273 raise BadIndexException (
282- "Provided treatment_time_range must be of length 2 : (start, end)"
274+ "Provided time_range must be of length 2 : (start, end)"
283275 )
284276 if isinstance (data .index , pd .DatetimeIndex ) and not (
285- treatment_time_range is None
286- or isinstance (treatment_time_range , Iterable [pd .Timestamp ])
277+ time_range is None
278+ or (
279+ isinstance (time_range , Iterable )
280+ and all (isinstance (t , pd .Timestamp ) for t in time_range )
281+ )
287282 ):
288283 raise BadIndexException (
289- "If data.index is DatetimeIndex, treatment_time_range must "
284+ "If data.index is DatetimeIndex, time_range must "
290285 "be of type Iterable[pd.Timestamp]."
291286 )
292- if not isinstance (data .index , pd .DatetimeIndex ) and isinstance (
293- treatment_time_range , Iterable [pd .Timestamp ]
287+ if not isinstance (data .index , pd .DatetimeIndex ) and (
288+ isinstance (time_range , Iterable )
289+ and all (isinstance (t , pd .Timestamp ) for t in time_range )
294290 ):
295291 raise BadIndexException (
296- "If data.index is not DatetimeIndex, treatment_time_range must"
292+ "If data.index is not DatetimeIndex, time_range must"
297293 "not be of type Iterable[pd.Timestamp]." # noqa: E501
298294 )
299295
@@ -324,7 +320,7 @@ def _bayesian_plot(
324320 labels = []
325321
326322 # Treated counterfactual
327- # Plot predicted values under treatment (with HDI)
323+ # Plot predicted values after change point (with HDI)
328324 h_line , h_patch = plot_xY (
329325 self .datapre .index ,
330326 self .pre_pred ["posterior_predictive" ].mu_ts .isel (treated_units = 0 ),
@@ -440,17 +436,17 @@ def _bayesian_plot(
440436 )
441437 ax [2 ].axhline (y = 0 , c = "k" )
442438
443- # Plot vertical line marking treatment time (with HDI if it's inferred)
439+ # Plot vertical line marking change point (with HDI if it's inferred)
444440 data = pd .concat ([self .datapre , self .datapost ])
445- # Extract the HDI (uncertainty interval) of the treatment time
446- hdi = az .hdi (self .idata , var_names = ["treatment_time " ])["treatment_time " ].values
441+ # Extract the HDI (uncertainty interval) of the change point
442+ hdi = az .hdi (self .idata , var_names = ["change_point " ])["change_point " ].values
447443 x1 = data .index [int (hdi [0 ])]
448444 x2 = data .index [int (hdi [1 ])]
449445
450446 for i in [0 , 1 , 2 ]:
451447 ymin , ymax = ax [i ].get_ylim ()
452448
453- # Vertical line for inferred treatment time
449+ # Vertical line for inferred change point
454450 ax [i ].plot (
455451 [self .changepoint , self .changepoint ],
456452 [ymin , ymax ],
@@ -460,7 +456,7 @@ def _bayesian_plot(
460456 solid_capstyle = "butt" ,
461457 )
462458
463- # Shaded region for HDI of treatment time
459+ # Shaded region for HDI of change point
464460 ax [i ].fill_betweenx (
465461 y = [ymin , ymax ],
466462 x1 = x1 ,
@@ -545,13 +541,13 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
545541 else :
546542 raise ValueError ("Unsupported model type" )
547543
548- def plot_treatment_time (self ):
544+ def plot_change_point (self ):
549545 """
550- display the posterior estimates of the treatment time
546+ display the posterior estimates of the change point
551547 """
552- if "treatment_time " not in self .idata .posterior .data_vars :
548+ if "change_point " not in self .idata .posterior .data_vars :
553549 raise ValueError (
554- "Variable 'treatment_time ' not found in inference data (idata)."
550+ "Variable 'change_point ' not found in inference data (idata)."
555551 )
556552
557- az .plot_trace (self .idata , var_names = "treatment_time " )
553+ az .plot_trace (self .idata , var_names = "change_point " )
0 commit comments