Skip to content

Commit 72ebb8d

Browse files
committed
Adding the abstract class
2 parents cc0979e + fdce5b0 commit 72ebb8d

File tree

11 files changed

+96
-43
lines changed

11 files changed

+96
-43
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
- name: Setup environment
2525
run: pip install -e .[test]
2626
- name: Run doctests
27-
run: pytest --doctest-modules --ignore=causalpy/tests/ causalpy/
27+
run: pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
2828
- name: Run extra tests
2929
run: pytest docs/source/.codespell/test_notebook_to_markdown.py
3030
- name: Run tests

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ repos:
2525
exclude: &exclude_pattern 'iv_weak_instruments.ipynb'
2626
args: ["--maxkb=1500"]
2727
- repo: https://github.com/astral-sh/ruff-pre-commit
28-
rev: v0.11.13
28+
rev: v0.12.1
2929
hooks:
3030
# Run the linter
3131
- id: ruff

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.PHONY: init lint check_lint test uml html cleandocs
1+
.PHONY: init lint check_lint test uml html cleandocs doctest
22

33
init:
44
python -m pip install -e . --no-deps
@@ -13,7 +13,7 @@ check_lint:
1313
interrogate .
1414

1515
doctest:
16-
pytest --doctest-modules --ignore=causalpy/tests/ causalpy/
16+
pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
1717

1818
test:
1919
pytest

causalpy/experiments/interrupted_time_series.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Interrupted Time Series Analysis
1616
"""
1717

18+
from abc import ABC, abstractmethod
1819
from typing import List, Union
1920

2021
import arviz as az
@@ -34,7 +35,35 @@
3435
LEGEND_FONT_SIZE = 12
3536

3637

37-
class UnknownTreatmentTimeHandler:
38+
class TreatmentTimeHandler(ABC):
39+
@abstractmethod
40+
def data_preprocessing(self, data, treatment_time, model):
41+
pass
42+
43+
@abstractmethod
44+
def data_postprocessing(
45+
self, model, data, idata, treatment_time, y, X, pre_y, pre_X
46+
):
47+
pass
48+
49+
@abstractmethod
50+
def plot_intervention_line(
51+
self, ax, handles, labels, datapre, datapost, pre_pred, post_pred
52+
):
53+
pass
54+
55+
@abstractmethod
56+
def plot_impact_cumulative(self, ax, datapre, datapost, post_impact_cumulative):
57+
pass
58+
59+
def plot_treated_counterfactual(
60+
self, ax, handles, labels, datapre, datapost, pre_pred, post_pred
61+
):
62+
"""Optional: override if needed"""
63+
pass
64+
65+
66+
class UnknownTreatmentTimeHandler(TreatmentTimeHandler):
3867
"""
3968
A utility class for managing data preprocessing, postprocessing,
4069
and plotting steps for models that infer unknown treatment times.
@@ -185,7 +214,7 @@ def plot_intervention_line(
185214
)
186215

187216

188-
class KnownTreatmentTimeHandler:
217+
class KnownTreatmentTimeHandler(TreatmentTimeHandler):
189218
"""
190219
Handles data preprocessing, postprocessing, and plotting logic for models
191220
where the treatment time is known in advance.
@@ -247,14 +276,6 @@ def data_postprocessing(
247276

248277
return res
249278

250-
def plot_treated_counterfactual(
251-
self, sax, handles, labels, datapre, datapost, pre_pred, post_pred
252-
):
253-
"""
254-
Placeholder method to maintain interface compatibility with UnknownTreatmentTimeHandler.
255-
"""
256-
pass
257-
258279
def plot_impact_cumulative(self, ax, datapre, datapost, post_impact_cumulative):
259280
"""
260281
Plot the cumulative causal impact for the post-intervention period.

causalpy/experiments/prepostnegd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class PrePostNEGD(BaseExperiment):
7272
... }
7373
... ),
7474
... )
75-
>>> result.summary(round_to=1) # doctest: +NUMBER
75+
>>> result.summary(round_to=1) # doctest: +SKIP
7676
==================Pretest/posttest Nonequivalent Group Design===================
7777
Formula: post ~ 1 + C(group) + pre
7878
<BLANKLINE>

causalpy/tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,24 @@
2020

2121
import numpy as np
2222
import pytest
23+
from pymc.testing import mock_sample, mock_sample_setup_and_teardown
2324

2425

2526
@pytest.fixture(scope="session")
2627
def rng() -> np.random.Generator:
2728
"""Random number generator that can persist through a pytest session"""
2829
seed: int = sum(map(ord, "causalpy"))
2930
return np.random.default_rng(seed=seed)
31+
32+
33+
mock_pymc_sample = pytest.fixture(mock_sample_setup_and_teardown, scope="session")
34+
35+
36+
@pytest.fixture(autouse=True)
37+
def mock_sample_for_doctest(request):
38+
if not request.config.getoption("--doctest-modules", default=False):
39+
return
40+
41+
import pymc as pm
42+
43+
pm.sample = mock_sample

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
@pytest.mark.integration
26-
def test_did():
26+
def test_did(mock_pymc_sample):
2727
"""
2828
Test Difference in Differences (DID) PyMC experiment.
2929
@@ -57,7 +57,7 @@ def test_did():
5757

5858

5959
@pytest.mark.integration
60-
def test_did_banks_simple():
60+
def test_did_banks_simple(mock_pymc_sample):
6161
"""
6262
Test simple Differences In Differences Experiment on the 'banks' data set.
6363
@@ -113,7 +113,7 @@ def test_did_banks_simple():
113113

114114

