Skip to content

Commit a417ce6

Browse files
authored
Merge branch 'main' into isort
2 parents 04615fc + c12d11a commit a417ce6

File tree

8 files changed

+349
-70
lines changed

8 files changed

+349
-70
lines changed

causalpy/data/ancova_generated.csv

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
group,pre,post
2+
0,8.489252455317661,7.824477455926374
3+
1,12.419853815210045,14.796265003859528
4+
0,11.131001267370635,10.693494775255136
5+
0,10.503789109304071,10.532152654532391
6+
0,10.599760912049257,9.731499501477701
7+
0,9.20585567484764,8.631473651277016
8+
1,12.981286194970231,14.858409910487993
9+
1,11.235838962316713,12.67802892339558
10+
1,11.052268558744334,11.981392508857457
11+
0,11.032876418087163,10.735915114198521
12+
1,10.797986215792887,12.602345899059923
13+
0,9.53432237712092,9.7465648655858
14+
0,8.777175028501961,9.185299260280473
15+
0,10.483186257178556,9.779337261557618
16+
0,8.18693217372839,8.827809032010238
17+
0,9.888550758723023,10.06543301874182
18+
1,12.025545572764294,13.433202294587286
19+
1,10.72641157652606,12.536558873383596
20+
0,8.06197411122309,7.425364373941061
21+
0,11.040779722618025,11.723017219921568
22+
1,11.30603790090078,13.385634372409909
23+
0,10.447600888850474,11.154060930494943
24+
1,13.067201979470134,14.28235454219127
25+
1,12.17750632339526,14.14039333760221
26+
0,10.807168530049271,11.24970995818306
27+
1,11.184307019963528,13.14627760120039
28+
0,9.586159304495181,9.826129608943418
29+
1,11.755038401678656,13.84496736394065
30+
1,10.625903042225632,12.689955968858188
31+
1,12.257685256277034,14.408471522629108
32+
0,8.896622773357983,8.273050312066276
33+
1,12.214940002590598,14.454534799661321
34+
1,11.835561812225908,14.247528917945397
35+
0,9.433030539280715,8.574508097683603
36+
0,10.686248326783621,11.040360143424008
37+
0,9.85838654577244,10.302354190227469
38+
1,9.608614059642928,12.299587472723434
39+
1,11.421790910329685,12.658228985028742
40+
0,10.669638134729697,10.358722385594032
41+
1,12.192272071263583,14.117735405267519
42+
1,13.038834337670504,15.514850208000231
43+
1,12.368314676020525,14.518772266846883
44+
0,10.617557026965056,11.016801339134387
45+
0,10.626532269883139,10.892646857034633
46+
0,10.300614333902987,10.330920058552755
47+
1,10.102849381841654,12.230359457045822
48+
1,11.008791245054956,12.733090909633706
49+
1,12.574573335103288,14.657540816061372
50+
1,13.098366871280593,16.027890469131716
51+
0,10.13314185203344,9.772772883269145
52+
1,12.444696778109526,15.429437827915937
53+
0,11.183433582379267,11.823413215893098
54+
0,8.542795699329139,8.295219192316004
55+
1,11.527270787720067,13.365372906900005
56+
0,9.879649326756322,10.36380907877101
57+
0,10.683112840950638,11.148616607826034
58+
1,11.973614356210454,13.689535712453113
59+
1,12.21887877910744,14.84175871850897
60+
1,11.055525478539108,13.066599895661387
61+
0,8.453381640575538,8.111537419587666
62+
0,9.870089190698051,10.244954780256606
63+
1,12.777355697190721,15.264886253657998
64+
1,12.317011998370008,13.97135908558011
65+
0,7.492808733899154,6.523860954287915
66+
1,12.368024481821628,14.53041599596784
67+
0,9.36416007790942,8.864468005622566
68+
0,10.619712036376995,11.492278368396999
69+
0,10.994893683611522,10.713579618741603
70+
1,10.630885220476767,12.932536617478073
71+
0,10.18920300527993,10.175662908120165
72+
0,11.549276043566842,11.632633532005107
73+
1,10.520615328912166,11.684547910255223
74+
1,11.645072905040799,13.45623886750313
75+
0,10.45875937643886,10.625967437146144
76+
1,10.060866941403077,11.62348553911652
77+
0,9.124687051984628,9.009220889339772
78+
1,11.498360775319536,12.746703456629843
79+
1,11.006238229752102,12.48465079862449
80+
0,10.722870069639583,10.757471649322138
81+
1,12.398698791433201,14.771181680989793
82+
0,9.770691699718377,8.685202159817438
83+
0,11.985158572294631,12.641466377478256
84+
0,8.594048648834946,9.116827003382555
85+
1,10.98177675963099,13.342589713080896
86+
0,9.648471015491806,8.337936823299579
87+
1,11.416965761785413,13.716362069518569
88+
0,11.703769282145567,11.56360776974179
89+
1,13.654731090563079,16.322820102174276
90+
1,10.50982748639981,12.363066668386029
91+
1,12.184260698273881,14.220574568667647
92+
0,9.117870231724542,8.517782751200494
93+
0,10.899705622412764,10.131533553924468
94+
0,10.833022280927885,10.539488123705226
95+
0,9.692716357998181,9.355244323730048
96+
0,11.100332278658298,10.768836085743134
97+
1,12.49699625130767,14.642322214070033
98+
1,11.726863918384433,14.10163077915994
99+
1,12.438566140714228,14.771487851461018
100+
1,10.90730329459493,13.097411299715988
101+
1,12.278901475190365,13.552951756078636
102+
1,12.265156438310129,14.67737747755582
103+
1,13.691286929642066,16.627878023869545
104+
0,10.1951357226745,9.45160547840346
105+
1,11.000528354049498,13.790898849686132
106+
0,8.893284860677767,8.948932364973224
107+
1,11.017152087022655,13.405675844916233
108+
1,12.845189458215584,14.695396946894284
109+
1,11.930183856716571,13.756428807657489
110+
1,12.192407095379489,14.677900510965449
111+
1,11.884725802514529,13.973533220264589
112+
0,9.204366676101822,8.238565832347659
113+
1,11.924056566110686,13.41118690511502
114+
0,9.733192805687251,9.919719653574846
115+
0,8.460665781880559,8.6999387553229
116+
0,9.020778090418924,10.03288785482543
117+
1,11.447580675708176,13.432136586882441
118+
0,8.580959268916429,8.291049437131987
119+
0,9.971720891948188,10.282006486576645
120+
1,11.032668688977394,13.40967978250074
121+
1,12.416818567306535,13.945708586691158
122+
0,10.182172058583925,10.09331090683779
123+
1,13.163171546449826,15.424084801147393
124+
0,10.42095673336904,10.409510045922122
125+
1,13.975016510347821,15.216115879129374
126+
1,11.49732187912045,13.846578205619414
127+
1,13.440602163936193,14.886634015814483
128+
1,11.930700474328532,13.780182905112268
129+
1,11.618453058418556,14.126668329679438
130+
0,9.751172804679456,9.788918717778754
131+
0,9.902487911531106,9.632698810164237
132+
0,9.398150334179403,9.438726118848791
133+
0,10.55877667049362,11.501197656653362
134+
1,12.225712172022552,13.732858661492681
135+
1,11.199763117297856,12.993311863721898
136+
0,8.503962964614457,8.787744515593845
137+
1,12.382539639759282,13.869004229739318
138+
1,11.264860317136874,13.471615470374822
139+
0,10.118365138760117,10.342860170948581
140+
1,13.015502974296197,14.267548895909114
141+
1,13.989377837690588,16.190181459644702
142+
0,11.127158566720949,10.677570054928793
143+
1,10.646242343119749,12.11896339648179
144+
0,10.164330881912127,10.89793960809276
145+
1,12.452471680236114,14.41318894029163
146+
0,8.929478965826593,9.399685573592025
147+
0,10.887105276402668,11.21233752513777
148+
1,12.788793113841145,14.611853869573453
149+
0,10.054823274810728,10.84174421439569
150+
0,10.450628491489136,9.801219714604168
151+
0,10.071076168392345,9.988065245821273
152+
1,14.65619257260125,16.21037548076697
153+
1,10.752644418784662,12.757671745957854
154+
1,12.367512031476597,13.990647903061141
155+
0,10.478471646252299,11.245810158940271
156+
1,11.932590274724973,14.08356956327396
157+
0,11.347258305364148,11.348332988923774
158+
1,11.503201853400503,14.214057326884314
159+
1,12.518869965621,14.137961853644292
160+
0,10.626820133170224,11.188432926449703
161+
1,11.538511555990787,14.002453123582423
162+
0,8.306059310792339,7.727924552988854
163+
1,12.442128762602495,13.810433501760814
164+
0,9.537521007599878,9.334443627143134
165+
1,12.967972843155685,14.36171721318257
166+
0,9.757145635663939,10.003194740646963
167+
1,11.998471541228344,14.73831178011914
168+
1,10.292481711678189,12.737319514475427
169+
1,12.314756886184675,13.116844779993592
170+
0,10.43849975260881,10.346747225019238
171+
1,12.919621939942418,14.681619928311472
172+
1,11.677649306775082,13.487347797970331
173+
0,9.941965032199327,10.269646730835463
174+
1,12.647579585136072,15.046848419329748
175+
0,9.628431468232515,9.735896298922077
176+
0,10.374851996408715,10.583207722350805
177+
0,9.876906469705084,9.830578613616709
178+
0,11.251897875890322,10.639444030967656
179+
1,12.349030894865676,13.492613886511528
180+
1,12.267519542091463,13.059608239132452
181+
1,11.583461058198159,13.230230139069306
182+
0,10.444945878227408,9.922913907294015
183+
1,13.374851599228663,15.856004353433418
184+
1,13.551659152951943,16.30466714176413
185+
1,10.52815952038451,12.209413797528292
186+
1,10.927962912150592,13.376284546621207
187+
0,9.629546929224926,10.112458762424659
188+
0,10.822862686692964,10.323347958244627
189+
1,11.77033081064472,14.573590716769981
190+
0,8.493079379186764,8.209446538006098
191+
1,13.440558253502417,16.094839169420357
192+
1,9.70734572727937,11.653537179686879
193+
1,11.790148012749617,12.855907075040314
194+
1,11.013447374072008,12.774959683965271
195+
1,11.946928386480945,13.715643446567078
196+
1,10.649070940308492,12.60415030897177
197+
0,8.567264509757244,9.237374661074682
198+
0,11.4829954066837,11.979219164829882
199+
0,10.04930108326449,10.7058776385638
200+
0,11.136769789873284,11.40344282008195
201+
0,9.599327203689894,9.927868251548668

