Skip to content

Commit aa071ab

Browse files
committed
Add ARIMAX error model support to TransferFunctionITS
Extended TransferFunctionITS and transform optimization to support ARIMAX (ARIMA with exogenous variables) error models in addition to HAC standard errors. Updated model fitting, parameter estimation, and documentation to allow users to specify error_model ('hac' or 'arimax') and ARIMA order. Added comprehensive tests for ARIMAX functionality and updated the notebook to demonstrate ARIMAX usage and comparison with HAC.
1 parent 8f37426 commit aa071ab

File tree

4 files changed

+849
-64
lines changed

4 files changed

+849
-64
lines changed

causalpy/experiments/transfer_function_its.py

Lines changed: 177 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,36 @@
2929
3030
2. **Inner Loop**: For each candidate set of transform parameters:
3131
- Apply transforms to the treatment variable
32-
- Fit OLS model with HAC standard errors
32+
- Fit regression model (OLS with HAC or ARIMAX)
3333
- Compute RMSE as the optimization metric
3434
3535
3. **Selection**: The transform parameters that yield the lowest RMSE are selected
36-
as the final estimates. These parameters, along with the OLS coefficients from
37-
the best-fitting model, define the complete fitted model.
36+
as the final estimates. These parameters, along with the regression coefficients
37+
from the best-fitting model, define the complete fitted model.
3838
3939
This nested approach is efficient because OLS has a closed-form solution, making
4040
the inner loop fast even when evaluating many parameter combinations.
41+
42+
Error Models
43+
------------
44+
Two error model options are available:
45+
46+
**HAC (Heteroskedasticity and Autocorrelation Consistent) Standard Errors:**
47+
- Default and recommended for most applications
48+
- Uses Newey-West robust standard errors
49+
- No specification required - automatically robust to autocorrelation
50+
- Works with any autocorrelation pattern
51+
- More conservative (wider confidence intervals)
52+
53+
**ARIMAX (ARIMA with eXogenous variables):**
54+
- Explicitly models ARIMA(p,d,q) structure of residuals
55+
- More efficient (narrower confidence intervals) when correctly specified
56+
- Requires manual specification of (p,d,q) orders
57+
- Sensitive to misspecification
58+
- Follows classical Box & Tiao (1975) intervention analysis
59+
60+
Choose HAC for robust, specification-free inference. Use ARIMAX when you have
61+
strong evidence for a specific ARIMA structure and want maximum efficiency.
4162
"""
4263

4364
from typing import Dict, List, Optional, Tuple, Union
@@ -83,7 +104,7 @@ class TransferFunctionITS(BaseExperiment):
83104
treatments : List[Treatment]
84105
Treatment specifications with estimated transforms.
85106
ols_result : statsmodels.regression.linear_model.RegressionResultsWrapper
86-
Fitted OLS model with HAC standard errors.
107+
Fitted regression model (OLS with HAC standard errors or ARIMAX).
87108
beta_baseline : np.ndarray
88109
Baseline model coefficients.
89110
theta_treatment : np.ndarray
@@ -98,9 +119,17 @@ class TransferFunctionITS(BaseExperiment):
98119
Full results from parameter estimation including best_score, best_params.
99120
transform_search_space : dict
100121
Parameter grids or bounds that were searched.
122+
error_model : str
123+
Error model type: "hac" (default) or "arimax".
124+
arima_order : Optional[Tuple[int, int, int]]
125+
ARIMA(p,d,q) order when error_model="arimax". None for HAC.
126+
hac_maxlags : Optional[int]
127+
Maximum lags for HAC standard errors. None for ARIMAX.
101128
102129
Examples
103130
--------
131+
**Example 1: HAC Standard Errors (Default)**
132+
104133
.. code-block:: python
105134
106135
import causalpy as cp
@@ -121,8 +150,8 @@ class TransferFunctionITS(BaseExperiment):
121150
df = df.set_index("date")
122151
df["t"] = np.arange(len(df))
123152
124-
# Estimate transform parameters via grid search
125-
result = cp.TransferFunctionITS.with_estimated_transforms(
153+
# Estimate transform parameters via grid search with HAC errors
154+
result_hac = cp.TransferFunctionITS.with_estimated_transforms(
126155
data=df,
127156
y_column="water_consumption",
128157
treatment_name="comm_intensity",
@@ -131,20 +160,45 @@ class TransferFunctionITS(BaseExperiment):
131160
saturation_type="hill",
132161
saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
133162
adstock_grid={"half_life": [2, 3, 4, 5]},
163+
error_model="hac", # HAC standard errors (default)
134164
)
135165
136-
# View estimated parameters
137-
print(result.transform_estimation_results["best_params"])
166+
# View estimated parameters and summary
167+
print(result_hac.transform_estimation_results["best_params"])
168+
result_hac.summary()
138169
139170
# Estimate effect of policy over entire period
140-
effect_result = result.effect(
171+
effect_result = result_hac.effect(
141172
window=(df.index[0], df.index[-1]), channels=["comm_intensity"], scale=0.0
142173
)
143174
print(f"Total effect: {effect_result['total_effect']:.2f}")
144175
145176
# Visualize results
146-
result.plot()
147-
result.diagnostics()
177+
result_hac.plot()
178+
result_hac.diagnostics()
179+
180+
**Example 2: ARIMAX Error Model**
181+
182+
.. code-block:: python
183+
184+
# Fit with ARIMAX errors if you know the error structure
185+
# (use ACF/PACF plots to determine p, d, q)
186+
result_arimax = cp.TransferFunctionITS.with_estimated_transforms(
187+
data=df,
188+
y_column="water_consumption",
189+
treatment_name="comm_intensity",
190+
base_formula="1 + t + temperature + rainfall",
191+
estimation_method="grid",
192+
saturation_type="hill",
193+
saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
194+
adstock_grid={"half_life": [2, 3, 4, 5]},
195+
error_model="arimax", # Use ARIMAX
196+
arima_order=(1, 0, 0), # AR(1) errors: p=1 (AR), d=0 (no diff), q=0 (no MA)
197+
)
198+
199+
# ARIMAX typically gives narrower confidence intervals
200+
# when the ARIMA structure is correctly specified
201+
result_arimax.summary() # Shows ARIMA order details
148202
149203
Notes
150204
-----
@@ -158,9 +212,13 @@ class TransferFunctionITS(BaseExperiment):
158212
- Grid search: Exhaustive search over discrete parameter values (slower, guaranteed best)
159213
- Continuous optimization: Uses scipy.optimize (faster, may find local optima)
160214
215+
**Error Model Selection:**
216+
- HAC (default): Robust standard errors, no specification required
217+
- ARIMAX: More efficient when ARIMA structure is known, requires (p,d,q) specification
218+
161219
**Future Extensions:**
162220
- Bootstrap or asymptotic confidence intervals for effects
163-
- Additional error models (GLSAR, ARIMAX)
221+
- Additional error models (GLSAR for known autocorrelation structure)
164222
- Bayesian inference via PyMC model (reusing transform pipeline)
165223
- Custom formula helpers (trend(), season_fourier(), holidays())
166224
- Multi-channel collinearity diagnostics
@@ -182,6 +240,8 @@ def _init_from_treatments(
182240
base_formula: str,
183241
treatments: List[Treatment],
184242
hac_maxlags: Optional[int] = None,
243+
error_model: str = "hac",
244+
arima_order: Optional[Tuple[int, int, int]] = None,
185245
model=None,
186246
**kwargs,
187247
) -> None:
@@ -190,11 +250,11 @@ def _init_from_treatments(
190250
This is a private method called by with_estimated_transforms().
191251
Users should not call this directly - use with_estimated_transforms() instead.
192252
"""
193-
# For MVP, we only support OLS. The model parameter is kept for future
253+
# For MVP, we only support OLS or ARIMAX. The model parameter is kept for future
194254
# compatibility with CausalPy's architecture.
195255
if model is not None:
196256
raise NotImplementedError(
197-
"Custom models not yet supported. MVP uses OLS with HAC standard errors only."
257+
"Custom models not yet supported. Use error_model='hac' or 'arimax'."
198258
)
199259

