Skip to content

Commit 33eec8c

Browse files
committed
Merge branch 'main' into interrogate
2 parents 40da547 + fc0f7a8 commit 33eec8c

28 files changed

+7301
-1514
lines changed

.pre-commit-config.yaml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,8 @@ repos:
1010
exclude_types: [svg]
1111
- id: check-yaml
1212
- id: check-added-large-files
13-
- repo: https://github.com/asottile/seed-isort-config
14-
rev: v2.2.0
15-
hooks:
16-
- id: seed-isort-config
17-
- repo: https://github.com/pre-commit/mirrors-isort
18-
rev: v5.10.1
13+
- repo: https://github.com/pycqa/isort
14+
rev: 5.11.2
1915
hooks:
2016
- id: isort
2117
args: [--profile, black]
@@ -24,11 +20,19 @@ repos:
2420
rev: 22.10.0
2521
hooks:
2622
- id: black
27-
- id: black-jupyter
2823
- repo: https://github.com/pycqa/flake8
2924
rev: 3.9.2
3025
hooks:
3126
- id: flake8
27+
- repo: https://github.com/nbQA-dev/nbQA
28+
rev: 1.5.3
29+
hooks:
30+
- id: nbqa-black
31+
# additional_dependencies: [jupytext] # optional, only if you're using Jupytext
32+
- id: nbqa-pyupgrade
33+
args: ["--py37-plus"]
34+
- id: nbqa-isort
35+
args: ["--float-to-top"]
3236
- repo: https://github.com/econchick/interrogate
3337
rev: 1.5.0
3438
hooks:

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ check_lint:
1313
flake8 .
1414
isort --check-only .
1515
black --diff --check --fast .
16+
nbqa black --check .
17+
nbqa isort --check-only .
1618

1719
test:
1820
pip install -r requirements-test.txt

README.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,26 @@ This is appropriate when you have multiple units, one of which is treated. You b
9191

9292
> The data (treated and untreated units), pre-treatment model fit, and counterfactual (i.e. the synthetic control) are plotted (top). The causal impact is shown as a blue shaded region. The Bayesian analysis shows shaded Bayesian credible regions of the model fit and counterfactual. Also shown is the causal impact (middle) and cumulative causal impact (bottom).
9393
94+
### ANCOVA
95+
96+
This is appropriate for non-equivalent group designs when you have a single pre and post intervention measurement and have a treament and a control group.
97+
98+
| Group | pre | post |
99+
|------|---|-------|
100+
| 0 | $x_1$ | $y_1$ |
101+
| 0 | $x_2$ | $y_2$ |
102+
| 1 | $x_3$ | $y_3$ |
103+
| 1 | $x_4$ | $y_4$ |
104+
105+
| Frequentist | Bayesian |
106+
|--|--|
107+
| coming soon | ![](img/anova_pymc.svg) |
108+
109+
> The data from the control and treatment group are plotted, along with posterior predictive 94% credible intervals. The lower panel shows the estimated treatment effect.
110+
94111
### Difference in Differences
95112

96-
This is appropriate when you have a single pre and post intervention measurement and have a treament and a control group.
113+
This is appropriate for non-equivalent group designs when you have pre and post intervention measurement and have a treament and a control group. Unlike the ANCOVA approach, difference in differences is appropriate when there are multiple pre and/or post treatment measurements.
97114

98115
Data is expected to be in the following form. Shown are just two units - one in the treated group (`group=1`) and one in the untreated group (`group=0`), but there can of course be multiple units per group. This is panel data (also known as repeated measures) where each unit is measured at 2 time points.
99116

@@ -108,7 +125,7 @@ Data is expected to be in the following form. Shown are just two units - one in
108125
|--|--|
109126
| ![](img/difference_in_differences_skl.svg) | ![](img/difference_in_differences_pymc.svg) |
110127

111-
The data, model fit, and counterfactual are plotted. Frequentist model fits result in points estimates, but the Bayesian analysis results in posterior distributions, represented by the violin plots. The causal impact is the difference between the counterfactual prediction (treated group, post treatment) and the observed values for the treated group, post treatment.
128+
>The data, model fit, and counterfactual are plotted. Frequentist model fits result in points estimates, but the Bayesian analysis results in posterior distributions, represented by the violin plots. The causal impact is the difference between the counterfactual prediction (treated group, post treatment) and the observed values for the treated group, post treatment.
112129
113130
### Regression discontinuity designs
114131