causalpy/data/datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"its simple": {"filename": "its_simple.csv"},
1515
"rd": {"filename": "regression_discontinuity.csv"},
1616
"sc": {"filename": "synthetic_control.csv"},
17+
"anova1": {"filename": "ancova_generated.csv"},
1718
}
1819

1920

causalpy/data/simulate_data.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,13 @@ def impact(x):
210210
y = np.sin(x * 3) + impact(x) + norm.rvs(scale=0.1, size=N)
211211

212212
return pd.DataFrame({"x": x, "y": y, "treated": is_treated(x)})
213+
214+
215+
def generate_ancova_data(
216+
N=200, pre_treatment_means=np.array([10, 12]), treatment_effect=2, sigma=1
217+
):
218+
group = np.random.choice(2, size=N)
219+
pre = np.random.normal(loc=pre_treatment_means[group])
220+
post = pre + treatment_effect * group + np.random.normal(size=N) * sigma
221+
df = pd.DataFrame({"group": group, "pre": pre, "post": post})
222+
return df

causalpy/pymc_experiments.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ def __init__(self, prediction_model=None, **kwargs):
2424
if self.prediction_model is None:
2525
raise ValueError("fitting_model not set or passed.")
2626

27+
@property
28+
def idata(self):
29+
"""Access to the InferenceData object"""
30+
return self.prediction_model.idata
31+
2732
def print_coefficients(self):
2833
"""Prints the model coefficients"""
2934
print("Model coefficients:")

