Skip to content

Commit 0b5535a

Browse files
committed
Refactor graded intervention ITS to use unfitted model pattern
Refactors GradedInterventionTimeSeries and TransferFunctionOLS to follow the standard CausalPy pattern: the experiment class now takes an unfitted model and handles transform parameter estimation, fitting, and result extraction. Removes the with_estimated_transforms factory method, updates all docstrings, and adapts tests and documentation to the new workflow. This enables more flexible and consistent usage for multi-treatment and advanced modeling scenarios.
1 parent 4889f72 commit 0b5535a

File tree

4 files changed

+300
-435
lines changed

4 files changed

+300
-435
lines changed

causalpy/experiments/graded_intervention_its.py

Lines changed: 148 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import numpy as np
3030
import pandas as pd
3131
from patsy import dmatrix
32+
from sklearn.base import RegressorMixin
3233
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
3334
from statsmodels.stats.diagnostic import acorr_ljungbox
3435

@@ -47,28 +48,28 @@ class GradedInterventionTimeSeries(BaseExperiment):
4748
4849
This experiment class handles causal inference for time series with graded
4950
(non-binary) interventions, incorporating saturation and adstock effects.
50-
It works with a pre-fitted TransferFunctionOLS model to provide visualization,
51-
diagnostics, and counterfactual effect estimation.
51+
Following the standard CausalPy pattern, it takes data and an unfitted model,
52+
performs transform parameter estimation, fits the model, and provides
53+
visualization, diagnostics, and counterfactual effect estimation.
5254
5355
Typical workflow:
54-
1. Create and fit a TransferFunctionOLS model using the with_estimated_transforms() method
55-
2. Pass the fitted model to this experiment class
56-
3. Use experiment methods for visualization and effect estimation
56+
1. Create an UNFITTED TransferFunctionOLS model with configuration
57+
2. Pass data + model to this experiment class
58+
3. Experiment estimates transforms, fits model, and provides results
59+
4. Use experiment methods for visualization and effect estimation
5760
5861
Parameters
5962
----------
6063
data : pd.DataFrame
6164
Time series data with datetime or numeric index.
6265
y_column : str
6366
Name of the outcome variable column.
64-
treatment_name : str
65-
Name of the treatment variable column.
67+
treatment_names : List[str]
68+
List of treatment variable names (e.g., ["comm_intensity"]).
6669
base_formula : str
67-
Patsy formula for baseline model.
68-
treatments : List[Treatment]
69-
List of Treatment objects with configured transforms.
70+
Patsy formula for baseline model (e.g., "1 + t + temperature").
7071
model : TransferFunctionOLS
71-
Pre-fitted model instance.
72+
UNFITTED model with configuration for transform parameter estimation.
7273
7374
Attributes
7475
----------
@@ -92,24 +93,20 @@ class GradedInterventionTimeSeries(BaseExperiment):
9293
Examples
9394
--------
9495
>>> import causalpy as cp
95-
>>> # Step 1: Create and fit model
96-
>>> model = cp.skl_models.TransferFunctionOLS.with_estimated_transforms(
97-
... data=df,
98-
... y_column="water_consumption",
99-
... treatment_name="comm_intensity",
100-
... base_formula="1 + t + temperature + rainfall",
101-
... estimation_method="grid",
96+
>>> # Step 1: Create UNFITTED model with configuration
97+
>>> model = cp.skl_models.TransferFunctionOLS(
98+
... saturation_type="hill",
10299
... saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
103100
... adstock_grid={"half_life": [2, 3, 4, 5]},
101+
... estimation_method="grid",
104102
... error_model="hac",
105103
... )
106-
>>> # Step 2: Create experiment with fitted model
104+
>>> # Step 2: Pass to experiment (experiment estimates transforms and fits model)
107105
>>> result = cp.GradedInterventionTimeSeries(
108106
... data=df,
109107
... y_column="water_consumption",
110-
... treatment_name="comm_intensity",
108+
... treatment_names=["comm_intensity"],
111109
... base_formula="1 + t + temperature + rainfall",
112-
... treatments=model.treatments,
113110
... model=model,
114111
... )
115112
>>> # Step 3: Use experiment methods
@@ -127,63 +124,152 @@ def __init__(
127124
self,
128125
data: pd.DataFrame,
129126
y_column: str,
130-
treatment_name: str,
127+
treatment_names: List[str],
131128
base_formula: str,
132-
treatments: List[Treatment],
133129
model=None,
134130
**kwargs,
135131
):
136132
"""
137-
Initialize experiment with pre-configured treatments and fitted model.
133+
Initialize experiment with data and unfitted model (standard CausalPy pattern).
138134
139-
The model should be a fitted TransferFunctionOLS instance. For most use cases,
140-
create the model using TransferFunctionOLS.with_estimated_transforms() first.
135+
This method:
136+
1. Validates inputs and builds baseline design matrix
137+
2. Estimates transform parameters for each treatment
138+
3. Applies transforms and builds full design matrix
139+
4. Calls model.fit(X_full, y)
140+
5. Extracts results for visualization and analysis
141+
142+
Parameters
143+
----------
144+
data : pd.DataFrame
145+
Time series data.
146+
y_column : str
147+
Name of outcome variable.
148+
treatment_names : List[str]
149+
List of treatment variable names (e.g., ["comm_intensity"]).
150+
base_formula : str
151+
Patsy formula for baseline model.
152+
model : TransferFunctionOLS
153+
UNFITTED model with configuration for transform estimation.
141154
"""
142155
super().__init__(model=model)
143156

144-
# Validate model
145-
if model is None:
146-
raise ValueError(
147-
"A fitted model is required. Use TransferFunctionOLS.with_estimated_transforms() "
148-
"to create and fit a model, then pass it to this experiment class."
149-
)
150-
151157
# Validate inputs
152-
self._validate_inputs(data, y_column, treatments)
158+
self._validate_inputs(data, y_column, treatment_names)
153159

154160
# Store attributes
155161
self.data = data.copy()
156162
self.y_column = y_column
157-
self.treatment_name = treatment_name # Store for backwards compatibility
163+
self.treatment_names = treatment_names
158164
self.base_formula = base_formula
159-
self.treatments = treatments
160-
self.treatment_names = [t.name for t in treatments]
161165

162166
# Extract outcome variable
163167
self.y = data[y_column].values
164168

165-
# Build baseline design matrix
169+
# Build baseline design matrix (like other experiments do)
166170
self.X_baseline = np.asarray(dmatrix(base_formula, data))
167171
self.baseline_labels = dmatrix(base_formula, data).design_info.column_names
168172

169-
# Build treatment design matrix
170-
self.Z_treatment, self.treatment_labels = self._build_treatment_matrix(
171-
data, treatments
173+
# Estimate transform parameters for each treatment
174+
from causalpy.transform_optimization import (
175+
estimate_transform_params_grid,
176+
estimate_transform_params_optimize,
172177
)
178+
from causalpy.transforms import Treatment
173179

174-
# Combine matrices
180+
self.treatments = []
181+
Z_columns = []
182+
self.treatment_labels = []
183+
184+
for name in treatment_names:
185+
# Run parameter estimation using model configuration
186+
if self.model.estimation_method == "grid":
187+
est_results = estimate_transform_params_grid(
188+
data=data,
189+
y_column=y_column,
190+
treatment_name=name,
191+
base_formula=base_formula,
192+
saturation_type=self.model.saturation_type,
193+
saturation_grid=self.model.saturation_grid,
194+
adstock_grid=self.model.adstock_grid,
195+
coef_constraint=self.model.coef_constraint,
196+
hac_maxlags=self.model.hac_maxlags,
197+
error_model=self.model.error_model,
198+
arima_order=self.model.arima_order,
199+
)
200+
search_space = {
201+
"saturation_grid": self.model.saturation_grid,
202+
"adstock_grid": self.model.adstock_grid,
203+
}
204+
elif self.model.estimation_method == "optimize":
205+
est_results = estimate_transform_params_optimize(
206+
data=data,
207+
y_column=y_column,
208+
treatment_name=name,
209+
base_formula=base_formula,
210+
saturation_type=self.model.saturation_type,
211+
saturation_bounds=self.model.saturation_bounds,
212+
adstock_bounds=self.model.adstock_bounds,
213+
initial_params=None,
214+
coef_constraint=self.model.coef_constraint,
215+
hac_maxlags=self.model.hac_maxlags,
216+
method="L-BFGS-B",
217+
error_model=self.model.error_model,
218+
arima_order=self.model.arima_order,
219+
)
220+
search_space = {
221+
"saturation_bounds": self.model.saturation_bounds,
222+
"adstock_bounds": self.model.adstock_bounds,
223+
}
224+
225+
# Store estimation metadata on model
226+
self.model.transform_estimation_results = est_results
227+
self.model.transform_search_space = search_space
228+
229+
# Create Treatment with estimated transforms
230+
treatment = Treatment(
231+
name=name,
232+
saturation=est_results["best_saturation"],
233+
adstock=est_results["best_adstock"],
234+
coef_constraint=self.model.coef_constraint,
235+
)
236+
self.treatments.append(treatment)
237+
238+
# Apply transforms
239+
x_raw = data[name].values
240+
x_transformed = x_raw
241+
if treatment.saturation is not None:
242+
x_transformed = treatment.saturation.apply(x_transformed)
243+
if treatment.adstock is not None:
244+
x_transformed = treatment.adstock.apply(x_transformed)
245+
if treatment.lag is not None:
246+
x_transformed = treatment.lag.apply(x_transformed)
247+
248+
Z_columns.append(x_transformed)
249+
self.treatment_labels.append(name)
250+
251+
# Build full design matrix
252+
self.Z_treatment = np.column_stack(Z_columns)
175253
self.X_full = np.column_stack([self.X_baseline, self.Z_treatment])
176254
self.all_labels = self.baseline_labels + self.treatment_labels
177255

178-
# Extract information from fitted model
179-
self.model = model
180-
self.ols_result = model.ols_result
181-
self.predictions = model.ols_result.fittedvalues
182-
self.residuals = model.ols_result.resid
183-
self.score = model.score
256+
# Store treatments on model for later use
257+
self.model.treatments = self.treatments
258+
259+
# Fit the model (standard CausalPy pattern)
260+
if isinstance(self.model, RegressorMixin):
261+
self.model.fit(X=self.X_full, y=self.y)
262+
else:
263+
raise ValueError("Model type not recognized")
264+
265+
# Extract results from fitted model
266+
self.ols_result = self.model.ols_result
267+
self.predictions = self.model.ols_result.fittedvalues
268+
self.residuals = self.model.ols_result.resid
269+
self.score = self.model.score
184270

185271
# Extract coefficients (handling ARIMAX correctly)
186-
if hasattr(model, "error_model") and model.error_model == "arimax":
272+
if self.model.error_model == "arimax":
187273
# ARIMAX: extract only exogenous coefficients
188274
n_exog = self.ols_result.model.k_exog
189275
exog_params = self.ols_result.params[:n_exog]
@@ -197,44 +283,38 @@ def __init__(
197283
self.theta_treatment = self.ols_result.params[n_baseline:]
198284

199285
# Store model metadata for summary output
200-
self.error_model = getattr(model, "error_model", "hac")
201-
self.hac_maxlags = getattr(model, "hac_maxlags", None)
202-
self.arima_order = getattr(model, "arima_order", None)
203-
self.transform_estimation_method = getattr(
204-
model, "transform_estimation_method", None
205-
)
206-
self.transform_estimation_results = getattr(
207-
model, "transform_estimation_results", None
208-
)
209-
self.transform_search_space = getattr(model, "transform_search_space", None)
286+
self.error_model = self.model.error_model
287+
self.hac_maxlags = self.model.hac_maxlags
288+
self.arima_order = self.model.arima_order
289+
self.transform_estimation_method = self.model.estimation_method
290+
self.transform_estimation_results = self.model.transform_estimation_results
291+
self.transform_search_space = self.model.transform_search_space
210292

211293
def _validate_inputs(
212294
self,
213295
data: pd.DataFrame,
214296
y_column: str,
215-
treatments: List[Treatment],
297+
treatment_names: List[str],
216298
) -> None:
217299
"""Validate input data and parameters."""
218300
# Check that y_column exists
219301
if y_column not in data.columns:
220302
raise ValueError(f"y_column '{y_column}' not found in data columns")
221303

222304
# Check that treatment columns exist
223-
for treatment in treatments:
224-
if treatment.name not in data.columns:
225-
raise ValueError(
226-
f"Treatment column '{treatment.name}' not found in data columns"
227-
)
305+
for name in treatment_names:
306+
if name not in data.columns:
307+
raise ValueError(f"Treatment column '{name}' not found in data columns")
228308

229309
# Check for missing values in outcome
230310
if data[y_column].isna().any():
231311
raise ValueError("Outcome variable contains missing values")
232312

233313
# Warn about missing values in treatment columns
234-
for treatment in treatments:
235-
if data[treatment.name].isna().any():
314+
for name in treatment_names:
315+
if data[name].isna().any():
236316
print(
237-
f"Warning: Treatment column '{treatment.name}' contains missing values. "
317+
f"Warning: Treatment column '{name}' contains missing values. "
238318
f"Consider forward-filling if justified by the context."
239319
)
240320

0 commit comments

Comments
 (0)