causalpy/data/ancova_generated.csv

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
group,pre,post
2+
0,8.489252455317661,7.824477455926374
3+
1,12.419853815210045,14.796265003859528
4+
0,11.131001267370635,10.693494775255136
5+
0,10.503789109304071,10.532152654532391
6+
0,10.599760912049257,9.731499501477701
7+
0,9.20585567484764,8.631473651277016
8+
1,12.981286194970231,14.858409910487993
9+
1,11.235838962316713,12.67802892339558
10+
1,11.052268558744334,11.981392508857457
11+
0,11.032876418087163,10.735915114198521
12+
1,10.797986215792887,12.602345899059923
13+
0,9.53432237712092,9.7465648655858
14+
0,8.777175028501961,9.185299260280473
15+
0,10.483186257178556,9.779337261557618
16+
0,8.18693217372839,8.827809032010238
17+
0,9.888550758723023,10.06543301874182
18+
1,12.025545572764294,13.433202294587286
19+
1,10.72641157652606,12.536558873383596
20+
0,8.06197411122309,7.425364373941061
21+
0,11.040779722618025,11.723017219921568
22+
1,11.30603790090078,13.385634372409909
23+
0,10.447600888850474,11.154060930494943
24+
1,13.067201979470134,14.28235454219127
25+
1,12.17750632339526,14.14039333760221
26+
0,10.807168530049271,11.24970995818306
27+
1,11.184307019963528,13.14627760120039
28+
0,9.586159304495181,9.826129608943418
29+
1,11.755038401678656,13.84496736394065
30+
1,10.625903042225632,12.689955968858188
31+
1,12.257685256277034,14.408471522629108
32+
0,8.896622773357983,8.273050312066276
33+
1,12.214940002590598,14.454534799661321
34+
1,11.835561812225908,14.247528917945397
35+
0,9.433030539280715,8.574508097683603
36+
0,10.686248326783621,11.040360143424008
37+
0,9.85838654577244,10.302354190227469
38+
1,9.608614059642928,12.299587472723434
39+
1,11.421790910329685,12.658228985028742
40+
0,10.669638134729697,10.358722385594032
41+
1,12.192272071263583,14.117735405267519
42+
1,13.038834337670504,15.514850208000231
43+
1,12.368314676020525,14.518772266846883
44+
0,10.617557026965056,11.016801339134387
45+
0,10.626532269883139,10.892646857034633
46+
0,10.300614333902987,10.330920058552755
47+
1,10.102849381841654,12.230359457045822
48+
1,11.008791245054956,12.733090909633706
49+
1,12.574573335103288,14.657540816061372
50+
1,13.098366871280593,16.027890469131716
51+
0,10.13314185203344,9.772772883269145
52+
1,12.444696778109526,15.429437827915937
53+
0,11.183433582379267,11.823413215893098
54+
0,8.542795699329139,8.295219192316004
55+
1,11.527270787720067,13.365372906900005
56+
0,9.879649326756322,10.36380907877101
57+
0,10.683112840950638,11.148616607826034
58+
1,11.973614356210454,13.689535712453113
59+
1,12.21887877910744,14.84175871850897
60+
1,11.055525478539108,13.066599895661387
61+
0,8.453381640575538,8.111537419587666
62+
0,9.870089190698051,10.244954780256606
63+
1,12.777355697190721,15.264886253657998
64+
1,12.317011998370008,13.97135908558011
65+
0,7.492808733899154,6.523860954287915
66+
1,12.368024481821628,14.53041599596784
67+
0,9.36416007790942,8.864468005622566
68+
0,10.619712036376995,11.492278368396999
69+
0,10.994893683611522,10.713579618741603
70+
1,10.630885220476767,12.932536617478073
71+
0,10.18920300527993,10.175662908120165
72+
0,11.549276043566842,11.632633532005107
73+
1,10.520615328912166,11.684547910255223
74+
1,11.645072905040799,13.45623886750313
75+
0,10.45875937643886,10.625967437146144
76+
1,10.060866941403077,11.62348553911652
77+
0,9.124687051984628,9.009220889339772
78+
1,11.498360775319536,12.746703456629843
79+
1,11.006238229752102,12.48465079862449
80+
0,10.722870069639583,10.757471649322138
81+
1,12.398698791433201,14.771181680989793
82+
0,9.770691699718377,8.685202159817438
83+
0,11.985158572294631,12.641466377478256
84+
0,8.594048648834946,9.116827003382555
85+
1,10.98177675963099,13.342589713080896
86+
0,9.648471015491806,8.337936823299579
87+
1,11.416965761785413,13.716362069518569
88+
0,11.703769282145567,11.56360776974179
89+
1,13.654731090563079,16.322820102174276
90+
1,10.50982748639981,12.363066668386029
91+
1,12.184260698273881,14.220574568667647
92+
0,9.117870231724542,8.517782751200494
93+
0,10.899705622412764,10.131533553924468
94+
0,10.833022280927885,10.539488123705226
95+
0,9.692716357998181,9.355244323730048
96+
0,11.100332278658298,10.768836085743134
97+
1,12.49699625130767,14.642322214070033
98+
1,11.726863918384433,14.10163077915994
99+
1,12.438566140714228,14.771487851461018
100+
1,10.90730329459493,13.097411299715988
101+
1,12.278901475190365,13.552951756078636
102+
1,12.265156438310129,14.67737747755582
103+
1,13.691286929642066,16.627878023869545
104+
0,10.1951357226745,9.45160547840346
105+
1,11.000528354049498,13.790898849686132
106+
0,8.893284860677767,8.948932364973224
107+
1,11.017152087022655,13.405675844916233
108+
1,12.845189458215584,14.695396946894284
109+
1,11.930183856716571,13.756428807657489
110+
1,12.192407095379489,14.677900510965449
111+
1,11.884725802514529,13.973533220264589
112+
0,9.204366676101822,8.238565832347659
113+
1,11.924056566110686,13.41118690511502
114+
0,9.733192805687251,9.919719653574846
115+
0,8.460665781880559,8.6999387553229
116+
0,9.020778090418924,10.03288785482543
117+
1,11.447580675708176,13.432136586882441
118+
0,8.580959268916429,8.291049437131987
119+
0,9.971720891948188,10.282006486576645
120+
1,11.032668688977394,13.40967978250074
121+
1,12.416818567306535,13.945708586691158
122+
0,10.182172058583925,10.09331090683779
123+
1,13.163171546449826,15.424084801147393
124+
0,10.42095673336904,10.409510045922122
125+
1,13.975016510347821,15.216115879129374
126+
1,11.49732187912045,13.846578205619414
127+
1,13.440602163936193,14.886634015814483
128+
1,11.930700474328532,13.780182905112268
129+
1,11.618453058418556,14.126668329679438
130+
0,9.751172804679456,9.788918717778754
131+
0,9.902487911531106,9.632698810164237
132+
0,9.398150334179403,9.438726118848791
133+
0,10.55877667049362,11.501197656653362
134+
1,12.225712172022552,13.732858661492681
135+
1,11.199763117297856,12.993311863721898
136+
0,8.503962964614457,8.787744515593845
137+
1,12.382539639759282,13.869004229739318
138+
1,11.264860317136874,13.471615470374822
139+
0,10.118365138760117,10.342860170948581
140+
1,13.015502974296197,14.267548895909114
141+
1,13.989377837690588,16.190181459644702
142+
0,11.127158566720949,10.677570054928793
143+
1,10.646242343119749,12.11896339648179
144+
0,10.164330881912127,10.89793960809276
145+
1,12.452471680236114,14.41318894029163
146+
0,8.929478965826593,9.399685573592025
147+
0,10.887105276402668,11.21233752513777
148+
1,12.788793113841145,14.611853869573453
149+
0,10.054823274810728,10.84174421439569
150+
0,10.450628491489136,9.801219714604168
151+
0,10.071076168392345,9.988065245821273
152+
1,14.65619257260125,16.21037548076697
153+
1,10.752644418784662,12.757671745957854
154+
1,12.367512031476597,13.990647903061141
155+
0,10.478471646252299,11.245810158940271
156+
1,11.932590274724973,14.08356956327396
157+
0,11.347258305364148,11.348332988923774
158+
1,11.503201853400503,14.214057326884314
159+
1,12.518869965621,14.137961853644292
160+
0,10.626820133170224,11.188432926449703
161+
1,11.538511555990787,14.002453123582423
162+
0,8.306059310792339,7.727924552988854
163+
1,12.442128762602495,13.810433501760814
164+
0,9.537521007599878,9.334443627143134
165+
1,12.967972843155685,14.36171721318257
166+
0,9.757145635663939,10.003194740646963
167+
1,11.998471541228344,14.73831178011914
168+
1,10.292481711678189,12.737319514475427
169+
1,12.314756886184675,13.116844779993592
170+
0,10.43849975260881,10.346747225019238
171+
1,12.919621939942418,14.681619928311472
172+
1,11.677649306775082,13.487347797970331
173+
0,9.941965032199327,10.269646730835463
174+
1,12.647579585136072,15.046848419329748
175+
0,9.628431468232515,9.735896298922077
176+
0,10.374851996408715,10.583207722350805
177+
0,9.876906469705084,9.830578613616709
178+
0,11.251897875890322,10.639444030967656
179+
1,12.349030894865676,13.492613886511528
180+
1,12.267519542091463,13.059608239132452
181+
1,11.583461058198159,13.230230139069306
182+
0,10.444945878227408,9.922913907294015
183+
1,13.374851599228663,15.856004353433418
184+
1,13.551659152951943,16.30466714176413
185+
1,10.52815952038451,12.209413797528292
186+
1,10.927962912150592,13.376284546621207
187+
0,9.629546929224926,10.112458762424659
188+
0,10.822862686692964,10.323347958244627
189+
1,11.77033081064472,14.573590716769981
190+
0,8.493079379186764,8.209446538006098
191+
1,13.440558253502417,16.094839169420357
192+
1,9.70734572727937,11.653537179686879
193+
1,11.790148012749617,12.855907075040314
194+
1,11.013447374072008,12.774959683965271
195+
1,11.946928386480945,13.715643446567078
196+
1,10.649070940308492,12.60415030897177
197+
0,8.567264509757244,9.237374661074682
198+
0,11.4829954066837,11.979219164829882
199+
0,10.04930108326449,10.7058776385638
200+
0,11.136769789873284,11.40344282008195
201+
0,9.599327203689894,9.927868251548668

