Skip to content

Commit ec98c8a

Browse files
committed
Improve TF-ITS docs and parameter estimation details
Expanded and clarified docstrings in transfer_function_its.py to document the nested parameter estimation approach for saturation and adstock transforms. Updated the example and usage instructions to reflect the new estimation workflow. Revised the notebook to demonstrate transform parameter estimation via grid search, show parameter recovery, and clarify the distinction between grid search and continuous optimization. Removed the outdated and redundant test class for TransferFunctionITS in test_transfer_function_its.py.
1 parent 659b502 commit ec98c8a

File tree

3 files changed

+361
-785
lines changed

3 files changed

+361
-785
lines changed

causalpy/experiments/transfer_function_its.py

Lines changed: 72 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,27 @@
1717
This module implements Transfer-Function ITS for estimating the causal effects
1818
of graded interventions in single-market time series using saturation and
1919
adstock transforms.
20+
21+
Parameter Estimation
22+
--------------------
23+
Transform parameters (saturation and adstock) are estimated via nested optimization:
24+
25+
1. **Outer Loop**: Search over transform parameters (saturation slope/kappa,
26+
adstock half-life) using either:
27+
- Grid search: Exhaustive evaluation of discrete parameter combinations
28+
- Continuous optimization: scipy.optimize.minimize for faster convergence
29+
30+
2. **Inner Loop**: For each candidate set of transform parameters:
31+
- Apply transforms to the treatment variable
32+
- Fit OLS model with HAC standard errors
33+
- Compute RMSE as the optimization metric
34+
35+
3. **Selection**: The transform parameters that yield the lowest RMSE are selected
36+
as the final estimates. These parameters, along with the OLS coefficients from
37+
the best-fitting model, define the complete fitted model.
38+
39+
This nested approach is efficient because OLS has a closed-form solution, making
40+
the inner loop fast even when evaluating many parameter combinations.
2041
"""
2142

2243
from typing import Dict, List, Optional, Tuple, Union
@@ -46,26 +67,8 @@ class TransferFunctionITS(BaseExperiment):
4667
spend, policy intensity) in a single market using transfer functions that model
4768
saturation and adstock (carryover) effects.
4869
49-
Parameters
50-
----------
51-
data : pd.DataFrame
52-
Time series data with datetime or numeric index. Must contain the outcome
53-
variable and treatment exposure columns.
54-
y_column : str
55-
Name of the outcome variable column in data.
56-
base_formula : str
57-
Patsy formula for the baseline model (trend, seasonality, controls).
58-
Example: "1 + t + np.sin(2*np.pi*t/52) + np.cos(2*np.pi*t/52)"
59-
where t is a time index. FUTURE: Custom helpers like trend(),
60-
season_fourier(), holidays() can be added.
61-
treatments : List[Treatment]
62-
List of Treatment objects specifying channels and their transforms.
63-
hac_maxlags : int, optional
64-
Maximum lags for Newey-West HAC covariance estimation. Default is
65-
int(4 * (n / 100) ** (2/9)) as suggested by Newey & West.
66-
model : None
67-
Not used in MVP (OLS only), but parameter kept for future Bayesian
68-
extension compatibility with CausalPy architecture.
70+
Transform parameters (saturation and adstock) are estimated from the data via
71+
grid search or continuous optimization to find the best fit.
6972
7073
Attributes
7174
----------
@@ -78,7 +81,7 @@ class TransferFunctionITS(BaseExperiment):
7881
base_formula : str
7982
Baseline model formula.
8083
treatments : List[Treatment]
81-
Treatment specifications.
84+
Treatment specifications with estimated transforms.
8285
ols_result : statsmodels.regression.linear_model.RegressionResultsWrapper
8386
Fitted OLS model with HAC standard errors.
8487
beta_baseline : np.ndarray
@@ -89,59 +92,70 @@ class TransferFunctionITS(BaseExperiment):
8992
Fitted values.
9093
residuals : np.ndarray
9194
Model residuals.
95+
transform_estimation_method : str
96+
Method used for parameter estimation ("grid" or "optimize").
97+
transform_estimation_results : dict
98+
Full results from parameter estimation including best_score, best_params.
99+
transform_search_space : dict
100+
Parameter grids or bounds that were searched.
92101
93102
Examples
94103
--------
95104
>>> import causalpy as cp
96105
>>> import pandas as pd
97106
>>> import numpy as np
98107
>>> # Create sample data
99-
>>> dates = pd.date_range("2020-01-01", periods=104, freq="W")
108+
>>> dates = pd.date_range("2022-01-01", periods=104, freq="W")
100109
>>> df = pd.DataFrame(
101110
... {
102111
... "date": dates,
103-
... "sales": np.random.normal(1000, 100, 104),
104-
... "tv_spend": np.random.uniform(0, 10000, 104),
112+
... "water_consumption": np.random.normal(5000, 500, 104),
113+
... "comm_intensity": np.random.uniform(0, 10, 104),
114+
... "temperature": 25 + 10 * np.sin(2 * np.pi * np.arange(104) / 52),
115+
... "rainfall": 8 - 8 * np.sin(2 * np.pi * np.arange(104) / 52),
105116
... }
106117
... )
107118
>>> df = df.set_index("date")
108-
>>> # Add time index for formula
109119
>>> df["t"] = np.arange(len(df))
110-
>>> # Define treatment with saturation and adstock
111-
>>> treatment = cp.Treatment(
112-
... name="tv_spend",
113-
... transforms=[
114-
... cp.Saturation(kind="hill", slope=2.0, kappa=5000),
115-
... cp.Adstock(half_life=3, normalize=True),
116-
... ],
117-
... )
118-
>>> # Fit model
119-
>>> result = cp.TransferFunctionITS(
120+
>>>
121+
>>> # Estimate transform parameters via grid search
122+
>>> result = cp.TransferFunctionITS.with_estimated_transforms(
120123
... data=df,
121-
... y_column="sales",
122-
... base_formula="1 + t",
123-
... treatments=[treatment],
124-
... hac_maxlags=8,
124+
... y_column="water_consumption",
125+
... treatment_name="comm_intensity",
126+
... base_formula="1 + t + temperature + rainfall",
127+
... estimation_method="grid",
128+
... saturation_type="hill",
129+
... saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
130+
... adstock_grid={"half_life": [2, 3, 4, 5]},
125131
... )
126-
>>> # Estimate effect of zeroing TV spend in weeks 50-60
132+
>>>
133+
>>> # View estimated parameters
134+
>>> print(result.transform_estimation_results["best_params"])
135+
>>>
136+
>>> # Estimate effect of policy over entire period
127137
>>> effect_result = result.effect(
128-
... window=(df.index[50], df.index[60]), channels=["tv_spend"], scale=0.0
138+
... window=(df.index[0], df.index[-1]), channels=["comm_intensity"], scale=0.0
129139
... )
130-
>>> # Plot results
140+
>>> print(f"Total effect: {effect_result['total_effect']:.2f}")
141+
>>>
142+
>>> # Visualize results
131143
>>> result.plot()
132-
>>> # Show diagnostics
133144
>>> result.diagnostics()
134145
135146
Notes
136147
-----
137-
**MVP Limitations:**
138-
- OLS with HAC standard errors only (no Bayesian inference)
139-
- Point estimates only (no bootstrap uncertainty intervals)
140-
- Fixed transform parameters (no grid search)
141-
- Basic diagnostics only
148+
**Instantiation:**
149+
Models are created via the `with_estimated_transforms()` class method, which
150+
estimates optimal transform parameters from the data. Direct instantiation
151+
is not supported.
152+
153+
**Transform Estimation:**
154+
Two methods are available:
155+
- Grid search: Exhaustive search over discrete parameter values (slower, guaranteed best)
156+
- Continuous optimization: Uses scipy.optimize (faster, may find local optima)
142157
143158
**Future Extensions:**
144-
- Grid search for optimal transform parameters (estimate_transforms=True)
145159
- Bootstrap or asymptotic confidence intervals for effects
146160
- Additional error models (GLSAR, ARIMAX)
147161
- Bayesian inference via PyMC model (reusing transform pipeline)
@@ -158,7 +172,7 @@ class TransferFunctionITS(BaseExperiment):
158172
supports_ols = True
159173
supports_bayes = False # FUTURE: Will be True when PyMC model is implemented
160174

161-
def __init__(
175+
def _init_from_treatments(
162176
self,
163177
data: pd.DataFrame,
164178
y_column: str,
@@ -168,7 +182,11 @@ def __init__(
168182
model=None,
169183
**kwargs,
170184
) -> None:
171-
"""Initialize and fit the Transfer Function ITS model."""
185+
"""Initialize and fit the Transfer Function ITS model with given treatments.
186+
187+
This is a private method called by with_estimated_transforms().
188+
Users should not call this directly - use with_estimated_transforms() instead.
189+
"""
172190
# For MVP, we only support OLS. The model parameter is kept for future
173191
# compatibility with CausalPy's architecture.
174192
if model is not None:
@@ -424,8 +442,9 @@ def with_estimated_transforms(
424442
coef_constraint=coef_constraint,
425443
)
426444

427-
# Create TransferFunctionITS with estimated transforms
428-
result = cls(
445+
# Create TransferFunctionITS instance and initialize with estimated transforms
446+
result = cls.__new__(cls)
447+
result._init_from_treatments(
429448
data=data,
430449
y_column=y_column,
431450
base_formula=base_formula,

0 commit comments

Comments
 (0)