2929import numpy as np
3030import pandas as pd
3131from patsy import dmatrix
32+ from sklearn .base import RegressorMixin
3233from statsmodels .graphics .tsaplots import plot_acf , plot_pacf
3334from statsmodels .stats .diagnostic import acorr_ljungbox
3435
@@ -47,28 +48,28 @@ class GradedInterventionTimeSeries(BaseExperiment):
4748
4849 This experiment class handles causal inference for time series with graded
4950 (non-binary) interventions, incorporating saturation and adstock effects.
50- It works with a pre-fitted TransferFunctionOLS model to provide visualization,
51- diagnostics, and counterfactual effect estimation.
51+ Following the standard CausalPy pattern, it takes data and an unfitted model,
52+ performs transform parameter estimation, fits the model, and provides
53+ visualization, diagnostics, and counterfactual effect estimation.
5254
5355 Typical workflow:
54- 1. Create and fit a TransferFunctionOLS model using the with_estimated_transforms() method
55- 2. Pass the fitted model to this experiment class
56- 3. Use experiment methods for visualization and effect estimation
56+ 1. Create an UNFITTED TransferFunctionOLS model with configuration
57+ 2. Pass data + model to this experiment class
58+ 3. Experiment estimates transforms, fits model, and provides results
59+ 4. Use experiment methods for visualization and effect estimation
5760
5861 Parameters
5962 ----------
6063 data : pd.DataFrame
6164 Time series data with datetime or numeric index.
6265 y_column : str
6366 Name of the outcome variable column.
64- treatment_name : str
65- Name of the treatment variable column .
67+ treatment_names : List[ str]
68+ List of treatment variable names (e.g., ["comm_intensity"]) .
6669 base_formula : str
67- Patsy formula for baseline model.
68- treatments : List[Treatment]
69- List of Treatment objects with configured transforms.
70+ Patsy formula for baseline model (e.g., "1 + t + temperature").
7071 model : TransferFunctionOLS
71- Pre-fitted model instance .
72+ UNFITTED model with configuration for transform parameter estimation .
7273
7374 Attributes
7475 ----------
@@ -92,24 +93,20 @@ class GradedInterventionTimeSeries(BaseExperiment):
9293 Examples
9394 --------
9495 >>> import causalpy as cp
95- >>> # Step 1: Create and fit model
96- >>> model = cp.skl_models.TransferFunctionOLS.with_estimated_transforms(
97- ... data=df,
98- ... y_column="water_consumption",
99- ... treatment_name="comm_intensity",
100- ... base_formula="1 + t + temperature + rainfall",
101- ... estimation_method="grid",
96+ >>> # Step 1: Create UNFITTED model with configuration
97+ >>> model = cp.skl_models.TransferFunctionOLS(
98+ ... saturation_type="hill",
10299 ... saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
103100 ... adstock_grid={"half_life": [2, 3, 4, 5]},
101+ ... estimation_method="grid",
104102 ... error_model="hac",
105103 ... )
106- >>> # Step 2: Create experiment with fitted model
104+ >>> # Step 2: Pass to experiment (experiment estimates transforms and fits model)
107105 >>> result = cp.GradedInterventionTimeSeries(
108106 ... data=df,
109107 ... y_column="water_consumption",
110- ... treatment_name= "comm_intensity",
108+ ... treatment_names=[ "comm_intensity"] ,
111109 ... base_formula="1 + t + temperature + rainfall",
112- ... treatments=model.treatments,
113110 ... model=model,
114111 ... )
115112 >>> # Step 3: Use experiment methods
@@ -127,63 +124,152 @@ def __init__(
127124 self ,
128125 data : pd .DataFrame ,
129126 y_column : str ,
130- treatment_name : str ,
127+ treatment_names : List [ str ] ,
131128 base_formula : str ,
132- treatments : List [Treatment ],
133129 model = None ,
134130 ** kwargs ,
135131 ):
136132 """
137- Initialize experiment with pre-configured treatments and fitted model.
133+ Initialize experiment with data and unfitted model (standard CausalPy pattern) .
138134
139- The model should be a fitted TransferFunctionOLS instance. For most use cases,
140- create the model using TransferFunctionOLS.with_estimated_transforms() first.
135+ This method:
136+ 1. Validates inputs and builds baseline design matrix
137+ 2. Estimates transform parameters for each treatment
138+ 3. Applies transforms and builds full design matrix
139+ 4. Calls model.fit(X_full, y)
140+ 5. Extracts results for visualization and analysis
141+
142+ Parameters
143+ ----------
144+ data : pd.DataFrame
145+ Time series data.
146+ y_column : str
147+ Name of outcome variable.
148+ treatment_names : List[str]
149+ List of treatment variable names (e.g., ["comm_intensity"]).
150+ base_formula : str
151+ Patsy formula for baseline model.
152+ model : TransferFunctionOLS
153+ UNFITTED model with configuration for transform estimation.
141154 """
142155 super ().__init__ (model = model )
143156
144- # Validate model
145- if model is None :
146- raise ValueError (
147- "A fitted model is required. Use TransferFunctionOLS.with_estimated_transforms() "
148- "to create and fit a model, then pass it to this experiment class."
149- )
150-
151157 # Validate inputs
152- self ._validate_inputs (data , y_column , treatments )
158+ self ._validate_inputs (data , y_column , treatment_names )
153159
154160 # Store attributes
155161 self .data = data .copy ()
156162 self .y_column = y_column
157- self .treatment_name = treatment_name # Store for backwards compatibility
163+ self .treatment_names = treatment_names
158164 self .base_formula = base_formula
159- self .treatments = treatments
160- self .treatment_names = [t .name for t in treatments ]
161165
162166 # Extract outcome variable
163167 self .y = data [y_column ].values
164168
165- # Build baseline design matrix
169+ # Build baseline design matrix (like other experiments do)
166170 self .X_baseline = np .asarray (dmatrix (base_formula , data ))
167171 self .baseline_labels = dmatrix (base_formula , data ).design_info .column_names
168172
169- # Build treatment design matrix
170- self .Z_treatment , self .treatment_labels = self ._build_treatment_matrix (
171- data , treatments
173+ # Estimate transform parameters for each treatment
174+ from causalpy .transform_optimization import (
175+ estimate_transform_params_grid ,
176+ estimate_transform_params_optimize ,
172177 )
178+ from causalpy .transforms import Treatment
173179
174- # Combine matrices
180+ self .treatments = []
181+ Z_columns = []
182+ self .treatment_labels = []
183+
184+ for name in treatment_names :
185+ # Run parameter estimation using model configuration
186+ if self .model .estimation_method == "grid" :
187+ est_results = estimate_transform_params_grid (
188+ data = data ,
189+ y_column = y_column ,
190+ treatment_name = name ,
191+ base_formula = base_formula ,
192+ saturation_type = self .model .saturation_type ,
193+ saturation_grid = self .model .saturation_grid ,
194+ adstock_grid = self .model .adstock_grid ,
195+ coef_constraint = self .model .coef_constraint ,
196+ hac_maxlags = self .model .hac_maxlags ,
197+ error_model = self .model .error_model ,
198+ arima_order = self .model .arima_order ,
199+ )
200+ search_space = {
201+ "saturation_grid" : self .model .saturation_grid ,
202+ "adstock_grid" : self .model .adstock_grid ,
203+ }
204+ elif self .model .estimation_method == "optimize" :
205+ est_results = estimate_transform_params_optimize (
206+ data = data ,
207+ y_column = y_column ,
208+ treatment_name = name ,
209+ base_formula = base_formula ,
210+ saturation_type = self .model .saturation_type ,
211+ saturation_bounds = self .model .saturation_bounds ,
212+ adstock_bounds = self .model .adstock_bounds ,
213+ initial_params = None ,
214+ coef_constraint = self .model .coef_constraint ,
215+ hac_maxlags = self .model .hac_maxlags ,
216+ method = "L-BFGS-B" ,
217+ error_model = self .model .error_model ,
218+ arima_order = self .model .arima_order ,
219+ )
220+ search_space = {
221+ "saturation_bounds" : self .model .saturation_bounds ,
222+ "adstock_bounds" : self .model .adstock_bounds ,
223+ }
224+
225+ # Store estimation metadata on model
226+ self .model .transform_estimation_results = est_results
227+ self .model .transform_search_space = search_space
228+
229+ # Create Treatment with estimated transforms
230+ treatment = Treatment (
231+ name = name ,
232+ saturation = est_results ["best_saturation" ],
233+ adstock = est_results ["best_adstock" ],
234+ coef_constraint = self .model .coef_constraint ,
235+ )
236+ self .treatments .append (treatment )
237+
238+ # Apply transforms
239+ x_raw = data [name ].values
240+ x_transformed = x_raw
241+ if treatment .saturation is not None :
242+ x_transformed = treatment .saturation .apply (x_transformed )
243+ if treatment .adstock is not None :
244+ x_transformed = treatment .adstock .apply (x_transformed )
245+ if treatment .lag is not None :
246+ x_transformed = treatment .lag .apply (x_transformed )
247+
248+ Z_columns .append (x_transformed )
249+ self .treatment_labels .append (name )
250+
251+ # Build full design matrix
252+ self .Z_treatment = np .column_stack (Z_columns )
175253 self .X_full = np .column_stack ([self .X_baseline , self .Z_treatment ])
176254 self .all_labels = self .baseline_labels + self .treatment_labels
177255
178- # Extract information from fitted model
179- self .model = model
180- self .ols_result = model .ols_result
181- self .predictions = model .ols_result .fittedvalues
182- self .residuals = model .ols_result .resid
183- self .score = model .score
256+ # Store treatments on model for later use
257+ self .model .treatments = self .treatments
258+
259+ # Fit the model (standard CausalPy pattern)
260+ if isinstance (self .model , RegressorMixin ):
261+ self .model .fit (X = self .X_full , y = self .y )
262+ else :
263+ raise ValueError ("Model type not recognized" )
264+
265+ # Extract results from fitted model
266+ self .ols_result = self .model .ols_result
267+ self .predictions = self .model .ols_result .fittedvalues
268+ self .residuals = self .model .ols_result .resid
269+ self .score = self .model .score
184270
185271 # Extract coefficients (handling ARIMAX correctly)
186- if hasattr ( model , "error_model" ) and model .error_model == "arimax" :
272+ if self . model .error_model == "arimax" :
187273 # ARIMAX: extract only exogenous coefficients
188274 n_exog = self .ols_result .model .k_exog
189275 exog_params = self .ols_result .params [:n_exog ]
@@ -197,44 +283,38 @@ def __init__(
197283 self .theta_treatment = self .ols_result .params [n_baseline :]
198284
199285 # Store model metadata for summary output
200- self .error_model = getattr (model , "error_model" , "hac" )
201- self .hac_maxlags = getattr (model , "hac_maxlags" , None )
202- self .arima_order = getattr (model , "arima_order" , None )
203- self .transform_estimation_method = getattr (
204- model , "transform_estimation_method" , None
205- )
206- self .transform_estimation_results = getattr (
207- model , "transform_estimation_results" , None
208- )
209- self .transform_search_space = getattr (model , "transform_search_space" , None )
286+ self .error_model = self .model .error_model
287+ self .hac_maxlags = self .model .hac_maxlags
288+ self .arima_order = self .model .arima_order
289+ self .transform_estimation_method = self .model .estimation_method
290+ self .transform_estimation_results = self .model .transform_estimation_results
291+ self .transform_search_space = self .model .transform_search_space
210292
211293 def _validate_inputs (
212294 self ,
213295 data : pd .DataFrame ,
214296 y_column : str ,
215- treatments : List [Treatment ],
297+ treatment_names : List [str ],
216298 ) -> None :
217299 """Validate input data and parameters."""
218300 # Check that y_column exists
219301 if y_column not in data .columns :
220302 raise ValueError (f"y_column '{ y_column } ' not found in data columns" )
221303
222304 # Check that treatment columns exist
223- for treatment in treatments :
224- if treatment .name not in data .columns :
225- raise ValueError (
226- f"Treatment column '{ treatment .name } ' not found in data columns"
227- )
305+ for name in treatment_names :
306+ if name not in data .columns :
307+ raise ValueError (f"Treatment column '{ name } ' not found in data columns" )
228308
229309 # Check for missing values in outcome
230310 if data [y_column ].isna ().any ():
231311 raise ValueError ("Outcome variable contains missing values" )
232312
233313 # Warn about missing values in treatment columns
234- for treatment in treatments :
235- if data [treatment . name ].isna ().any ():
314+ for name in treatment_names :
315+ if data [name ].isna ().any ():
236316 print (
237- f"Warning: Treatment column '{ treatment . name } ' contains missing values. "
317+ f"Warning: Treatment column '{ name } ' contains missing values. "
238318 f"Consider forward-filling if justified by the context."
239319 )
240320
0 commit comments