causalpy/data/datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"its simple": {"filename": "its_simple.csv"},
1515
"rd": {"filename": "regression_discontinuity.csv"},
1616
"sc": {"filename": "synthetic_control.csv"},
17+
"anova1": {"filename": "ancova_generated.csv"},
1718
}
1819

1920

causalpy/data/simulate_data.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,13 @@ def impact(x):
210210
y = np.sin(x * 3) + impact(x) + norm.rvs(scale=0.1, size=N)
211211

212212
return pd.DataFrame({"x": x, "y": y, "treated": is_treated(x)})
213+
214+
215+
def generate_ancova_data(
216+
N=200, pre_treatment_means=np.array([10, 12]), treatment_effect=2, sigma=1
217+
):
218+
group = np.random.choice(2, size=N)
219+
pre = np.random.normal(loc=pre_treatment_means[group])
220+
post = pre + treatment_effect * group + np.random.normal(size=N) * sigma
221+
df = pd.DataFrame({"group": group, "pre": pre, "post": post})
222+
return df

causalpy/pymc_experiments.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ def __init__(self, prediction_model=None, **kwargs):
2424
if self.prediction_model is None:
2525
raise ValueError("fitting_model not set or passed.")
2626

27+
@property
28+
def idata(self):
29+
"""Access to the InferenceData object"""
30+
return self.prediction_model.idata
31+
2732
def print_coefficients(self):
2833
"""Prints the model coefficients"""
2934
print("Model coefficients:")

