2929
30302. **Inner Loop**: For each candidate set of transform parameters:
3131 - Apply transforms to the treatment variable
32- - Fit OLS model with HAC standard errors
32+ - Fit regression model (OLS with HAC or ARIMAX)
3333 - Compute RMSE as the optimization metric
3434
35353. **Selection**: The transform parameters that yield the lowest RMSE are selected
36- as the final estimates. These parameters, along with the OLS coefficients from
37- the best-fitting model, define the complete fitted model.
36+ as the final estimates. These parameters, along with the regression coefficients
37+ from the best-fitting model, define the complete fitted model.
3838
3939This nested approach is efficient because OLS has a closed-form solution, making
4040the inner loop fast even when evaluating many parameter combinations.
41+
42+ Error Models
43+ ------------
44+ Two error model options are available:
45+
46+ **HAC (Heteroskedasticity and Autocorrelation Consistent) Standard Errors:**
47+ - Default and recommended for most applications
48+ - Uses Newey-West robust standard errors
49+ - No specification required - automatically robust to autocorrelation
50+ - Works with any autocorrelation pattern
51+ - More conservative (wider confidence intervals)
52+
53+ **ARIMAX (ARIMA with eXogenous variables):**
54+ - Explicitly models ARIMA(p,d,q) structure of residuals
55+ - More efficient (narrower confidence intervals) when correctly specified
56+ - Requires manual specification of (p,d,q) orders
57+ - Sensitive to misspecification
58+ - Follows classical Box & Tiao (1975) intervention analysis
59+
60+ Choose HAC for robust, specification-free inference. Use ARIMAX when you have
61+ strong evidence for a specific ARIMA structure and want maximum efficiency.
4162"""
4263
4364from typing import Dict , List , Optional , Tuple , Union
@@ -83,7 +104,7 @@ class TransferFunctionITS(BaseExperiment):
83104 treatments : List[Treatment]
84105 Treatment specifications with estimated transforms.
85106 ols_result : statsmodels.regression.linear_model.RegressionResultsWrapper
86- Fitted OLS model with HAC standard errors.
107+ Fitted regression model (OLS with HAC standard errors or ARIMAX) .
87108 beta_baseline : np.ndarray
88109 Baseline model coefficients.
89110 theta_treatment : np.ndarray
@@ -98,9 +119,17 @@ class TransferFunctionITS(BaseExperiment):
98119 Full results from parameter estimation including best_score, best_params.
99120 transform_search_space : dict
100121 Parameter grids or bounds that were searched.
122+ error_model : str
123+ Error model type: "hac" (default) or "arimax".
124+ arima_order : Optional[Tuple[int, int, int]]
125+ ARIMA(p,d,q) order when error_model="arimax". None for HAC.
126+ hac_maxlags : Optional[int]
127+ Maximum lags for HAC standard errors. None for ARIMAX.
101128
102129 Examples
103130 --------
131+ **Example 1: HAC Standard Errors (Default)**
132+
104133 .. code-block:: python
105134
106135 import causalpy as cp
@@ -121,8 +150,8 @@ class TransferFunctionITS(BaseExperiment):
121150 df = df.set_index("date")
122151 df["t"] = np.arange(len(df))
123152
124- # Estimate transform parameters via grid search
125- result = cp.TransferFunctionITS.with_estimated_transforms(
153+ # Estimate transform parameters via grid search with HAC errors
154+ result_hac = cp.TransferFunctionITS.with_estimated_transforms(
126155 data=df,
127156 y_column="water_consumption",
128157 treatment_name="comm_intensity",
@@ -131,20 +160,45 @@ class TransferFunctionITS(BaseExperiment):
131160 saturation_type="hill",
132161 saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
133162 adstock_grid={"half_life": [2, 3, 4, 5]},
163+ error_model="hac", # HAC standard errors (default)
134164 )
135165
136- # View estimated parameters
137- print(result.transform_estimation_results["best_params"])
166+ # View estimated parameters and summary
167+ print(result_hac.transform_estimation_results["best_params"])
168+ result_hac.summary()
138169
139170 # Estimate effect of policy over entire period
140- effect_result = result .effect(
171+ effect_result = result_hac .effect(
141172 window=(df.index[0], df.index[-1]), channels=["comm_intensity"], scale=0.0
142173 )
143174 print(f"Total effect: {effect_result['total_effect']:.2f}")
144175
145176 # Visualize results
146- result.plot()
147- result.diagnostics()
177+ result_hac.plot()
178+ result_hac.diagnostics()
179+
180+ **Example 2: ARIMAX Error Model**
181+
182+ .. code-block:: python
183+
184+ # Fit with ARIMAX errors if you know the error structure
185+ # (use ACF/PACF plots to determine p, d, q)
186+ result_arimax = cp.TransferFunctionITS.with_estimated_transforms(
187+ data=df,
188+ y_column="water_consumption",
189+ treatment_name="comm_intensity",
190+ base_formula="1 + t + temperature + rainfall",
191+ estimation_method="grid",
192+ saturation_type="hill",
193+ saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
194+ adstock_grid={"half_life": [2, 3, 4, 5]},
195+ error_model="arimax", # Use ARIMAX
196+ arima_order=(1, 0, 0), # AR(1) errors: p=1 (AR), d=0 (no diff), q=0 (no MA)
197+ )
198+
199+ # ARIMAX typically gives narrower confidence intervals
200+ # when the ARIMA structure is correctly specified
201+ result_arimax.summary() # Shows ARIMA order details
148202
149203 Notes
150204 -----
@@ -158,9 +212,13 @@ class TransferFunctionITS(BaseExperiment):
158212 - Grid search: Exhaustive search over discrete parameter values (slower, guaranteed best)
159213 - Continuous optimization: Uses scipy.optimize (faster, may find local optima)
160214
215+ **Error Model Selection:**
216+ - HAC (default): Robust standard errors, no specification required
217+ - ARIMAX: More efficient when ARIMA structure is known, requires (p,d,q) specification
218+
161219 **Future Extensions:**
162220 - Bootstrap or asymptotic confidence intervals for effects
163- - Additional error models (GLSAR, ARIMAX )
221+ - Additional error models (GLSAR for known autocorrelation structure )
164222 - Bayesian inference via PyMC model (reusing transform pipeline)
165223 - Custom formula helpers (trend(), season_fourier(), holidays())
166224 - Multi-channel collinearity diagnostics
@@ -182,6 +240,8 @@ def _init_from_treatments(
182240 base_formula : str ,
183241 treatments : List [Treatment ],
184242 hac_maxlags : Optional [int ] = None ,
243+ error_model : str = "hac" ,
244+ arima_order : Optional [Tuple [int , int , int ]] = None ,
185245 model = None ,
186246 ** kwargs ,
187247 ) -> None :
@@ -190,11 +250,11 @@ def _init_from_treatments(
190250 This is a private method called by with_estimated_transforms().
191251 Users should not call this directly - use with_estimated_transforms() instead.
192252 """
193- # For MVP, we only support OLS. The model parameter is kept for future
253+ # For MVP, we only support OLS or ARIMAX . The model parameter is kept for future
194254 # compatibility with CausalPy's architecture.
195255 if model is not None :
196256 raise NotImplementedError (
197- "Custom models not yet supported. MVP uses OLS with HAC standard errors only ."
257+ "Custom models not yet supported. Use error_model='hac' or 'arimax' ."
198258 )
199259
200260 # Validate inputs
@@ -224,30 +284,76 @@ def _init_from_treatments(
224284 self .X_full = np .column_stack ([self .X_baseline , self .Z_treatment ])
225285 self .all_labels = self .baseline_labels + self .treatment_labels
226286
227- # Fit OLS with HAC standard errors
228- if hac_maxlags is None :
229- # Newey & West (1994) rule of thumb
230- n = len (self .y )
231- hac_maxlags = int (np .floor (4 * (n / 100 ) ** (2 / 9 )))
287+ # Store error model metadata
288+ self .error_model = error_model
289+ self .arima_order = arima_order
232290
233- self .hac_maxlags = hac_maxlags
291+ # Fit model with chosen error structure
292+ if error_model == "hac" :
293+ # Fit OLS with HAC standard errors
294+ if hac_maxlags is None :
295+ # Newey & West (1994) rule of thumb
296+ n = len (self .y )
297+ hac_maxlags = int (np .floor (4 * (n / 100 ) ** (2 / 9 )))
234298
235- # Fit the model
236- self .ols_result = sm .OLS (self .y , self .X_full ).fit (
237- cov_type = "HAC" , cov_kwds = {"maxlags" : hac_maxlags }
238- )
299+ self .hac_maxlags = hac_maxlags
300+
301+ # Fit the model
302+ self .ols_result = sm .OLS (self .y , self .X_full ).fit (
303+ cov_type = "HAC" , cov_kwds = {"maxlags" : hac_maxlags }
304+ )
305+ elif error_model == "arimax" :
306+ # Fit ARIMAX model
307+ import warnings
308+
309+ from statsmodels .tsa .statespace .sarimax import SARIMAX
310+
311+ self .hac_maxlags = None # Not used for ARIMAX
312+
313+ # Suppress convergence warnings
314+ with warnings .catch_warnings ():
315+ warnings .simplefilter ("ignore" )
316+ arimax_model = SARIMAX (self .y , exog = self .X_full , order = arima_order )
317+ self .ols_result = arimax_model .fit (
318+ disp = 0 ,
319+ maxiter = 200 ,
320+ method = "lbfgs" ,
321+ )
322+ self .arimax_model = arimax_model
323+ else :
324+ raise ValueError (
325+ f"error_model must be 'hac' or 'arimax', got '{ error_model } '"
326+ )
239327
240328 # Extract coefficients
241- n_baseline = self .X_baseline .shape [1 ]
242- self .beta_baseline = self .ols_result .params [:n_baseline ]
243- self .theta_treatment = self .ols_result .params [n_baseline :]
329+ # For ARIMAX, params includes both exog coefficients and ARIMA parameters
330+ # We need to extract only the exogenous variable coefficients
331+ if self .error_model == "arimax" :
332+ # ARIMAX: params = [exog_coefs..., arima_params...]
333+ # Use k_exog to get only the exogenous coefficients
334+ n_exog = self .ols_result .model .k_exog
335+ exog_params = self .ols_result .params [:n_exog ]
336+ n_baseline = self .X_baseline .shape [1 ]
337+ self .beta_baseline = exog_params [:n_baseline ]
338+ self .theta_treatment = exog_params [n_baseline :]
339+ else :
340+ # OLS: params are just the regression coefficients
341+ n_baseline = self .X_baseline .shape [1 ]
342+ self .beta_baseline = self .ols_result .params [:n_baseline ]
343+ self .theta_treatment = self .ols_result .params [n_baseline :]
244344
245345 # Store predictions and residuals
246346 self .predictions = self .ols_result .fittedvalues
247347 self .residuals = self .ols_result .resid
248348
249- # Store score (R-squared)
250- self .score = self .ols_result .rsquared
349+ # Store score (R-squared if available, otherwise compute from residuals)
350+ if hasattr (self .ols_result , "rsquared" ):
351+ self .score = self .ols_result .rsquared
352+ else :
353+ # For ARIMAX, compute R-squared manually
354+ ss_res = np .sum (self .residuals ** 2 )
355+ ss_tot = np .sum ((self .y - np .mean (self .y )) ** 2 )
356+ self .score = 1 - (ss_res / ss_tot ) if ss_tot > 0 else 0.0
251357
252358 # Transform estimation metadata (set by with_estimated_transforms)
253359 self .transform_estimation_method = None # "grid", "optimize", or None
@@ -265,6 +371,8 @@ def with_estimated_transforms(
265371 saturation_type : str = "hill" ,
266372 coef_constraint : str = "nonnegative" ,
267373 hac_maxlags : Optional [int ] = None ,
374+ error_model : str = "hac" ,
375+ arima_order : Optional [Tuple [int , int , int ]] = None ,
268376 ** estimation_kwargs ,
269377 ) -> "TransferFunctionITS" :
270378 """
@@ -297,7 +405,19 @@ def with_estimated_transforms(
297405 autocorrelation and heteroskedasticity in residuals. Higher values account
298406 for longer-range dependencies but reduce degrees of freedom. If None, uses
299407 the Newey-West rule of thumb: floor(4*(n/100)^(2/9)). For example, with
300- n=104 observations, the default is hac_maxlags=4.
408+ n=104 observations, the default is hac_maxlags=4. Ignored if error_model="arimax".
409+ error_model : str, default="hac"
410+ Error model specification: "hac" or "arimax".
411+ - "hac": HAC (Newey-West) standard errors. Robust to any autocorrelation
412+ pattern, requires no specification of error structure.
413+ - "arimax": ARIMA(p,d,q) errors with exogenous variables (Box & Tiao 1975).
414+ More efficient when correctly specified, but requires choosing p, d, q orders.
415+ arima_order : tuple of (int, int, int), optional
416+ ARIMA order (p, d, q) when error_model="arimax". Required if error_model="arimax".
417+ - p: Autoregressive order (number of lagged values of the dependent variable)
418+ - d: Differencing order (usually 0 for stationary data; use 1 for trending data)
419+ - q: Moving average order (number of lagged forecast errors)
420+ Example: (1, 0, 0) for AR(1) errors, (1, 0, 1) for ARMA(1,1) errors.
301421 **estimation_kwargs
302422 Additional keyword arguments for the estimation method:
303423
@@ -372,6 +492,17 @@ def with_estimated_transforms(
372492 estimate_transform_params_optimize ,
373493 )
374494
495+ # Validate error model parameters
496+ if error_model not in ["hac" , "arimax" ]:
497+ raise ValueError (
498+ f"error_model must be 'hac' or 'arimax', got '{ error_model } '"
499+ )
500+ if error_model == "arimax" and arima_order is None :
501+ raise ValueError (
502+ "arima_order must be provided when error_model='arimax'. "
503+ "E.g., arima_order=(1, 0, 0) for AR(1) errors"
504+ )
505+
375506 # Run parameter estimation
376507 if estimation_method == "grid" :
377508 if "saturation_grid" not in estimation_kwargs :
@@ -395,6 +526,8 @@ def with_estimated_transforms(
395526 adstock_grid = estimation_kwargs ["adstock_grid" ],
396527 coef_constraint = coef_constraint ,
397528 hac_maxlags = hac_maxlags ,
529+ error_model = error_model ,
530+ arima_order = arima_order ,
398531 )
399532
400533 search_space = {
@@ -426,6 +559,8 @@ def with_estimated_transforms(
426559 coef_constraint = coef_constraint ,
427560 hac_maxlags = hac_maxlags ,
428561 method = estimation_kwargs .get ("method" , "L-BFGS-B" ),
562+ error_model = error_model ,
563+ arima_order = arima_order ,
429564 )
430565
431566 search_space = {
@@ -459,6 +594,8 @@ def with_estimated_transforms(
459594 base_formula = base_formula ,
460595 treatments = [treatment ],
461596 hac_maxlags = hac_maxlags ,
597+ error_model = error_model ,
598+ arima_order = arima_order ,
462599 )
463600
464601 # Store estimation metadata
@@ -940,10 +1077,16 @@ def summary(self, round_to: Optional[int] = None) -> None:
9401077 print (f"Outcome variable: { self .y_column } " )
9411078 print (f"Number of observations: { len (self .y )} " )
9421079 print (f"R-squared: { round_num (self .score , round_to )} " )
943- print (
944- f"HAC max lags: { self .hac_maxlags } "
945- f"(robust SEs accounting for { self .hac_maxlags } periods of autocorrelation)"
946- )
1080+ print (f"Error model: { self .error_model .upper ()} " )
1081+ if self .error_model == "hac" :
1082+ print (
1083+ f" HAC max lags: { self .hac_maxlags } "
1084+ f"(robust SEs accounting for { self .hac_maxlags } periods of autocorrelation)"
1085+ )
1086+ elif self .error_model == "arimax" :
1087+ p , d , q = self .arima_order
1088+ print (f" ARIMA order: ({ p } , { d } , { q } )" )
1089+ print (f" p={ p } : AR order, d={ d } : differencing, q={ q } : MA order" )
9471090 print ("-" * 80 )
9481091 print ("Baseline coefficients:" )
9491092 for label , coef , se in zip (
0 commit comments