@@ -101,47 +101,50 @@ class TransferFunctionITS(BaseExperiment):
101101
102102 Examples
103103 --------
104- >>> import causalpy as cp
105- >>> import pandas as pd
106- >>> import numpy as np
107- >>> # Create sample data
108- >>> dates = pd.date_range("2022-01-01", periods=104, freq="W")
109- >>> df = pd.DataFrame(
110- ... {
111- ... "date": dates,
112- ... "water_consumption": np.random.normal(5000, 500, 104),
113- ... "comm_intensity": np.random.uniform(0, 10, 104),
114- ... "temperature": 25 + 10 * np.sin(2 * np.pi * np.arange(104) / 52),
115- ... "rainfall": 8 - 8 * np.sin(2 * np.pi * np.arange(104) / 52),
116- ... }
117- ... )
118- >>> df = df.set_index("date")
119- >>> df["t"] = np.arange(len(df))
120- >>>
121- >>> # Estimate transform parameters via grid search
122- >>> result = cp.TransferFunctionITS.with_estimated_transforms(
123- ... data=df,
124- ... y_column="water_consumption",
125- ... treatment_name="comm_intensity",
126- ... base_formula="1 + t + temperature + rainfall",
127- ... estimation_method="grid",
128- ... saturation_type="hill",
129- ... saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
130- ... adstock_grid={"half_life": [2, 3, 4, 5]},
131- ... )
132- >>>
133- >>> # View estimated parameters
134- >>> print(result.transform_estimation_results["best_params"])
135- >>>
136- >>> # Estimate effect of policy over entire period
137- >>> effect_result = result.effect(
138- ... window=(df.index[0], df.index[-1]), channels=["comm_intensity"], scale=0.0
139- ... )
140- >>> print(f"Total effect: {effect_result['total_effect']:.2f}")
141- >>>
142- >>> # Visualize results
143- >>> result.plot()
144- >>> result.diagnostics()
104+ .. code-block:: python
105+
106+ import causalpy as cp
107+ import pandas as pd
108+ import numpy as np
109+
110+ # Create sample data
111+ dates = pd.date_range("2022-01-01", periods=104, freq="W")
112+ df = pd.DataFrame(
113+ {
114+ "date": dates,
115+ "water_consumption": np.random.normal(5000, 500, 104),
116+ "comm_intensity": np.random.uniform(0, 10, 104),
117+ "temperature": 25 + 10 * np.sin(2 * np.pi * np.arange(104) / 52),
118+ "rainfall": 8 - 8 * np.sin(2 * np.pi * np.arange(104) / 52),
119+ }
120+ )
121+ df = df.set_index("date")
122+ df["t"] = np.arange(len(df))
123+
124+ # Estimate transform parameters via grid search
125+ result = cp.TransferFunctionITS.with_estimated_transforms(
126+ data=df,
127+ y_column="water_consumption",
128+ treatment_name="comm_intensity",
129+ base_formula="1 + t + temperature + rainfall",
130+ estimation_method="grid",
131+ saturation_type="hill",
132+ saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
133+ adstock_grid={"half_life": [2, 3, 4, 5]},
134+ )
135+
136+ # View estimated parameters
137+ print(result.transform_estimation_results["best_params"])
138+
139+ # Estimate effect of policy over entire period
140+ effect_result = result.effect(
141+ window=(df.index[0], df.index[-1]), channels=["comm_intensity"], scale=0.0
142+ )
143+ print(f"Total effect: {effect_result['total_effect']:.2f}")
144+
145+ # Visualize results
146+ result.plot()
147+ result.diagnostics()
145148
146149 Notes
147150 -----
@@ -321,31 +324,33 @@ def with_estimated_transforms(
321324
322325 Examples
323326 --------
324- >>> # Grid search example
325- >>> result = TransferFunctionITS.with_estimated_transforms(
326- ... data=df,
327- ... y_column="water_consumption",
328- ... treatment_name="comm_intensity",
329- ... base_formula="1 + t + temperature + rainfall",
330- ... estimation_method="grid",
331- ... saturation_type="hill",
332- ... saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
333- ... adstock_grid={"half_life": [2, 3, 4, 5]},
334- ... )
335- >>> print(f"Best RMSE: {result.transform_estimation_results['best_score']:.2f}")
336-
337- >>> # Optimization example
338- >>> result = TransferFunctionITS.with_estimated_transforms(
339- ... data=df,
340- ... y_column="water_consumption",
341- ... treatment_name="comm_intensity",
342- ... base_formula="1 + t + temperature + rainfall",
343- ... estimation_method="optimize",
344- ... saturation_type="hill",
345- ... saturation_bounds={"slope": (0.5, 5.0), "kappa": (2, 10)},
346- ... adstock_bounds={"half_life": (1, 10)},
347- ... initial_params={"slope": 2.0, "kappa": 5.0, "half_life": 4.0},
348- ... )
327+ .. code-block:: python
328+
329+ # Grid search example
330+ result = TransferFunctionITS.with_estimated_transforms(
331+ data=df,
332+ y_column="water_consumption",
333+ treatment_name="comm_intensity",
334+ base_formula="1 + t + temperature + rainfall",
335+ estimation_method="grid",
336+ saturation_type="hill",
337+ saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
338+ adstock_grid={"half_life": [2, 3, 4, 5]},
339+ )
340+ print(f"Best RMSE: {result.transform_estimation_results['best_score']:.2f}")
341+
342+ # Optimization example
343+ result = TransferFunctionITS.with_estimated_transforms(
344+ data=df,
345+ y_column="water_consumption",
346+ treatment_name="comm_intensity",
347+ base_formula="1 + t + temperature + rainfall",
348+ estimation_method="optimize",
349+ saturation_type="hill",
350+ saturation_bounds={"slope": (0.5, 5.0), "kappa": (2, 10)},
351+ adstock_bounds={"half_life": (1, 10)},
352+ initial_params={"slope": 2.0, "kappa": 5.0, "half_life": 4.0},
353+ )
349354
350355 Notes
351356 -----
@@ -587,12 +592,14 @@ def effect(
587592
588593 Examples
589594 --------
590- >>> # Estimate effect of completely removing TV spend in weeks 50-60
591- >>> effect = result.effect(
592- ... window=(df.index[50], df.index[60]), channels=["tv_spend"], scale=0.0
593- ... )
594- >>> print(f"Total effect: {effect['total_effect']:.2f}")
595- >>> print(f"Mean weekly effect: {effect['mean_effect']:.2f}")
595+ .. code-block:: python
596+
597+ # Estimate effect of completely removing TV spend in weeks 50-60
598+ effect = result.effect(
599+ window=(df.index[50], df.index[60]), channels=["tv_spend"], scale=0.0
600+ )
601+ print(f"Total effect: {effect['total_effect']:.2f}")
602+ print(f"Mean weekly effect: {effect['mean_effect']:.2f}")
596603
597604 Notes
598605 -----
@@ -760,7 +767,9 @@ def plot_irf(self, channel: str, max_lag: Optional[int] = None) -> plt.Figure:
760767
761768 Examples
762769 --------
763- >>> result.plot_irf("tv_spend")
770+ .. code-block:: python
771+
772+ result.plot_irf("tv_spend")
764773
765774 Notes
766775 -----
0 commit comments