causalpy/tests/test_data_loading.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33

44
import causalpy as cp
55

6-
tests = ["banks", "brexit", "covid", "did", "drinking", "its", "its simple", "rd", "sc"]
6+
tests = [
7+
"banks",
8+
"brexit",
9+
"covid",
10+
"did",
11+
"drinking",
12+
"its",
13+
"its simple",
14+
"rd",
15+
"sc",
16+
"anova1",
17+
]
718

819

920
@pytest.mark.parametrize("dataset_name", tests)

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,19 @@ def test_sc_brexit():
218218
len(result.prediction_model.idata.posterior.coords["draw"])
219219
== sample_kwargs["draws"]
220220
)
221+
222+
223+
@pytest.mark.integration
224+
def test_ancova():
225+
df = cp.load_data("anova1")
226+
result = cp.pymc_experiments.PrePostNEGD(
227+
df,
228+
formula="post ~ 1 + C(group) + pre",
229+
group_variable_name="group",
230+
pretreatment_variable_name="pre",
231+
prediction_model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
232+
)
233+
assert isinstance(df, pd.DataFrame)
234+
assert isinstance(result, cp.pymc_experiments.PrePostNEGD)
235+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
236+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]

causalpy/tests/test_pymc_models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import pymc as pm
55
import pytest
66

7+
import causalpy as cp
78
from causalpy.pymc_models import ModelBuilder
89

10+
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
11+
912

1013
class MyToyModel(ModelBuilder):
1114
def build_model(self, X, y, coords):
@@ -67,3 +70,19 @@ def test_fit_predict(self, coords, rng) -> None:
6770
assert isinstance(score, pd.Series)
6871
assert score.shape == (2,)
6972
assert isinstance(predictions, az.InferenceData)
73+
74+
75+
def test_idata_property():
76+
"""Test that we can access the idata property of the model"""
77+
df = cp.load_data("did")
78+
result = cp.pymc_experiments.DifferenceInDifferences(
79+
df,
80+
formula="y ~ 1 + group + t + treated:group",
81+
time_variable_name="t",
82+
group_variable_name="group",
83+
treated=1,
84+
untreated=0,
85+
prediction_model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
86+
)
87+
assert hasattr(result, "idata")
88+
assert isinstance(result.idata, az.InferenceData)

0 commit comments

Comments
 (0)