Skip to content

Commit 659b502

Browse files
committed
Refactor transform strategy and add parameter estimation
Refactored transform classes to use a strategy pattern with explicit Adstock, Saturation, and Lag implementations. Added transform_optimization.py for grid search and optimization of transform parameters. Updated TransferFunctionITS to support transform parameter estimation and metadata. Revised tests to use new transform classes and parameter estimation workflows.
1 parent 803b076 commit 659b502

File tree

7 files changed

+1556
-391
lines changed

7 files changed

+1556
-391
lines changed

causalpy/__init__.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,38 @@
2727
from .experiments.regression_kink import RegressionKink
2828
from .experiments.synthetic_control import SyntheticControl
2929
from .experiments.transfer_function_its import TransferFunctionITS
30-
from .transforms import Adstock, Lag, Saturation, Treatment
30+
from .transforms import (
31+
AdstockTransform,
32+
DiscreteLag,
33+
GeometricAdstock,
34+
HillSaturation,
35+
LagTransform,
36+
LogisticSaturation,
37+
MichaelisMentenSaturation,
38+
SaturationTransform,
39+
Treatment,
40+
)
3141

3242
__all__ = [
3343
"__version__",
34-
"Adstock",
44+
"AdstockTransform",
3545
"DifferenceInDifferences",
3646
"create_causalpy_compatible_class",
47+
"DiscreteLag",
48+
"GeometricAdstock",
49+
"HillSaturation",
3750
"InstrumentalVariable",
3851
"InterruptedTimeSeries",
3952
"InversePropensityWeighting",
40-
"Lag",
53+
"LagTransform",
4154
"load_data",
55+
"LogisticSaturation",
56+
"MichaelisMentenSaturation",
4257
"PrePostNEGD",
4358
"pymc_models",
4459
"RegressionDiscontinuity",
4560
"RegressionKink",
46-
"Saturation",
61+
"SaturationTransform",
4762
"skl_models",
4863
"SyntheticControl",
4964
"TransferFunctionITS",

causalpy/experiments/transfer_function_its.py

Lines changed: 236 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from statsmodels.stats.diagnostic import acorr_ljungbox
3131

3232
from causalpy.custom_exceptions import BadIndexException
33-
from causalpy.transforms import Adstock, Treatment, apply_treatment_transforms
33+
from causalpy.transforms import Treatment
3434
from causalpy.utils import round_num
3535

3636
from .base import BaseExperiment
@@ -228,6 +228,218 @@ def __init__(
228228
# Store score (R-squared)
229229
self.score = self.ols_result.rsquared
230230

231+
# Transform estimation metadata (set by with_estimated_transforms)
232+
self.transform_estimation_method = None # "grid", "optimize", or None
233+
self.transform_estimation_results = None # Full results dict
234+
self.transform_search_space = None # Grid or bounds that were searched
235+
236+
@classmethod
237+
def with_estimated_transforms(
238+
cls,
239+
data: pd.DataFrame,
240+
y_column: str,
241+
treatment_name: str,
242+
base_formula: str,
243+
estimation_method: str = "grid",
244+
saturation_type: str = "hill",
245+
coef_constraint: str = "nonnegative",
246+
hac_maxlags: Optional[int] = None,
247+
**estimation_kwargs,
248+
) -> "TransferFunctionITS":
249+
"""
250+
Create a TransferFunctionITS with transform parameters estimated from data.
251+
252+
This method estimates optimal saturation and adstock parameters via grid
253+
search or continuous optimization, then creates a TransferFunctionITS
254+
instance with those estimated transforms.
255+
256+
Parameters
257+
----------
258+
data : pd.DataFrame
259+
Time series data with datetime or numeric index.
260+
y_column : str
261+
Name of the outcome variable column in data.
262+
treatment_name : str
263+
Name of the treatment variable column in data.
264+
base_formula : str
265+
Patsy formula for the baseline model (trend, seasonality, controls).
266+
estimation_method : str, default="grid"
267+
Method for parameter estimation: "grid" or "optimize".
268+
- "grid": Grid search over discrete parameter values
269+
- "optimize": Continuous optimization using scipy.optimize
270+
saturation_type : str, default="hill"
271+
Type of saturation function: "hill", "logistic", or "michaelis_menten".
272+
coef_constraint : str, default="nonnegative"
273+
Constraint on treatment coefficient ("nonnegative" or "unconstrained").
274+
hac_maxlags : int, optional
275+
Maximum lags for HAC standard errors. If None, uses rule of thumb.
276+
**estimation_kwargs
277+
Additional keyword arguments for the estimation method:
278+
279+
For grid search (estimation_method="grid"):
280+
- saturation_grid : dict
281+
Dictionary mapping parameter names to lists of values.
282+
E.g., {"slope": [1.0, 2.0], "kappa": [3, 5, 7]}
283+
- adstock_grid : dict
284+
Dictionary mapping parameter names to lists of values.
285+
E.g., {"half_life": [2, 3, 4]}
286+
287+
For optimization (estimation_method="optimize"):
288+
- saturation_bounds : dict
289+
Dictionary mapping parameter names to (min, max) tuples.
290+
E.g., {"slope": (0.5, 5.0), "kappa": (2, 10)}
291+
- adstock_bounds : dict
292+
Dictionary mapping parameter names to (min, max) tuples.
293+
E.g., {"half_life": (1, 10)}
294+
- initial_params : dict, optional
295+
Initial parameter values for optimization.
296+
- method : str, default="L-BFGS-B"
297+
Scipy optimization method.
298+
299+
Returns
300+
-------
301+
TransferFunctionITS
302+
Fitted model with estimated transform parameters.
303+
304+
Examples
305+
--------
306+
>>> # Grid search example
307+
>>> result = TransferFunctionITS.with_estimated_transforms(
308+
... data=df,
309+
... y_column="water_consumption",
310+
... treatment_name="comm_intensity",
311+
... base_formula="1 + t + temperature + rainfall",
312+
... estimation_method="grid",
313+
... saturation_type="hill",
314+
... saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
315+
... adstock_grid={"half_life": [2, 3, 4, 5]},
316+
... )
317+
>>> print(f"Best RMSE: {result.transform_estimation_results['best_score']:.2f}")
318+
319+
>>> # Optimization example
320+
>>> result = TransferFunctionITS.with_estimated_transforms(
321+
... data=df,
322+
... y_column="water_consumption",
323+
... treatment_name="comm_intensity",
324+
... base_formula="1 + t + temperature + rainfall",
325+
... estimation_method="optimize",
326+
... saturation_type="hill",
327+
... saturation_bounds={"slope": (0.5, 5.0), "kappa": (2, 10)},
328+
... adstock_bounds={"half_life": (1, 10)},
329+
... initial_params={"slope": 2.0, "kappa": 5.0, "half_life": 4.0},
330+
... )
331+
332+
Notes
333+
-----
334+
This method performs nested optimization:
335+
- Outer loop: Search over transform parameters
336+
- Inner loop: Fit OLS for each set of transform parameters
337+
- Objective: Minimize RMSE
338+
339+
Grid search is exhaustive but can be slow for large grids. Continuous
340+
optimization is faster but may find local optima. Consider using grid
341+
search first to find good starting points for optimization.
342+
"""
343+
from causalpy.transform_optimization import (
344+
estimate_transform_params_grid,
345+
estimate_transform_params_optimize,
346+
)
347+
348+
# Run parameter estimation
349+
if estimation_method == "grid":
350+
if "saturation_grid" not in estimation_kwargs:
351+
raise ValueError(
352+
"saturation_grid is required for grid search method. "
353+
"E.g., saturation_grid={'slope': [1.0, 2.0], 'kappa': [3, 5]}"
354+
)
355+
if "adstock_grid" not in estimation_kwargs:
356+
raise ValueError(
357+
"adstock_grid is required for grid search method. "
358+
"E.g., adstock_grid={'half_life': [2, 3, 4]}"
359+
)
360+
361+
est_results = estimate_transform_params_grid(
362+
data=data,
363+
y_column=y_column,
364+
treatment_name=treatment_name,
365+
base_formula=base_formula,
366+
saturation_type=saturation_type,
367+
saturation_grid=estimation_kwargs["saturation_grid"],
368+
adstock_grid=estimation_kwargs["adstock_grid"],
369+
coef_constraint=coef_constraint,
370+
hac_maxlags=hac_maxlags,
371+
)
372+
373+
search_space = {
374+
"saturation_grid": estimation_kwargs["saturation_grid"],
375+
"adstock_grid": estimation_kwargs["adstock_grid"],
376+
}
377+
378+
elif estimation_method == "optimize":
379+
if "saturation_bounds" not in estimation_kwargs:
380+
raise ValueError(
381+
"saturation_bounds is required for optimize method. "
382+
"E.g., saturation_bounds={'slope': (0.5, 5.0), 'kappa': (2, 10)}"
383+
)
384+
if "adstock_bounds" not in estimation_kwargs:
385+
raise ValueError(
386+
"adstock_bounds is required for optimize method. "
387+
"E.g., adstock_bounds={'half_life': (1, 10)}"
388+
)
389+
390+
est_results = estimate_transform_params_optimize(
391+
data=data,
392+
y_column=y_column,
393+
treatment_name=treatment_name,
394+
base_formula=base_formula,
395+
saturation_type=saturation_type,
396+
saturation_bounds=estimation_kwargs["saturation_bounds"],
397+
adstock_bounds=estimation_kwargs["adstock_bounds"],
398+
initial_params=estimation_kwargs.get("initial_params"),
399+
coef_constraint=coef_constraint,
400+
hac_maxlags=hac_maxlags,
401+
method=estimation_kwargs.get("method", "L-BFGS-B"),
402+
)
403+
404+
search_space = {
405+
"saturation_bounds": estimation_kwargs["saturation_bounds"],
406+
"adstock_bounds": estimation_kwargs["adstock_bounds"],
407+
"initial_params": estimation_kwargs.get("initial_params"),
408+
"method": estimation_kwargs.get("method", "L-BFGS-B"),
409+
}
410+
411+
else:
412+
raise ValueError(
413+
f"Unknown estimation_method: {estimation_method}. "
414+
"Use 'grid' or 'optimize'."
415+
)
416+
417+
# Create Treatment with best transforms
418+
from causalpy.transforms import Treatment
419+
420+
treatment = Treatment(
421+
name=treatment_name,
422+
saturation=est_results["best_saturation"],
423+
adstock=est_results["best_adstock"],
424+
coef_constraint=coef_constraint,
425+
)
426+
427+
# Create TransferFunctionITS with estimated transforms
428+
result = cls(
429+
data=data,
430+
y_column=y_column,
431+
base_formula=base_formula,
432+
treatments=[treatment],
433+
hac_maxlags=hac_maxlags,
434+
)
435+
436+
# Store estimation metadata
437+
result.transform_estimation_method = estimation_method
438+
result.transform_estimation_results = est_results
439+
result.transform_search_space = search_space
440+
441+
return result
442+
231443
def _validate_inputs(
232444
self,
233445
data: pd.DataFrame,
@@ -303,8 +515,15 @@ def _build_treatment_matrix(
303515
# Get raw exposure series
304516
x_raw = data[treatment.name].values
305517

306-
# Apply transform pipeline
307-
x_transformed = apply_treatment_transforms(x_raw, treatment)
518+
# Apply transform pipeline using strategy pattern
519+
# Transforms are applied in order: Saturation → Adstock → Lag
520+
x_transformed = x_raw
521+
if treatment.saturation is not None:
522+
x_transformed = treatment.saturation.apply(x_transformed)
523+
if treatment.adstock is not None:
524+
x_transformed = treatment.adstock.apply(x_transformed)
525+
if treatment.lag is not None:
526+
x_transformed = treatment.lag.apply(x_transformed)
308527

309528
Z_columns.append(x_transformed)
310529
labels.append(treatment.name)
@@ -540,32 +759,32 @@ def plot_irf(self, channel: str, max_lag: Optional[int] = None) -> plt.Figure:
540759
if treatment is None:
541760
raise ValueError(f"Channel '{channel}' not found in treatments")
542761

543-
# Extract adstock parameters
544-
adstock = None
545-
for transform in treatment.transforms:
546-
if isinstance(transform, Adstock):
547-
adstock = transform
548-
break
762+
# Extract adstock transform (now directly accessible via treatment.adstock)
763+
adstock = treatment.adstock
549764

550765
if adstock is None:
551766
print(f"No adstock transform found for channel '{channel}'")
552767
return None
553768

554-
# Verify alpha is set (should be set by __post_init__)
555-
if adstock.alpha is None:
769+
# Get alpha parameter from adstock transform
770+
adstock_params = adstock.get_params()
771+
alpha = adstock_params.get("alpha")
772+
773+
if alpha is None:
556774
raise ValueError(
557775
f"Adstock transform for channel '{channel}' has alpha=None. "
558776
"This should not happen if half_life or alpha was provided."
559777
)
560778

561779
# Generate IRF (adstock weights)
562780
if max_lag is None:
563-
max_lag = adstock.l_max
781+
max_lag = adstock_params.get("l_max", 12)
564782

565783
lags = np.arange(max_lag + 1)
566-
weights = adstock.alpha**lags
784+
weights = alpha**lags
567785

568-
if adstock.normalize:
786+
normalize = adstock_params.get("normalize", True)
787+
if normalize:
569788
weights = weights / weights.sum()
570789

571790
# Plot
@@ -575,12 +794,12 @@ def plot_irf(self, channel: str, max_lag: Optional[int] = None) -> plt.Figure:
575794
ax.set_ylabel("Weight")
576795

577796
# Calculate half-life: alpha^h = 0.5, so h = log(0.5) / log(alpha)
578-
half_life_calc = np.log(0.5) / np.log(adstock.alpha)
797+
half_life_calc = np.log(0.5) / np.log(alpha)
579798

580799
ax.set_title(
581800
f"Impulse Response Function: {channel}\n"
582-
f"(alpha={adstock.alpha:.3f}, half_life={half_life_calc:.2f}, "
583-
f"normalize={adstock.normalize})"
801+
f"(alpha={alpha:.3f}, half_life={half_life_calc:.2f}, "
802+
f"normalize={normalize})"
584803
)
585804
ax.grid(True, alpha=0.3, axis="y")
586805

0 commit comments

Comments
 (0)