1717This module implements Transfer-Function ITS for estimating the causal effects
1818of graded interventions in single-market time series using saturation and
1919adstock transforms.
20+
21+ Parameter Estimation
22+ --------------------
23+ Transform parameters (saturation and adstock) are estimated via nested optimization:
24+
25+ 1. **Outer Loop**: Search over transform parameters (saturation slope/kappa,
26+ adstock half-life) using either:
27+ - Grid search: Exhaustive evaluation of discrete parameter combinations
28+ - Continuous optimization: scipy.optimize.minimize for faster convergence
29+
30+ 2. **Inner Loop**: For each candidate set of transform parameters:
31+ - Apply transforms to the treatment variable
32+ - Fit OLS model with HAC standard errors
33+ - Compute RMSE as the optimization metric
34+
35+ 3. **Selection**: The transform parameters that yield the lowest RMSE are selected
36+ as the final estimates. These parameters, along with the OLS coefficients from
37+ the best-fitting model, define the complete fitted model.
38+
39+ This nested approach is efficient because OLS has a closed-form solution, making
40+ the inner loop fast even when evaluating many parameter combinations.
2041"""
2142
2243from typing import Dict , List , Optional , Tuple , Union
@@ -46,26 +67,8 @@ class TransferFunctionITS(BaseExperiment):
4667 spend, policy intensity) in a single market using transfer functions that model
4768 saturation and adstock (carryover) effects.
4869
49- Parameters
50- ----------
51- data : pd.DataFrame
52- Time series data with datetime or numeric index. Must contain the outcome
53- variable and treatment exposure columns.
54- y_column : str
55- Name of the outcome variable column in data.
56- base_formula : str
57- Patsy formula for the baseline model (trend, seasonality, controls).
58- Example: "1 + t + np.sin(2*np.pi*t/52) + np.cos(2*np.pi*t/52)"
59- where t is a time index. FUTURE: Custom helpers like trend(),
60- season_fourier(), holidays() can be added.
61- treatments : List[Treatment]
62- List of Treatment objects specifying channels and their transforms.
63- hac_maxlags : int, optional
64- Maximum lags for Newey-West HAC covariance estimation. Default is
65- int(4 * (n / 100) ** (2/9)) as suggested by Newey & West.
66- model : None
67- Not used in MVP (OLS only), but parameter kept for future Bayesian
68- extension compatibility with CausalPy architecture.
70+ Transform parameters (saturation and adstock) are estimated from the data via
71+ grid search or continuous optimization to find the best fit.
6972
7073 Attributes
7174 ----------
@@ -78,7 +81,7 @@ class TransferFunctionITS(BaseExperiment):
7881 base_formula : str
7982 Baseline model formula.
8083 treatments : List[Treatment]
81- Treatment specifications.
84+ Treatment specifications with estimated transforms .
8285 ols_result : statsmodels.regression.linear_model.RegressionResultsWrapper
8386 Fitted OLS model with HAC standard errors.
8487 beta_baseline : np.ndarray
@@ -89,59 +92,70 @@ class TransferFunctionITS(BaseExperiment):
8992 Fitted values.
9093 residuals : np.ndarray
9194 Model residuals.
95+ transform_estimation_method : str
96+ Method used for parameter estimation ("grid" or "optimize").
97+ transform_estimation_results : dict
98+ Full results from parameter estimation including best_score, best_params.
99+ transform_search_space : dict
100+ Parameter grids or bounds that were searched.
92101
93102 Examples
94103 --------
95104 >>> import causalpy as cp
96105 >>> import pandas as pd
97106 >>> import numpy as np
98107 >>> # Create sample data
99- >>> dates = pd.date_range("2020 -01-01", periods=104, freq="W")
108+ >>> dates = pd.date_range("2022 -01-01", periods=104, freq="W")
100109 >>> df = pd.DataFrame(
101110 ... {
102111 ... "date": dates,
103- ... "sales": np.random.normal(1000, 100, 104),
104- ... "tv_spend": np.random.uniform(0, 10000, 104),
112+ ... "water_consumption": np.random.normal(5000, 500, 104),
113+ ... "comm_intensity": np.random.uniform(0, 10, 104),
114+ ... "temperature": 25 + 10 * np.sin(2 * np.pi * np.arange(104) / 52),
115+ ... "rainfall": 8 - 8 * np.sin(2 * np.pi * np.arange(104) / 52),
105116 ... }
106117 ... )
107118 >>> df = df.set_index("date")
108- >>> # Add time index for formula
109119 >>> df["t"] = np.arange(len(df))
110- >>> # Define treatment with saturation and adstock
111- >>> treatment = cp.Treatment(
112- ... name="tv_spend",
113- ... transforms=[
114- ... cp.Saturation(kind="hill", slope=2.0, kappa=5000),
115- ... cp.Adstock(half_life=3, normalize=True),
116- ... ],
117- ... )
118- >>> # Fit model
119- >>> result = cp.TransferFunctionITS(
120+ >>>
121+ >>> # Estimate transform parameters via grid search
122+ >>> result = cp.TransferFunctionITS.with_estimated_transforms(
120123 ... data=df,
121- ... y_column="sales",
122- ... base_formula="1 + t",
123- ... treatments=[treatment],
124- ... hac_maxlags=8,
124+ ... y_column="water_consumption",
125+ ... treatment_name="comm_intensity",
126+ ... base_formula="1 + t + temperature + rainfall",
127+ ... estimation_method="grid",
128+ ... saturation_type="hill",
129+ ... saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
130+ ... adstock_grid={"half_life": [2, 3, 4, 5]},
125131 ... )
126- >>> # Estimate effect of zeroing TV spend in weeks 50-60
132+ >>>
133+ >>> # View estimated parameters
134+ >>> print(result.transform_estimation_results["best_params"])
135+ >>>
136+ >>> # Estimate effect of policy over entire period
127137 >>> effect_result = result.effect(
128- ... window=(df.index[50 ], df.index[60 ]), channels=["tv_spend "], scale=0.0
138+ ... window=(df.index[0 ], df.index[-1 ]), channels=["comm_intensity "], scale=0.0
129139 ... )
130- >>> # Plot results
140+ >>> print(f"Total effect: {effect_result['total_effect']:.2f}")
141+ >>>
142+ >>> # Visualize results
131143 >>> result.plot()
132- >>> # Show diagnostics
133144 >>> result.diagnostics()
134145
135146 Notes
136147 -----
137- **MVP Limitations:**
138- - OLS with HAC standard errors only (no Bayesian inference)
139- - Point estimates only (no bootstrap uncertainty intervals)
140- - Fixed transform parameters (no grid search)
141- - Basic diagnostics only
148+ **Instantiation:**
149+ Models are created via the `with_estimated_transforms()` class method, which
150+ estimates optimal transform parameters from the data. Direct instantiation
151+ is not supported.
152+
153+ **Transform Estimation:**
154+ Two methods are available:
155+ - Grid search: Exhaustive search over discrete parameter values (slower, guaranteed best)
156+ - Continuous optimization: Uses scipy.optimize (faster, may find local optima)
142157
143158 **Future Extensions:**
144- - Grid search for optimal transform parameters (estimate_transforms=True)
145159 - Bootstrap or asymptotic confidence intervals for effects
146160 - Additional error models (GLSAR, ARIMAX)
147161 - Bayesian inference via PyMC model (reusing transform pipeline)
@@ -158,7 +172,7 @@ class TransferFunctionITS(BaseExperiment):
158172 supports_ols = True
159173 supports_bayes = False # FUTURE: Will be True when PyMC model is implemented
160174
161- def __init__ (
175+ def _init_from_treatments (
162176 self ,
163177 data : pd .DataFrame ,
164178 y_column : str ,
@@ -168,7 +182,11 @@ def __init__(
168182 model = None ,
169183 ** kwargs ,
170184 ) -> None :
171- """Initialize and fit the Transfer Function ITS model."""
185+ """Initialize and fit the Transfer Function ITS model with given treatments.
186+
187+ This is a private method called by with_estimated_transforms().
188+ Users should not call this directly - use with_estimated_transforms() instead.
189+ """
172190 # For MVP, we only support OLS. The model parameter is kept for future
173191 # compatibility with CausalPy's architecture.
174192 if model is not None :
@@ -424,8 +442,9 @@ def with_estimated_transforms(
424442 coef_constraint = coef_constraint ,
425443 )
426444
427- # Create TransferFunctionITS with estimated transforms
428- result = cls (
445+ # Create TransferFunctionITS instance and initialize with estimated transforms
446+ result = cls .__new__ (cls )
447+ result ._init_from_treatments (
429448 data = data ,
430449 y_column = y_column ,
431450 base_formula = base_formula ,
0 commit comments