Skip to content

Commit a058d93

Browse files
committed
Support adstock-only and saturation-only transfer functions
Refactored transfer function parameter estimation and plotting to allow models with only adstock or only saturation transforms (or both). Updated validation logic, grid/optimize routines, and plotting to handle optional transforms. Added comprehensive tests for all transform configurations and clarified documentation and notebook to demonstrate adstock-only use case.
1 parent 49c6986 commit a058d93

File tree

6 files changed

+599
-305
lines changed

6 files changed

+599
-305
lines changed

causalpy/experiments/graded_intervention_its.py

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -764,11 +764,11 @@ def plot_transforms(
764764
x_range=None,
765765
**kwargs,
766766
) -> Tuple[plt.Figure, np.ndarray]:
767-
"""Plot estimated saturation and adstock transformation curves.
767+
"""Plot estimated transformation curves (saturation and/or adstock).
768768
769-
Creates a 2-panel figure showing:
770-
1. Saturation curve (input exposure -> saturated exposure)
771-
2. Adstock weights over time (lag distribution)
769+
Creates a figure with 1-2 panels depending on which transforms are present:
770+
- Saturation curve (input exposure -> saturated exposure) if saturation exists
771+
- Adstock weights over time (lag distribution) if adstock exists
772772
773773
Parameters
774774
----------
@@ -784,8 +784,8 @@ def plot_transforms(
784784
Returns
785785
-------
786786
fig : matplotlib.figure.Figure
787-
ax : array of matplotlib.axes.Axes
788-
Array of 2 axes objects (left: saturation, right: adstock).
787+
ax : list of matplotlib.axes.Axes
788+
List of axes objects (1 or 2 panels depending on which transforms exist).
789789
790790
Examples
791791
--------
@@ -810,13 +810,33 @@ def plot_transforms(
810810
est_saturation = treatment.saturation
811811
est_adstock = treatment.adstock
812812

813-
# Create 2-panel subplot
814-
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
813+
# Check which transforms exist
814+
has_saturation = est_saturation is not None
815+
has_adstock = est_adstock is not None
816+
817+
if not has_saturation and not has_adstock:
818+
raise ValueError(
819+
"No transforms to plot (both saturation and adstock are None). "
820+
"At least one transform must be specified."
821+
)
822+
823+
# Determine number of panels based on available transforms
824+
n_panels = int(has_saturation) + int(has_adstock)
825+
826+
# Create subplot with appropriate number of panels
827+
fig, axes = plt.subplots(1, n_panels, figsize=(7 * n_panels, 5))
828+
829+
# Make axes a list for consistent indexing
830+
if n_panels == 1:
831+
axes = [axes]
832+
833+
panel_idx = 0
815834

816835
# ============================================================================
817-
# LEFT PLOT: Saturation curves
836+
# SATURATION PLOT (if present)
818837
# ============================================================================
819838
if est_saturation is not None:
839+
ax = axes[panel_idx]
820840
# Determine x range
821841
if x_range is None:
822842
# Use range from data
@@ -831,7 +851,7 @@ def plot_transforms(
831851
# Plot true saturation if provided
832852
if true_saturation is not None:
833853
y_true_sat = true_saturation.apply(x_sat)
834-
axes[0].plot(
854+
ax.plot(
835855
x_sat,
836856
y_true_sat,
837857
"k--",
@@ -842,13 +862,13 @@ def plot_transforms(
842862

843863
# Plot estimated saturation
844864
y_est_sat = est_saturation.apply(x_sat)
845-
axes[0].plot(x_sat, y_est_sat, "C0-", linewidth=2.5, label="Estimated")
865+
ax.plot(x_sat, y_est_sat, "C0-", linewidth=2.5, label="Estimated")
846866

847-
axes[0].set_xlabel(f"{treatment.name} (raw)", fontsize=11)
848-
axes[0].set_ylabel("Saturated Value", fontsize=11)
849-
axes[0].set_title("Saturation Function", fontsize=12, fontweight="bold")
850-
axes[0].legend(fontsize=LEGEND_FONT_SIZE, framealpha=0.9)
851-
axes[0].grid(True, alpha=0.3)
867+
ax.set_xlabel(f"{treatment.name} (raw)", fontsize=11)
868+
ax.set_ylabel("Saturated Value", fontsize=11)
869+
ax.set_title("Saturation Function", fontsize=12, fontweight="bold")
870+
ax.legend(fontsize=LEGEND_FONT_SIZE, framealpha=0.9)
871+
ax.grid(True, alpha=0.3)
852872

853873
# Add parameter text
854874
est_params = est_saturation.get_params()
@@ -864,30 +884,22 @@ def plot_transforms(
864884
if key not in ["alpha", "l_max", "normalize"]:
865885
param_text += f" {key}={val:.2f}\n"
866886

867-
axes[0].text(
887+
ax.text(
868888
0.05,
869889
0.95,
870890
param_text.strip(),
871-
transform=axes[0].transAxes,
891+
transform=ax.transAxes,
872892
fontsize=9,
873893
verticalalignment="top",
874894
bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
875895
)
876-
else:
877-
axes[0].text(
878-
0.5,
879-
0.5,
880-
"No saturation transform",
881-
ha="center",
882-
va="center",
883-
transform=axes[0].transAxes,
884-
)
885-
axes[0].set_title("Saturation Function", fontsize=12, fontweight="bold")
896+
panel_idx += 1
886897

887898
# ============================================================================
888-
# RIGHT PLOT: Adstock weights
899+
# ADSTOCK PLOT (if present)
889900
# ============================================================================
890901
if est_adstock is not None:
902+
ax = axes[panel_idx]
891903
est_adstock_params = est_adstock.get_params()
892904
l_max = est_adstock_params.get("l_max", 12)
893905
lags = np.arange(l_max + 1)
@@ -908,15 +920,15 @@ def plot_transforms(
908920
true_weights = true_weights / true_weights.sum()
909921

910922
width = 0.35
911-
axes[1].bar(
923+
ax.bar(
912924
lags - width / 2,
913925
true_weights,
914926
width,
915927
alpha=0.8,
916928
label="True",
917929
color="gray",
918930
)
919-
axes[1].bar(
931+
ax.bar(
920932
lags + width / 2,
921933
est_weights,
922934
width,
@@ -925,15 +937,15 @@ def plot_transforms(
925937
color="C0",
926938
)
927939
else:
928-
axes[1].bar(lags, est_weights, alpha=0.7, color="C0", label="Estimated")
940+
ax.bar(lags, est_weights, alpha=0.7, color="C0", label="Estimated")
929941

930-
axes[1].set_xlabel("Lag (periods)", fontsize=11)
931-
axes[1].set_ylabel("Adstock Weight", fontsize=11)
932-
axes[1].set_title(
942+
ax.set_xlabel("Lag (periods)", fontsize=11)
943+
ax.set_ylabel("Adstock Weight", fontsize=11)
944+
ax.set_title(
933945
"Adstock Function (Carryover Effect)", fontsize=12, fontweight="bold"
934946
)
935-
axes[1].legend(fontsize=LEGEND_FONT_SIZE, framealpha=0.9)
936-
axes[1].grid(True, alpha=0.3, axis="y")
947+
ax.legend(fontsize=LEGEND_FONT_SIZE, framealpha=0.9)
948+
ax.grid(True, alpha=0.3, axis="y")
937949

938950
# Add parameter text
939951
param_text = "Estimated:\n"
@@ -948,28 +960,16 @@ def plot_transforms(
948960
param_text += f" half_life={half_life_true:.2f}\n"
949961
param_text += f" alpha={true_alpha:.3f}\n"
950962

951-
axes[1].text(
963+
ax.text(
952964
0.95,
953965
0.95,
954966
param_text.strip(),
955-
transform=axes[1].transAxes,
967+
transform=ax.transAxes,
956968
fontsize=9,
957969
verticalalignment="top",
958970
horizontalalignment="right",
959971
bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
960972
)
961-
else:
962-
axes[1].text(
963-
0.5,
964-
0.5,
965-
"No adstock transform",
966-
ha="center",
967-
va="center",
968-
transform=axes[1].transAxes,
969-
)
970-
axes[1].set_title(
971-
"Adstock Function (Carryover Effect)", fontsize=12, fontweight="bold"
972-
)
973973

974974
plt.tight_layout()
975975
return fig, axes

causalpy/skl_models.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -195,26 +195,32 @@ def __init__(
195195

196196
# Validate estimation method and required parameters
197197
if estimation_method == "grid":
198-
if saturation_grid is None:
198+
# At least one transform must be specified
199+
if saturation_grid is None and adstock_grid is None:
199200
raise ValueError(
200-
"saturation_grid is required for grid search method. "
201-
"E.g., saturation_grid={'slope': [1.0, 2.0], 'kappa': [3, 5]}"
201+
"At least one of saturation_grid or adstock_grid must be provided for grid search. "
202+
"To use only adstock: set saturation_type=None and provide adstock_grid. "
203+
"To use only saturation: provide saturation_grid and set adstock_grid=None."
202204
)
203-
if adstock_grid is None:
205+
# If saturation_type is specified, grid must be provided
206+
if saturation_type is not None and saturation_grid is None:
204207
raise ValueError(
205-
"adstock_grid is required for grid search method. "
206-
"E.g., adstock_grid={'half_life': [2, 3, 4]}"
208+
f"saturation_grid is required when saturation_type='{saturation_type}'. "
209+
"E.g., saturation_grid={'lam': [0.2, 0.5, 0.8]}"
207210
)
208211
elif estimation_method == "optimize":
209-
if saturation_bounds is None:
212+
# At least one transform must be specified
213+
if saturation_bounds is None and adstock_bounds is None:
210214
raise ValueError(
211-
"saturation_bounds is required for optimize method. "
212-
"E.g., saturation_bounds={'slope': (0.5, 5.0), 'kappa': (2, 10)}"
215+
"At least one of saturation_bounds or adstock_bounds must be provided for optimize method. "
216+
"To use only adstock: set saturation_type=None and provide adstock_bounds. "
217+
"To use only saturation: provide saturation_bounds and set adstock_bounds=None."
213218
)
214-
if adstock_bounds is None:
219+
# If saturation_type is specified, bounds must be provided
220+
if saturation_type is not None and saturation_bounds is None:
215221
raise ValueError(
216-
"adstock_bounds is required for optimize method. "
217-
"E.g., adstock_bounds={'half_life': (1, 10)}"
222+
f"saturation_bounds is required when saturation_type='{saturation_type}'. "
223+
"E.g., saturation_bounds={'lam': (0.1, 1.0)}"
218224
)
219225
else:
220226
raise ValueError(

0 commit comments

Comments
 (0)