200260
# Validate inputs
@@ -224,30 +284,76 @@ def _init_from_treatments(
224284
self.X_full = np.column_stack([self.X_baseline, self.Z_treatment])
225285
self.all_labels = self.baseline_labels + self.treatment_labels
226286

227-
# Fit OLS with HAC standard errors
228-
if hac_maxlags is None:
229-
# Newey & West (1994) rule of thumb
230-
n = len(self.y)
231-
hac_maxlags = int(np.floor(4 * (n / 100) ** (2 / 9)))
287+
# Store error model metadata
288+
self.error_model = error_model
289+
self.arima_order = arima_order
232290

233-
self.hac_maxlags = hac_maxlags
291+
# Fit model with chosen error structure
292+
if error_model == "hac":
293+
# Fit OLS with HAC standard errors
294+
if hac_maxlags is None:
295+
# Newey & West (1994) rule of thumb
296+
n = len(self.y)
297+
hac_maxlags = int(np.floor(4 * (n / 100) ** (2 / 9)))
234298

235-
# Fit the model
236-
self.ols_result = sm.OLS(self.y, self.X_full).fit(
237-
cov_type="HAC", cov_kwds={"maxlags": hac_maxlags}
238-
)
299+
self.hac_maxlags = hac_maxlags
300+
301+
# Fit the model
302+
self.ols_result = sm.OLS(self.y, self.X_full).fit(
303+
cov_type="HAC", cov_kwds={"maxlags": hac_maxlags}
304+
)
305+
elif error_model == "arimax":
306+
# Fit ARIMAX model
307+
import warnings
308+
309+
from statsmodels.tsa.statespace.sarimax import SARIMAX
310+
311+
self.hac_maxlags = None # Not used for ARIMAX
312+
313+
# Suppress convergence warnings
314+
with warnings.catch_warnings():
315+
warnings.simplefilter("ignore")
316+
arimax_model = SARIMAX(self.y, exog=self.X_full, order=arima_order)
317+
self.ols_result = arimax_model.fit(
318+
disp=0,
319+
maxiter=200,
320+
method="lbfgs",
321+
)
322+
self.arimax_model = arimax_model
323+
else:
324+
raise ValueError(
325+
f"error_model must be 'hac' or 'arimax', got '{error_model}'"
326+
)
239327

240328
# Extract coefficients
241-
n_baseline = self.X_baseline.shape[1]
242-
self.beta_baseline = self.ols_result.params[:n_baseline]
243-
self.theta_treatment = self.ols_result.params[n_baseline:]
329+
# For ARIMAX, params includes both exog coefficients and ARIMA parameters
330+
# We need to extract only the exogenous variable coefficients
331+
if self.error_model == "arimax":
332+
# ARIMAX: params = [exog_coefs..., arima_params...]
333+
# Use k_exog to get only the exogenous coefficients
334+
n_exog = self.ols_result.model.k_exog
335+
exog_params = self.ols_result.params[:n_exog]
336+
n_baseline = self.X_baseline.shape[1]
337+
self.beta_baseline = exog_params[:n_baseline]
338+
self.theta_treatment = exog_params[n_baseline:]
339+
else:
340+
# OLS: params are just the regression coefficients
341+
n_baseline = self.X_baseline.shape[1]
342+
self.beta_baseline = self.ols_result.params[:n_baseline]
343+
self.theta_treatment = self.ols_result.params[n_baseline:]
244344

245345
# Store predictions and residuals
246346
self.predictions = self.ols_result.fittedvalues
247347
self.residuals = self.ols_result.resid
248348

249-
# Store score (R-squared)
250-
self.score = self.ols_result.rsquared
349+
# Store score (R-squared if available, otherwise compute from residuals)
350+
if hasattr(self.ols_result, "rsquared"):
351+
self.score = self.ols_result.rsquared
352+
else:
353+
# For ARIMAX, compute R-squared manually
354+
ss_res = np.sum(self.residuals**2)
355+
ss_tot = np.sum((self.y - np.mean(self.y)) ** 2)
356+
self.score = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0.0
251357

252358
# Transform estimation metadata (set by with_estimated_transforms)
253359
self.transform_estimation_method = None # "grid", "optimize", or None
@@ -265,6 +371,8 @@ def with_estimated_transforms(
265371
saturation_type: str = "hill",
266372
coef_constraint: str = "nonnegative",
267373
hac_maxlags: Optional[int] = None,
374+
error_model: str = "hac",
375+
arima_order: Optional[Tuple[int, int, int]] = None,
268376
**estimation_kwargs,
269377
) -> "TransferFunctionITS":
270378
"""
@@ -297,7 +405,19 @@ def with_estimated_transforms(
297405
autocorrelation and heteroskedasticity in residuals. Higher values account
298406
for longer-range dependencies but reduce degrees of freedom. If None, uses
299407
the Newey-West rule of thumb: floor(4*(n/100)^(2/9)). For example, with
300-
n=104 observations, the default is hac_maxlags=4.
408+
n=104 observations, the default is hac_maxlags=4. Ignored if error_model="arimax".
409+
error_model : str, default="hac"
410+
Error model specification: "hac" or "arimax".
411+
- "hac": HAC (Newey-West) standard errors. Robust to any autocorrelation
412+
pattern, requires no specification of error structure.
413+
- "arimax": ARIMA(p,d,q) errors with exogenous variables (Box & Tiao 1975).
414+
More efficient when correctly specified, but requires choosing p, d, q orders.
415+
arima_order : tuple of (int, int, int), optional
416+
ARIMA order (p, d, q) when error_model="arimax". Required if error_model="arimax".
417+
- p: Autoregressive order (number of lagged values of the dependent variable)
418+
- d: Differencing order (usually 0 for stationary data; use 1 for trending data)
419+
- q: Moving average order (number of lagged forecast errors)
420+
Example: (1, 0, 0) for AR(1) errors, (1, 0, 1) for ARMA(1,1) errors.
301421
**estimation_kwargs
302422
Additional keyword arguments for the estimation method:
303423
@@ -372,6 +492,17 @@ def with_estimated_transforms(
372492
estimate_transform_params_optimize,
373493
)
374494

495+
# Validate error model parameters
496+
if error_model not in ["hac", "arimax"]:
497+
raise ValueError(
498+
f"error_model must be 'hac' or 'arimax', got '{error_model}'"
499+
)
500+
if error_model == "arimax" and arima_order is None:
501+
raise ValueError(
502+
"arima_order must be provided when error_model='arimax'. "
503+
"E.g., arima_order=(1, 0, 0) for AR(1) errors"
504+
)
505+
375506
# Run parameter estimation
376507
if estimation_method == "grid":
377508
if "saturation_grid" not in estimation_kwargs:
@@ -395,6 +526,8 @@ def with_estimated_transforms(
395526
adstock_grid=estimation_kwargs["adstock_grid"],
396527
coef_constraint=coef_constraint,
397528
hac_maxlags=hac_maxlags,
529+
error_model=error_model,
530+
arima_order=arima_order,
398531
)
399532

400533
search_space = {
@@ -426,6 +559,8 @@ def with_estimated_transforms(
426559
coef_constraint=coef_constraint,
427560
hac_maxlags=hac_maxlags,
428561
method=estimation_kwargs.get("method", "L-BFGS-B"),
562+
error_model=error_model,
563+
arima_order=arima_order,
429564
)
430565

431566
search_space = {
@@ -459,6 +594,8 @@ def with_estimated_transforms(
459594
base_formula=base_formula,
460595
treatments=[treatment],
461596
hac_maxlags=hac_maxlags,
597+
error_model=error_model,
598+
arima_order=arima_order,
462599
)
463600

464601
# Store estimation metadata
@@ -940,10 +1077,16 @@ def summary(self, round_to: Optional[int] = None) -> None:
9401077
print(f"Outcome variable: {self.y_column}")
9411078
print(f"Number of observations: {len(self.y)}")
9421079
print(f"R-squared: {round_num(self.score, round_to)}")
943-
print(
944-
f"HAC max lags: {self.hac_maxlags} "
945-
f"(robust SEs accounting for {self.hac_maxlags} periods of autocorrelation)"
946-
)
1080+
print(f"Error model: {self.error_model.upper()}")
1081+
if self.error_model == "hac":
1082+
print(
1083+
f" HAC max lags: {self.hac_maxlags} "
1084+
f"(robust SEs accounting for {self.hac_maxlags} periods of autocorrelation)"
1085+
)
1086+
elif self.error_model == "arimax":
1087+
p, d, q = self.arima_order
1088+
print(f" ARIMA order: ({p}, {d}, {q})")
1089+
print(f" p={p}: AR order, d={d}: differencing, q={q}: MA order")
9471090
print("-" * 80)
9481091
print("Baseline coefficients:")
9491092
for label, coef, se in zip(

0 commit comments

Comments
 (0)