3030from statsmodels .stats .diagnostic import acorr_ljungbox
3131
3232from causalpy .custom_exceptions import BadIndexException
33- from causalpy .transforms import Adstock , Treatment , apply_treatment_transforms
33+ from causalpy .transforms import Treatment
3434from causalpy .utils import round_num
3535
3636from .base import BaseExperiment
@@ -228,6 +228,218 @@ def __init__(
228228 # Store score (R-squared)
229229 self .score = self .ols_result .rsquared
230230
231+ # Transform estimation metadata (set by with_estimated_transforms)
232+ self .transform_estimation_method = None # "grid", "optimize", or None
233+ self .transform_estimation_results = None # Full results dict
234+ self .transform_search_space = None # Grid or bounds that were searched
235+
236+ @classmethod
237+ def with_estimated_transforms (
238+ cls ,
239+ data : pd .DataFrame ,
240+ y_column : str ,
241+ treatment_name : str ,
242+ base_formula : str ,
243+ estimation_method : str = "grid" ,
244+ saturation_type : str = "hill" ,
245+ coef_constraint : str = "nonnegative" ,
246+ hac_maxlags : Optional [int ] = None ,
247+ ** estimation_kwargs ,
248+ ) -> "TransferFunctionITS" :
249+ """
250+ Create a TransferFunctionITS with transform parameters estimated from data.
251+
252+ This method estimates optimal saturation and adstock parameters via grid
253+ search or continuous optimization, then creates a TransferFunctionITS
254+ instance with those estimated transforms.
255+
256+ Parameters
257+ ----------
258+ data : pd.DataFrame
259+ Time series data with datetime or numeric index.
260+ y_column : str
261+ Name of the outcome variable column in data.
262+ treatment_name : str
263+ Name of the treatment variable column in data.
264+ base_formula : str
265+ Patsy formula for the baseline model (trend, seasonality, controls).
266+ estimation_method : str, default="grid"
267+ Method for parameter estimation: "grid" or "optimize".
268+ - "grid": Grid search over discrete parameter values
269+ - "optimize": Continuous optimization using scipy.optimize
270+ saturation_type : str, default="hill"
271+ Type of saturation function: "hill", "logistic", or "michaelis_menten".
272+ coef_constraint : str, default="nonnegative"
273+ Constraint on treatment coefficient ("nonnegative" or "unconstrained").
274+ hac_maxlags : int, optional
275+ Maximum lags for HAC standard errors. If None, uses rule of thumb.
276+ **estimation_kwargs
277+ Additional keyword arguments for the estimation method:
278+
279+ For grid search (estimation_method="grid"):
280+ - saturation_grid : dict
281+ Dictionary mapping parameter names to lists of values.
282+ E.g., {"slope": [1.0, 2.0], "kappa": [3, 5, 7]}
283+ - adstock_grid : dict
284+ Dictionary mapping parameter names to lists of values.
285+ E.g., {"half_life": [2, 3, 4]}
286+
287+ For optimization (estimation_method="optimize"):
288+ - saturation_bounds : dict
289+ Dictionary mapping parameter names to (min, max) tuples.
290+ E.g., {"slope": (0.5, 5.0), "kappa": (2, 10)}
291+ - adstock_bounds : dict
292+ Dictionary mapping parameter names to (min, max) tuples.
293+ E.g., {"half_life": (1, 10)}
294+ - initial_params : dict, optional
295+ Initial parameter values for optimization.
296+ - method : str, default="L-BFGS-B"
297+ Scipy optimization method.
298+
299+ Returns
300+ -------
301+ TransferFunctionITS
302+ Fitted model with estimated transform parameters.
303+
304+ Examples
305+ --------
306+ >>> # Grid search example
307+ >>> result = TransferFunctionITS.with_estimated_transforms(
308+ ... data=df,
309+ ... y_column="water_consumption",
310+ ... treatment_name="comm_intensity",
311+ ... base_formula="1 + t + temperature + rainfall",
312+ ... estimation_method="grid",
313+ ... saturation_type="hill",
314+ ... saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
315+ ... adstock_grid={"half_life": [2, 3, 4, 5]},
316+ ... )
317+ >>> print(f"Best RMSE: {result.transform_estimation_results['best_score']:.2f}")
318+
319+ >>> # Optimization example
320+ >>> result = TransferFunctionITS.with_estimated_transforms(
321+ ... data=df,
322+ ... y_column="water_consumption",
323+ ... treatment_name="comm_intensity",
324+ ... base_formula="1 + t + temperature + rainfall",
325+ ... estimation_method="optimize",
326+ ... saturation_type="hill",
327+ ... saturation_bounds={"slope": (0.5, 5.0), "kappa": (2, 10)},
328+ ... adstock_bounds={"half_life": (1, 10)},
329+ ... initial_params={"slope": 2.0, "kappa": 5.0, "half_life": 4.0},
330+ ... )
331+
332+ Notes
333+ -----
334+ This method performs nested optimization:
335+ - Outer loop: Search over transform parameters
336+ - Inner loop: Fit OLS for each set of transform parameters
337+ - Objective: Minimize RMSE
338+
339+ Grid search is exhaustive but can be slow for large grids. Continuous
340+ optimization is faster but may find local optima. Consider using grid
341+ search first to find good starting points for optimization.
342+ """
343+ from causalpy .transform_optimization import (
344+ estimate_transform_params_grid ,
345+ estimate_transform_params_optimize ,
346+ )
347+
348+ # Run parameter estimation
349+ if estimation_method == "grid" :
350+ if "saturation_grid" not in estimation_kwargs :
351+ raise ValueError (
352+ "saturation_grid is required for grid search method. "
353+ "E.g., saturation_grid={'slope': [1.0, 2.0], 'kappa': [3, 5]}"
354+ )
355+ if "adstock_grid" not in estimation_kwargs :
356+ raise ValueError (
357+ "adstock_grid is required for grid search method. "
358+ "E.g., adstock_grid={'half_life': [2, 3, 4]}"
359+ )
360+
361+ est_results = estimate_transform_params_grid (
362+ data = data ,
363+ y_column = y_column ,
364+ treatment_name = treatment_name ,
365+ base_formula = base_formula ,
366+ saturation_type = saturation_type ,
367+ saturation_grid = estimation_kwargs ["saturation_grid" ],
368+ adstock_grid = estimation_kwargs ["adstock_grid" ],
369+ coef_constraint = coef_constraint ,
370+ hac_maxlags = hac_maxlags ,
371+ )
372+
373+ search_space = {
374+ "saturation_grid" : estimation_kwargs ["saturation_grid" ],
375+ "adstock_grid" : estimation_kwargs ["adstock_grid" ],
376+ }
377+
378+ elif estimation_method == "optimize" :
379+ if "saturation_bounds" not in estimation_kwargs :
380+ raise ValueError (
381+ "saturation_bounds is required for optimize method. "
382+ "E.g., saturation_bounds={'slope': (0.5, 5.0), 'kappa': (2, 10)}"
383+ )
384+ if "adstock_bounds" not in estimation_kwargs :
385+ raise ValueError (
386+ "adstock_bounds is required for optimize method. "
387+ "E.g., adstock_bounds={'half_life': (1, 10)}"
388+ )
389+
390+ est_results = estimate_transform_params_optimize (
391+ data = data ,
392+ y_column = y_column ,
393+ treatment_name = treatment_name ,
394+ base_formula = base_formula ,
395+ saturation_type = saturation_type ,
396+ saturation_bounds = estimation_kwargs ["saturation_bounds" ],
397+ adstock_bounds = estimation_kwargs ["adstock_bounds" ],
398+ initial_params = estimation_kwargs .get ("initial_params" ),
399+ coef_constraint = coef_constraint ,
400+ hac_maxlags = hac_maxlags ,
401+ method = estimation_kwargs .get ("method" , "L-BFGS-B" ),
402+ )
403+
404+ search_space = {
405+ "saturation_bounds" : estimation_kwargs ["saturation_bounds" ],
406+ "adstock_bounds" : estimation_kwargs ["adstock_bounds" ],
407+ "initial_params" : estimation_kwargs .get ("initial_params" ),
408+ "method" : estimation_kwargs .get ("method" , "L-BFGS-B" ),
409+ }
410+
411+ else :
412+ raise ValueError (
413+ f"Unknown estimation_method: { estimation_method } . "
414+ "Use 'grid' or 'optimize'."
415+ )
416+
417+ # Create Treatment with best transforms
418+ from causalpy .transforms import Treatment
419+
420+ treatment = Treatment (
421+ name = treatment_name ,
422+ saturation = est_results ["best_saturation" ],
423+ adstock = est_results ["best_adstock" ],
424+ coef_constraint = coef_constraint ,
425+ )
426+
427+ # Create TransferFunctionITS with estimated transforms
428+ result = cls (
429+ data = data ,
430+ y_column = y_column ,
431+ base_formula = base_formula ,
432+ treatments = [treatment ],
433+ hac_maxlags = hac_maxlags ,
434+ )
435+
436+ # Store estimation metadata
437+ result .transform_estimation_method = estimation_method
438+ result .transform_estimation_results = est_results
439+ result .transform_search_space = search_space
440+
441+ return result
442+
231443 def _validate_inputs (
232444 self ,
233445 data : pd .DataFrame ,
@@ -303,8 +515,15 @@ def _build_treatment_matrix(
303515 # Get raw exposure series
304516 x_raw = data [treatment .name ].values
305517
306- # Apply transform pipeline
307- x_transformed = apply_treatment_transforms (x_raw , treatment )
518+ # Apply transform pipeline using strategy pattern
519+ # Transforms are applied in order: Saturation → Adstock → Lag
520+ x_transformed = x_raw
521+ if treatment .saturation is not None :
522+ x_transformed = treatment .saturation .apply (x_transformed )
523+ if treatment .adstock is not None :
524+ x_transformed = treatment .adstock .apply (x_transformed )
525+ if treatment .lag is not None :
526+ x_transformed = treatment .lag .apply (x_transformed )
308527
309528 Z_columns .append (x_transformed )
310529 labels .append (treatment .name )
@@ -540,32 +759,32 @@ def plot_irf(self, channel: str, max_lag: Optional[int] = None) -> plt.Figure:
540759 if treatment is None :
541760 raise ValueError (f"Channel '{ channel } ' not found in treatments" )
542761
543- # Extract adstock parameters
544- adstock = None
545- for transform in treatment .transforms :
546- if isinstance (transform , Adstock ):
547- adstock = transform
548- break
762+ # Extract adstock transform (now directly accessible via treatment.adstock)
763+ adstock = treatment .adstock
549764
550765 if adstock is None :
551766 print (f"No adstock transform found for channel '{ channel } '" )
552767 return None
553768
554- # Verify alpha is set (should be set by __post_init__)
555- if adstock .alpha is None :
769+ # Get alpha parameter from adstock transform
770+ adstock_params = adstock .get_params ()
771+ alpha = adstock_params .get ("alpha" )
772+
773+ if alpha is None :
556774 raise ValueError (
557775 f"Adstock transform for channel '{ channel } ' has alpha=None. "
558776 "This should not happen if half_life or alpha was provided."
559777 )
560778
561779 # Generate IRF (adstock weights)
562780 if max_lag is None :
563- max_lag = adstock . l_max
781+ max_lag = adstock_params . get ( " l_max" , 12 )
564782
565783 lags = np .arange (max_lag + 1 )
566- weights = adstock . alpha ** lags
784+ weights = alpha ** lags
567785
568- if adstock .normalize :
786+ normalize = adstock_params .get ("normalize" , True )
787+ if normalize :
569788 weights = weights / weights .sum ()
570789
571790 # Plot
@@ -575,12 +794,12 @@ def plot_irf(self, channel: str, max_lag: Optional[int] = None) -> plt.Figure:
575794 ax .set_ylabel ("Weight" )
576795
577796 # Calculate half-life: alpha^h = 0.5, so h = log(0.5) / log(alpha)
578- half_life_calc = np .log (0.5 ) / np .log (adstock . alpha )
797+ half_life_calc = np .log (0.5 ) / np .log (alpha )
579798
580799 ax .set_title (
581800 f"Impulse Response Function: { channel } \n "
582- f"(alpha={ adstock . alpha :.3f} , half_life={ half_life_calc :.2f} , "
583- f"normalize={ adstock . normalize } )"
801+ f"(alpha={ alpha :.3f} , half_life={ half_life_calc :.2f} , "
802+ f"normalize={ normalize } )"
584803 )
585804 ax .grid (True , alpha = 0.3 , axis = "y" )
586805
0 commit comments