115115
@pytest.mark.integration
116-
def test_did_banks_multi():
116+
def test_did_banks_multi(mock_pymc_sample):
117117
"""
118118
Test multiple regression Differences In Differences Experiment on the 'banks'
119119
data set.
@@ -168,7 +168,7 @@ def test_did_banks_multi():
168168

169169

170170
@pytest.mark.integration
171-
def test_rd():
171+
def test_rd(mock_pymc_sample):
172172
"""
173173
Test Regression Discontinuity experiment.
174174
@@ -199,7 +199,7 @@ def test_rd():
199199

200200

201201
@pytest.mark.integration
202-
def test_rd_bandwidth():
202+
def test_rd_bandwidth(mock_pymc_sample):
203203
"""
204204
Test Regression Discontinuity experiment with bandwidth parameter.
205205
@@ -229,7 +229,7 @@ def test_rd_bandwidth():
229229

230230

231231
@pytest.mark.integration
232-
def test_rd_drinking():
232+
def test_rd_drinking(mock_pymc_sample):
233233
"""
234234
Test Regression Discontinuity experiment on drinking age data.
235235
@@ -289,7 +289,7 @@ def reg_kink_function(x, beta, kink):
289289

290290

291291
@pytest.mark.integration
292-
def test_rkink():
292+
def test_rkink(mock_pymc_sample):
293293
"""
294294
Test Regression Kink design.
295295
@@ -320,7 +320,7 @@ def test_rkink():
320320

321321

322322
@pytest.mark.integration
323-
def test_rkink_bandwidth():
323+
def test_rkink_bandwidth(mock_pymc_sample):
324324
"""
325325
Test Regression Kink experiment with bandwidth parameter.
326326
@@ -350,7 +350,7 @@ def test_rkink_bandwidth():
350350

351351

352352
@pytest.mark.integration
353-
def test_its():
353+
def test_its(mock_pymc_sample):
354354
"""
355355
Test Interrupted Time-Series experiment.
356356
@@ -461,7 +461,7 @@ def test_its_no_treatment_time():
461461

462462

463463
@pytest.mark.integration
464-
def test_its_covid():
464+
def test_its_covid(mock_pymc_sample):
465465
"""
466466
Test Interrupted Time-Series experiment on COVID data.
467467
@@ -515,7 +515,7 @@ def test_its_covid():
515515

516516

517517
@pytest.mark.integration
518-
def test_sc():
518+
def test_sc(mock_pymc_sample):
519519
"""
520520
Test Synthetic Control experiment.
521521
@@ -574,7 +574,7 @@ def test_sc():
574574

575575

576576
@pytest.mark.integration
577-
def test_sc_brexit():
577+
def test_sc_brexit(mock_pymc_sample):
578578
"""
579579
Test Synthetic Control experiment on Brexit data.
580580
@@ -637,7 +637,7 @@ def test_sc_brexit():
637637

638638

639639
@pytest.mark.integration
640-
def test_ancova():
640+
def test_ancova(mock_pymc_sample):
641641
"""
642642
Test Pre-PostNEGD experiment on anova1 data.
643643
@@ -669,7 +669,7 @@ def test_ancova():
669669

670670

671671
@pytest.mark.integration
672-
def test_geolift1():
672+
def test_geolift1(mock_pymc_sample):
673673
"""
674674
Test Synthetic Control experiment on geo lift data.
675675
@@ -706,7 +706,7 @@ def test_geolift1():
706706

707707

708708
@pytest.mark.integration
709-
def test_iv_reg():
709+
def test_iv_reg(mock_pymc_sample):
710710
df = cp.load_data("risk")
711711
instruments_formula = "risk ~ 1 + logmort0"
712712
formula = "loggdp ~ 1 + risk"
@@ -734,7 +734,7 @@ def test_iv_reg():
734734

735735

736736
@pytest.mark.integration
737-
def test_inverse_prop():
737+
def test_inverse_prop(mock_pymc_sample):
738738
"""Test the InversePropensityWeighting class."""
739739
df = cp.load_data("nhefs")
740740
sample_kwargs = {

causalpy/tests/test_pymc_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_fit_build_not_implemented(self):
9393
argvalues=[None, {"a": 1}],
9494
ids=["None-coords", "dict-coords"],
9595
)
96-
def test_fit_predict(self, coords, rng) -> None:
96+
def test_fit_predict(self, coords, rng, mock_pymc_sample) -> None:
9797
"""
9898
Test fit and predict methods on MyToyModel.
9999
@@ -122,7 +122,7 @@ def test_fit_predict(self, coords, rng) -> None:
122122
assert isinstance(predictions, az.InferenceData)
123123

124124

125-
def test_idata_property():
125+
def test_idata_property(mock_pymc_sample):
126126
"""Test that we can access the idata property of the model"""
127127
df = cp.load_data("did")
128128
result = cp.DifferenceInDifferences(
@@ -140,7 +140,7 @@ def test_idata_property():
140140

141141

142142
@pytest.mark.parametrize("seed", seeds)
143-
def test_result_reproducibility(seed):
143+
def test_result_reproducibility(seed, mock_pymc_sample):
144144
"""Test that we can reproduce the results from the model. We could in theory test
145145
this with all the model and experiment types, but what is being targeted is
146146
the PyMCModel.fit method, so we should be safe testing with just one model. Here

codecov.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
coverage:
2+
ignore:
3+
- "*/conftest.py"
4+
- "*/tests/conftest.py"
25
status:
36
project:
47
default:

docs/source/_static/interrogate_badge.svg

Lines changed: 4 additions & 4 deletions
Loading

0 commit comments

Comments
 (0)