Skip to content

Commit 061ab17

Browse files
committed
increase code coverage
1 parent c437c91 commit 061ab17

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,9 +1132,26 @@ def test_transfer_function_bayesian_adstock_only(mock_pymc_sample):
11321132
assert isinstance(fig, plt.Figure)
11331133
assert len(ax) == 2 # Should have 2 subplots
11341134

1135+
# Test plot_transforms (Bayesian-specific path)
1136+
fig_trans, ax_trans = result.plot_transforms()
1137+
assert isinstance(fig_trans, plt.Figure)
1138+
assert len(ax_trans) >= 1 # Should have at least 1 panel (adstock)
1139+
11351140
# Test summary (should not raise)
11361141
result.summary()
11371142

1143+
# Test effect() method (Bayesian-specific path)
1144+
effect_result = result.effect(
1145+
window=(df.index[0], df.index[-1]), channels=["treatment"], scale=0.0
1146+
)
1147+
assert "effect_df" in effect_result
1148+
assert "total_effect" in effect_result
1149+
1150+
# Test plot_effect() (Bayesian-specific path)
1151+
fig_eff, ax_eff = result.plot_effect(effect_result)
1152+
assert isinstance(fig_eff, plt.Figure)
1153+
assert len(ax_eff) == 2
1154+
11381155
# Test that half_life posterior is reasonable (should be positive)
11391156
half_life_samples = az.extract(result.model.idata, var_names=["half_life"])
11401157
assert (half_life_samples > 0).all(), "Half-life should be positive"
@@ -1251,9 +1268,26 @@ def test_transfer_function_ar_bayesian(mock_pymc_sample):
12511268
assert isinstance(fig, plt.Figure)
12521269
assert len(ax) == 2 # Should have 2 subplots
12531270

1271+
# Test plot_transforms (Bayesian-specific path)
1272+
fig_trans, ax_trans = result.plot_transforms()
1273+
assert isinstance(fig_trans, plt.Figure)
1274+
assert len(ax_trans) >= 1 # Should have at least 1 panel (adstock)
1275+
12541276
# Test summary (should not raise)
12551277
result.summary()
12561278

1279+
# Test effect() method (Bayesian-specific path)
1280+
effect_result = result.effect(
1281+
window=(df.index[0], df.index[-1]), channels=["treatment"], scale=0.0
1282+
)
1283+
assert "effect_df" in effect_result
1284+
assert "total_effect" in effect_result
1285+
1286+
# Test plot_effect() (Bayesian-specific path)
1287+
fig_eff, ax_eff = result.plot_effect(effect_result)
1288+
assert isinstance(fig_eff, plt.Figure)
1289+
assert len(ax_eff) == 2
1290+
12571291
# Test that half_life posterior is reasonable (should be positive)
12581292
half_life_samples = az.extract(result.model.idata, var_names=["half_life"])
12591293
assert (half_life_samples > 0).all(), "Half-life should be positive"

causalpy/tests/test_transfer_function_its.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,3 +1818,58 @@ def test_effect_with_arimax_model(self):
18181818
fig, ax = result.plot_effect(effect_result)
18191819
assert isinstance(fig, plt.Figure)
18201820
plt.close(fig)
1821+
1822+
1823+
class TestPlotTransformsEdgeCases:
1824+
"""Test plot_transforms edge cases."""
1825+
1826+
def test_plot_transforms_with_lag(self):
1827+
"""Test that lag transforms are applied correctly in build_treatment_matrix."""
1828+
np.random.seed(42)
1829+
n = 50
1830+
t = np.arange(n)
1831+
dates = pd.date_range("2020-01-01", periods=n, freq="W")
1832+
1833+
treatment_raw = 50 + np.random.uniform(-10, 10, n)
1834+
treatment_raw = np.maximum(treatment_raw, 0)
1835+
y = 100.0 + 0.5 * t + treatment_raw + np.random.normal(0, 5, n)
1836+
1837+
df = pd.DataFrame({"date": dates, "t": t, "y": y, "treatment": treatment_raw})
1838+
df = df.set_index("date")
1839+
1840+
model = TransferFunctionOLS(
1841+
saturation_type=None,
1842+
adstock_grid={"half_life": [3]},
1843+
estimation_method="grid",
1844+
error_model="hac",
1845+
)
1846+
1847+
result = GradedInterventionTimeSeries(
1848+
data=df,
1849+
y_column="y",
1850+
treatment_names=["treatment"],
1851+
base_formula="1 + t",
1852+
model=model,
1853+
)
1854+
1855+
# Manually add a lag to the treatment and test _build_treatment_matrix
1856+
from causalpy.transforms import DiscreteLag
1857+
1858+
treatments_with_lag = []
1859+
for treatment in result.treatments:
1860+
# Create a new treatment object with lag
1861+
treatment_lagged = Treatment(
1862+
name=treatment.name,
1863+
saturation=treatment.saturation,
1864+
adstock=treatment.adstock,
1865+
lag=DiscreteLag(k=1), # Add 1-period lag
1866+
)
1867+
treatments_with_lag.append(treatment_lagged)
1868+
1869+
# Test _build_treatment_matrix with lag
1870+
Z, labels = result._build_treatment_matrix(df, treatments_with_lag)
1871+
1872+
assert Z.shape == (n, 1)
1873+
assert labels == ["treatment"]
1874+
# First value should be 0 due to lag
1875+
assert Z[0, 0] == 0

0 commit comments

Comments
 (0)