Skip to content

Commit 48193ee

Browse files
committed
more tests to address failing code coverage check
1 parent bfce5c6 commit 48193ee

File tree

4 files changed

+1208
-3
lines changed

4 files changed

+1208
-3
lines changed

causalpy/tests/test_integration_skl_examples.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,114 @@ def test_rd_linear_with_gaussian_process():
275275
fig, ax = result.plot()
276276
assert isinstance(fig, plt.Figure)
277277
assert isinstance(ax, plt.Axes)
278+
279+
280+
@pytest.mark.integration
281+
def test_graded_intervention_time_series_end_to_end():
282+
"""
283+
Test Graded Intervention Time Series end-to-end workflow.
284+
285+
This integration test exercises the full workflow:
286+
1. Create data
287+
2. Configure TransferFunctionOLS model
288+
3. Run GradedInterventionTimeSeries experiment
289+
4. Call all major methods: plot(), plot_transforms(), effect(), plot_effect(), summary()
290+
5. Verify all methods work together
291+
"""
292+
# Generate synthetic data
293+
np.random.seed(42)
294+
n = 80
295+
t = np.arange(n)
296+
dates = pd.date_range("2020-01-01", periods=n, freq="W")
297+
298+
# Create treatment with known transforms
299+
treatment_raw = 50 + 30 * np.sin(2 * np.pi * t / 20) + np.random.uniform(-10, 10, n)
300+
treatment_raw = np.maximum(treatment_raw, 0)
301+
302+
# Generate outcome (we don't know true transforms, model will estimate them)
303+
y = 100.0 + 0.5 * t + 2.0 * treatment_raw + np.random.normal(0, 10, n)
304+
305+
df = pd.DataFrame({"date": dates, "t": t, "y": y, "treatment": treatment_raw})
306+
df = df.set_index("date")
307+
308+
# Create TransferFunctionOLS model
309+
model = cp.skl_models.TransferFunctionOLS(
310+
saturation_type="hill",
311+
saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [40, 50, 60]},
312+
adstock_grid={"half_life": [2, 3, 4], "l_max": [12], "normalize": [True]},
313+
estimation_method="grid",
314+
error_model="hac",
315+
)
316+
317+
# Run experiment
318+
result = cp.GradedInterventionTimeSeries(
319+
data=df,
320+
y_column="y",
321+
treatment_names=["treatment"],
322+
base_formula="1 + t",
323+
model=model,
324+
)
325+
326+
# Verify experiment result
327+
assert isinstance(result, cp.GradedInterventionTimeSeries)
328+
assert result.score > 0.5 # Reasonable fit
329+
330+
# Test plot() method
331+
fig, ax = result.plot()
332+
assert isinstance(fig, plt.Figure)
333+
assert isinstance(ax, np.ndarray)
334+
assert len(ax) == 2
335+
plt.close(fig)
336+
337+
# Test plot_transforms() method
338+
fig, ax = result.plot_transforms()
339+
assert isinstance(fig, plt.Figure)
340+
assert isinstance(ax, np.ndarray)
341+
assert len(ax) == 2
342+
plt.close(fig)
343+
344+
# Test effect() method
345+
effect_result = result.effect(
346+
window=(df.index[0], df.index[-1]), channels=["treatment"], scale=0.0
347+
)
348+
assert "effect_df" in effect_result
349+
assert "total_effect" in effect_result
350+
assert "mean_effect" in effect_result
351+
assert isinstance(effect_result["effect_df"], pd.DataFrame)
352+
353+
# Test plot_effect() method
354+
fig, ax = result.plot_effect(effect_result)
355+
assert isinstance(fig, plt.Figure)
356+
assert isinstance(ax, np.ndarray)
357+
assert len(ax) == 2
358+
plt.close(fig)
359+
360+
# Test summary() method (capture output to avoid cluttering test output)
361+
import io
362+
import sys
363+
364+
old_stdout = sys.stdout
365+
sys.stdout = io.StringIO()
366+
try:
367+
result.summary()
368+
output = sys.stdout.getvalue()
369+
assert "Graded Intervention Time Series Results" in output
370+
assert "Outcome variable" in output
371+
assert "Treatment coefficients" in output
372+
finally:
373+
sys.stdout = old_stdout
374+
375+
# Test plot_diagnostics() method
376+
sys.stdout = io.StringIO()
377+
try:
378+
result.plot_diagnostics(lags=10)
379+
finally:
380+
sys.stdout = old_stdout
381+
plt.close("all")
382+
383+
# Test get_plot_data_ols() method
384+
plot_data = result.get_plot_data_ols()
385+
assert isinstance(plot_data, pd.DataFrame)
386+
assert "observed" in plot_data.columns
387+
assert "fitted" in plot_data.columns
388+
assert "residuals" in plot_data.columns

0 commit comments

Comments
 (0)