@@ -117,28 +117,33 @@ class GradedInterventionTimeSeries(BaseExperiment):
117117
118118 Examples
119119 --------
120- >>> import causalpy as cp
121- >>> # Step 1: Create UNFITTED model with configuration
122- >>> model = cp.skl_models.TransferFunctionOLS(
123- ... saturation_type="hill",
124- ... saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
125- ... adstock_grid={"half_life": [2, 3, 4, 5]},
126- ... estimation_method="grid",
127- ... error_model="hac",
128- ... )
129- >>> # Step 2: Pass to experiment (experiment estimates transforms and fits model)
130- >>> result = cp.GradedInterventionTimeSeries(
131- ... data=df,
132- ... y_column="water_consumption",
133- ... treatment_names=["comm_intensity"],
134- ... base_formula="1 + t + temperature + rainfall",
135- ... model=model,
136- ... )
137- >>> # Step 3: Use experiment methods
138- >>> result.summary()
139- >>> result.plot()
140- >>> result.plot_diagnostics()
141- >>> effect = result.effect(window=(df.index[0], df.index[-1]), scale=0.0)
120+ .. code-block:: python
121+
122+ import causalpy as cp
123+
124+ # Step 1: Create UNFITTED model with configuration
125+ model = cp.skl_models.TransferFunctionOLS(
126+ saturation_type="hill",
127+ saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
128+ adstock_grid={"half_life": [2, 3, 4, 5]},
129+ estimation_method="grid",
130+ error_model="hac",
131+ )
132+
133+ # Step 2: Pass to experiment (experiment estimates transforms and fits model)
134+ result = cp.GradedInterventionTimeSeries(
135+ data=df,
136+ y_column="water_consumption",
137+ treatment_names=["comm_intensity"],
138+ base_formula="1 + t + temperature + rainfall",
139+ model=model,
140+ )
141+
142+ # Step 3: Use experiment methods
143+ result.summary()
144+ result.plot()
145+ result.plot_diagnostics()
146+ effect = result.effect(window=(df.index[0], df.index[-1]), scale=0.0)
142147 """
143148
144149 expt_type = "Graded Intervention Time Series"
@@ -427,13 +432,15 @@ def effect(
427432
428433 Examples
429434 --------
430- >>> # Estimate effect of removing treatment completely
431- >>> effect = result.effect(
432- ... window=(df.index[0], df.index[-1]),
433- ... channels=["comm_intensity"],
434- ... scale=0.0,
435- ... )
436- >>> print(f"Total effect: {effect['total_effect']:.2f}")
435+ .. code-block:: python
436+
437+ # Estimate effect of removing treatment completely
438+ effect = result.effect(
439+ window=(df.index[0], df.index[-1]),
440+ channels=["comm_intensity"],
441+ scale=0.0,
442+ )
443+ print(f"Total effect: {effect['total_effect']:.2f}")
437444 """
438445 # Default to all channels if not specified
439446 if channels is None :
@@ -530,13 +537,15 @@ def plot_effect(
530537
531538 Examples
532539 --------
533- >>> # Estimate effect of removing treatment
534- >>> effect_result = result.effect(
535- ... window=(df.index[0], df.index[-1]),
536- ... channels=["comm_intensity"],
537- ... scale=0.0,
538- ... )
539- >>> fig, ax = result.plot_effect(effect_result)
540+ .. code-block:: python
541+
542+ # Estimate effect of removing treatment
543+ effect_result = result.effect(
544+ window=(df.index[0], df.index[-1]),
545+ channels=["comm_intensity"],
546+ scale=0.0,
547+ )
548+ fig, ax = result.plot_effect(effect_result)
540549 """
541550 # Extract data from effect result
542551 effect_df = effect_result ["effect_df" ]
@@ -687,7 +696,9 @@ def plot_irf(self, channel: str, max_lag: Optional[int] = None) -> plt.Figure:
687696
688697 Examples
689698 --------
690- >>> result.plot_irf("comm_intensity", max_lag=12)
699+ .. code-block:: python
700+
701+ fig = result.plot_irf("comm_intensity", max_lag=12)
691702 """
692703 # Find the treatment
693704 treatment = None
@@ -778,14 +789,16 @@ def plot_transforms(
778789
779790 Examples
780791 --------
781- >>> # Plot estimated transforms only
782- >>> fig, ax = result.plot_transforms()
783- >>>
784- >>> # Compare to true transforms (simulation study)
785- >>> fig, ax = result.plot_transforms(
786- ... true_saturation=HillSaturation(slope=2.0, kappa=50),
787- ... true_adstock=GeometricAdstock(half_life=3.0, normalize=True),
788- ... )
792+ .. code-block:: python
793+
794+ # Plot estimated transforms only
795+ fig, ax = result.plot_transforms()
796+
797+ # Compare to true transforms (simulation study)
798+ fig, ax = result.plot_transforms(
799+ true_saturation=HillSaturation(slope=2.0, kappa=50),
800+ true_adstock=GeometricAdstock(half_life=3.0, normalize=True),
801+ )
789802 """
790803 # Currently only supports single treatment
791804 if len (self .treatments ) != 1 :
0 commit comments