causalpy/tests/test_data_loading.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33

44
import causalpy as cp
55

6-
tests = ["banks", "brexit", "covid", "did", "drinking", "its", "its simple", "rd", "sc"]
6+
tests = [
7+
"banks",
8+
"brexit",
9+
"covid",
10+
"did",
11+
"drinking",
12+
"its",
13+
"its simple",
14+
"rd",
15+
"sc",
16+
"anova1",
17+
]
718

819

920
@pytest.mark.parametrize("dataset_name", tests)

causalpy/tests/test_pymc_models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import pymc as pm
55
import pytest
66

7+
import causalpy as cp
78
from causalpy.pymc_models import ModelBuilder
89

10+
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
11+
912

1013
class MyToyModel(ModelBuilder):
1114
def build_model(self, X, y, coords):
@@ -67,3 +70,19 @@ def test_fit_predict(self, coords, rng) -> None:
6770
assert isinstance(score, pd.Series)
6871
assert score.shape == (2,)
6972
assert isinstance(predictions, az.InferenceData)
73+
74+
75+
def test_idata_property():
76+
"""Test that we can access the idata property of the model"""
77+
df = cp.load_data("did")
78+
result = cp.pymc_experiments.DifferenceInDifferences(
79+
df,
80+
formula="y ~ 1 + group + t + treated:group",
81+
time_variable_name="t",
82+
group_variable_name="group",
83+
treated=1,
84+
untreated=0,
85+
prediction_model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
86+
)
87+
assert hasattr(result, "idata")
88+
assert isinstance(result.idata, az.InferenceData)

docs/glossary.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Some of the definitions have been copied from (or inspired by) various resources
66
</div>
77

88

9-
**ANCOVA:** Analysis of covariance is a simple linear model, typically with one continuous predictor (the covariate) and a catgeorical variable (which may correspond to treatment or control group). In the context of this package, ANCOVA could be useful in pretest-postdest designs, either with or without random assignment.
9+
**ANCOVA:** Analysis of covariance is a simple linear model, typically with one continuous predictor (the covariate) and a catgeorical variable (which may correspond to treatment or control group). In the context of this package, ANCOVA could be useful in pre-post treatment designs, either with or without random assignment. This is similar to the approach of difference in differences, but only applicable with a single pre and post treatment measure.
1010

1111
**Average treatment effect (ATE):** The average treatement effect across all units.
1212

0 commit comments

Comments
 (0)