Skip to content

Commit 1db0c40

Browse files
authored
Merge pull request #97 from pymc-labs/brexit-example
Add Brexit synthetic control example
2 parents 6d488fb + 692008f commit 1db0c40

File tree

6 files changed

+988
-2
lines changed

6 files changed

+988
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ build/
88
dist/
99
docs/_build/
1010
*.vscode
11+
.coverage
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
Time,Australia,Austria,Belgium,Canada,Denmark,Finland,France,Germany,Iceland,Italy,Japan,Luxemburg,Netherlands,New_Zealand,Norway,Portugal,Spain,Sweden,Switzerland,UK,US
2+
2008-07-01,3.8237,0.835645,0.97149,17.53314,4.67639,0.55785,5.21976,7.07351,5.44478,4.434069,5.192159,0.12152299999999999,1.706577,0.47764,7.82785,0.48438800000000004,2.6986,10.86125,1.5433882,4.79983,157.09562
3+
2008-10-01,3.80552,0.8169080000000001,0.95042,17.32984,4.56596,0.54588,5.14178,6.96078,5.74929,4.317894,5.065315,0.117046,1.695315,0.47741,7.85237,0.478131,2.86664,10.47702,1.4999767000000002,4.70358,153.66607
4+
2009-01-01,3.84048,0.8028360000000001,0.94117,16.93824,4.50096,0.51052,5.0545,6.63471,5.18157,4.201581,4.822338,0.11483600000000001,1.6343910000000001,0.47336,7.78753,0.466098,2.58309,10.3222,1.4765317000000002,4.61881,151.87475
5+
2009-04-01,3.86954,0.796545,0.94162,16.7534,4.41372,0.50829,5.05375,6.6453,5.16171,4.1882090000000005,4.915681,0.116259,1.634432,0.47916,7.71903,0.466515,2.72044,10.32867,1.485509,4.60431,151.61772
6+
2009-07-01,3.88115,0.799937,0.95352,16.82878,4.42898,0.51299,5.06237,6.68237,5.24132,4.212578,4.912656,0.118747,1.6409820000000002,0.48188,7.724,0.47063099999999997,2.6055,10.32328,1.5025060000000001,4.60722,152.16647
7+
2009-10-01,3.91028,0.8038230000000001,0.96117,17.02503,4.433,0.50903,5.09832,6.73155,5.22482,4.226248,4.974765,0.119302,1.6508660000000002,0.48805,7.72812,0.470856,2.7842,10.37107,1.5151385000000002,4.62152,153.79155
8+
2010-01-01,3.92716,0.80051,0.96615,17.23041,4.47128,0.51413,5.11625,6.78621,4.91128,4.237582,5.027625,0.121414,1.647748,0.49349,7.87891,0.47473099999999996,2.57598,10.64833,1.5258639999999999,4.6538,154.56059
9+
2010-04-01,3.95387,0.811277,0.97567,17.32057,4.50184,0.52833,5.14098,6.93903,5.10614,4.269259,5.086985,0.122733,1.654981,0.49609,7.79516,0.477354,2.73434,10.86674,1.5384847,4.70655,156.05628
10+
2010-07-01,3.98175,0.81855,0.97974,17.44332,4.57321,0.52616,5.17273,6.99577,5.00671,4.289668,5.178457,0.120794,1.6622929999999998,0.48887,7.60172,0.477754,2.61706,10.99898,1.5475226999999998,4.73954,157.26282
11+
2010-10-01,4.01663,0.828113,0.98469,17.63825,4.56292,0.53721,5.20808,7.05251,5.19581,4.313207,5.137305,0.121695,1.681046,0.48355,7.81512,0.47682699999999995,2.79971,11.17711,1.5593448,4.74526,158.07995
12+
2011-01-01,4.0046,0.8352289999999999,0.98943,17.77148,4.57286,0.54021,5.26173,7.18926,4.98014,4.333769,5.083540999999999,0.123155,1.690748,0.49175,7.82243,0.473676,2.58719,11.21731,1.5693595999999999,4.7672,157.69911
13+
2011-04-01,4.05656,0.837287,0.99148,17.8061,4.6187,0.53927,5.25703,7.19607,5.11759,4.335718,5.039066,0.121965,1.689237,0.49628,7.79009,0.471644,2.72481,11.25308,1.5767164000000002,4.77202,158.76839
14+
2011-07-01,4.11156,0.842913,0.99432,18.05176,4.56125,0.53993,5.28283,7.25962,5.20997,4.313057,5.16075,0.12346,1.6892429999999998,0.50209,7.91561,0.467962,2.59493,11.40015,1.5707735999999999,4.78726,158.70684
15+
2011-10-01,4.15576,0.84265,0.99722,18.19392,4.59853,0.54008,5.29302,7.23616,5.28546,4.27216,5.154163,0.123072,1.679051,0.5062,7.91969,0.46104300000000004,2.7307,11.24116,1.5725503,4.79335,160.48702
16+
2012-01-01,4.19499,0.8493639999999999,0.99936,18.20558,4.59533,0.53772,5.29395,7.2513,5.11087,4.225777,5.22599,0.123716,1.675556,0.50695,8.09152,0.45852699999999996,2.52313,11.26089,1.5888463000000002,4.82583,161.79968
17+
2012-04-01,4.22634,0.845857,0.99935,18.26496,4.59879,0.53203,5.28606,7.26643,5.18428,4.194572,5.178464,0.12370600000000001,1.67641,0.51008,8.1106,0.45234599999999997,2.62527,11.2776,1.587918,4.82299,162.53726
18+
2012-07-01,4.25166,0.844361,1.00178,18.28984,4.60311,0.52998,5.29839,7.28685,5.21788,4.172838,5.158104,0.125012,1.669214,0.513,7.99303,0.447237,2.5334,11.26257,1.5990306,4.88205,162.82151
19+
2012-10-01,4.27205,0.84392,1.00132,18.32766,4.59567,0.52957,5.29273,7.25432,5.29916,4.141455,5.154555,0.127251,1.6575339999999998,0.51894,8.05836,0.44016900000000003,2.62919,11.18393,1.5975844000000001,4.87,163.00035
20+
2013-01-01,4.29019,0.8424010000000001,0.99857,18.49206,4.62158,0.52503,5.29616,7.22122,5.33395,4.102825,5.226566,0.127256,1.662886,0.52179,8.0705,0.442,2.46846,11.33892,1.6044659,4.89236,164.41485
21+
2013-04-01,4.30869,0.843032,1.0044,18.59938,4.62501,0.52745,5.33164,7.29839,5.42017,4.103628,5.273670999999999,0.129726,1.6598810000000002,0.52008,8.1288,0.44535199999999997,2.59228,11.32686,1.6188392,4.92451,164.64402
22+
2013-07-01,4.34261,0.846905,1.00751,18.75096,4.65431,0.52929,5.32621,7.33924,5.4072,4.113084000000001,5.323485,0.12981399999999998,1.669891,0.52474,8.20853,0.44473199999999996,2.51199,11.37816,1.6313214,4.96229,165.94743
23+
2013-10-01,4.37749,0.85024,1.00971,18.94795,4.66367,0.52834,5.36163,7.36043,5.59834,4.1039829999999995,5.317018,0.128722,1.68049,0.52764,8.19398,0.44960199999999995,2.63075,11.45938,1.6367806,4.98839,167.1276
24+
2014-01-01,4.41048,0.8483289999999999,1.01405,18.97892,4.67548,0.52426,5.36183,7.43211,5.43661,4.11047,5.3616,0.131386,1.678897,0.53332,8.23666,0.446716,2.48197,11.55366,1.6470401000000001,5.03492,166.54247
25+
2014-04-01,4.43084,0.853993,1.01674,19.15226,4.67209,0.52511,5.36653,7.43211,5.43975,4.109159,5.264161,0.129623,1.688726,0.53587,8.29753,0.448202,2.60733,11.64891,1.6581819,5.07551,168.68109
26+
2014-07-01,4.45209,0.85319,1.02376,19.33594,4.7503,0.52714,5.39723,7.46994,5.63207,4.1133809999999995,5.267956999999999,0.13231600000000002,1.692948,0.5425,8.33241,0.448627,2.5478,11.73151,1.6673167999999998,5.10988,170.64616
27+
2014-10-01,4.46827,0.854363,1.02911,19.46974,4.76734,0.52588,5.39537,7.53046,5.61836,4.102336,5.29257,0.135719,1.7081110000000002,0.54982,8.43427,0.452256,2.68448,11.82275,1.6761612,5.13852,171.41235
28+
2015-01-01,4.50884,0.856751,1.03446,19.36275,4.79998,0.52179,5.4238,7.49188,5.51916,4.113187,5.375484,0.134703,1.718281,0.55348,8.41262,0.45517199999999997,2.57233,11.98956,1.6733119,5.17165,172.80647
29+
2015-04-01,4.5137,0.858388,1.04141,19.31005,4.82243,0.52973,5.42994,7.5471,5.86313,4.129486,5.382865,0.13561600000000001,1.723575,0.5585,8.45721,0.45652699999999996,2.72132,12.10879,1.6853151000000002,5.20984,173.80875
30+
2015-07-01,4.56084,0.863285,1.0434,19.37835,4.83698,0.52968,5.44213,7.58115,5.81588,4.1394,5.387195999999999,0.134972,1.729509,0.56733,8.55631,0.457046,2.66375,12.27554,1.6959884,5.23783,174.3708
31+
2015-10-01,4.58825,0.8643789999999999,1.04775,19.39286,4.84776,0.53265,5.45297,7.61443,5.9103,4.161275,5.378318,0.13584,1.729793,0.57376,8.48118,0.45923699999999995,2.8185,12.36629,1.7029607999999998,5.27344,174.62579
32+
2016-01-01,4.62773,0.873659,1.04835,19.49923,4.9036,0.53922,5.48707,7.67836,5.73509,4.172618,5.419015,0.139863,1.745688,0.57866,8.49252,0.46138,2.65844,12.35977,1.7101123999999999,5.29673,175.65465
33+
2016-04-01,4.658,0.872602,1.05411,19.40335,4.96993,0.54006,5.47082,7.71241,6.13539,4.181198,5.411239,0.14183,1.7497989999999999,0.58517,8.48983,0.46276300000000004,2.81788,12.36546,1.7190254999999999,5.32731,176.18581
34+
2016-07-01,4.66437,0.877404,1.05596,19.60344,5.0121,0.54616,5.48611,7.73208,6.24267,4.204893,5.42041,0.142868,1.7693329999999998,0.58918,8.46376,0.468196,2.75793,12.40342,1.7283486,5.35103,177.24489
35+
2016-10-01,4.71319,0.886402,1.06138,19.71351,5.04821,0.54784,5.51862,7.76007,6.45202,4.216942,5.427695,0.14357899999999998,1.78415,0.59054,8.57712,0.472559,2.90415,12.51088,1.7343463000000001,5.39059,178.1256
36+
2017-01-01,4.72423,0.8904139999999999,1.0683,19.92778,5.08124,0.55343,5.56005,7.85351,6.19735,4.2401,5.47353,0.14136100000000001,1.7935070000000002,0.59832,8.66243,0.47809199999999996,2.76574,12.57115,1.7372078,5.42562,178.96623
37+
2017-04-01,4.75627,0.895741,1.07088,20.13165,5.13804,0.56048,5.60345,7.92008,6.41612,4.25706,5.495453,0.143345,1.809282,0.60619,8.74589,0.48050800000000005,2.94157,12.73491,1.7424922,5.44248,179.96802
38+
2017-07-01,4.80272,0.9003669999999999,1.07019,20.21658,5.1192,0.56257,5.65092,7.97833,6.3905,4.27331,5.537455,0.145172,1.822236,0.61264,8.80582,0.48384699999999997,2.87429,12.85438,1.75342,5.46579,181.26226
39+
2017-10-01,4.82362,0.906802,1.07877,20.3213,5.15784,0.56621,5.68609,8.05096,6.59169,4.29624,5.5409180000000005,0.145756,1.8360329999999998,0.62027,8.75565,0.48784099999999997,3.03707,12.87352,1.7726275,5.48781,182.96685
40+
2018-01-01,4.86504,0.914896,1.08326,20.49916,5.18425,0.56768,5.69028,8.00425,6.62937,4.294731,5.552345,0.14615,1.844232,0.62654,8.82385,0.491389,2.86562,12.92577,1.7915179,5.5008,184.36262
41+
2018-04-01,4.90809,0.91826,1.08833,20.6601,5.21031,0.56749,5.7137,8.06099,6.75676,4.295742,5.573858,0.14558100000000002,1.856078,0.63435,8.85039,0.49522,3.04836,13.06163,1.8080886999999999,5.53087,185.90004
42+
2018-07-01,4.92894,0.9195169999999999,1.09254,20.80268,5.24168,0.56655,5.73572,7.99744,6.68631,4.300846,5.537012,0.147342,1.860633,0.63601,8.90615,0.497823,2.96887,12.95859,1.8039333,5.56581,186.79599
43+
2018-10-01,4.93524,0.928973,1.10307,20.87359,5.26786,0.56651,5.77073,8.06099,6.77453,4.313607,5.522476,0.148252,1.86831,0.64537,8.90915,0.500855,3.14974,13.12264,1.8081592000000002,5.58448,187.21281
44+
2019-01-01,4.95551,0.936745,1.10577,20.89251,5.25637,0.57061,5.8107,8.11867,6.90307,4.322133,5.5518469999999995,0.150063,1.8813479999999998,0.64928,8.88936,0.50527,2.99029,13.19624,1.8128046,5.62033,188.33195
45+
2019-04-01,4.98973,0.9325439999999999,1.11054,21.0995,5.3092,0.57579,5.8489,8.10506,6.87472,4.334204000000001,5.57382,0.15265299999999998,1.889184,0.65371,8.89098,0.508123,3.15507,13.27584,1.8202787,5.62779,189.82528
46+
2019-07-01,5.0306,0.934451,1.11798,21.16842,5.33348,0.57554,5.8492,8.11413,6.74077,4.334605,5.572678000000001,0.152008,1.89659,0.66126,8.91728,0.5104730000000001,3.05611,13.3004,1.8296510999999998,5.65362,191.12653
47+
2019-10-01,5.05036,0.932082,1.12513,21.23207,5.31726,0.57405,5.83324,8.13532,6.97884,4.299539,5.408475,0.151945,1.9040839999999999,0.66254,9.05169,0.514682,3.24228,13.34498,1.8385548,5.65109,192.0231
48+
2020-01-01,5.03496,0.9083310000000001,1.08966,20.77403,5.27728,0.57338,5.50817,8.01957,6.62288,4.045370999999999,5.438035999999999,0.149451,1.87627,0.65584,8.93389,0.491961,2.89305,13.32553,1.8093797,5.50835,189.51992
49+
2020-04-01,4.69275,0.80489,0.96287,18.48005,4.96031,0.53833,4.76258,7.25923,6.13008,3.533511,5.002934000000001,0.14042,1.728407,0.59528,8.5071,0.416939,2.51187,12.25111,1.6983431,4.43817,172.58205
50+
2020-07-01,4.85727,0.891973,1.0774,20.14029,5.27494,0.56453,5.63803,7.9129,6.23727,4.10219,5.276015,0.152086,1.83597,0.67859,8.89123,0.478115,2.82149,13.15607,1.8052762,5.2191,185.60774
51+
2020-10-01,5.01644,0.875193,1.07616,20.58185,5.28059,0.56862,5.58853,7.96207,6.56287,4.037671,5.362419,0.15393600000000002,1.8364179999999999,0.6717,8.95496,0.47943,2.99307,13.14522,1.8059775,5.29647,187.67778
52+
2021-01-01,5.1059,0.87101,1.09033,20.8068,5.27907,0.56794,5.59176,7.84519,6.42829,4.046204,5.343331999999999,0.157143,1.837799,0.68192,8.95172,0.46540400000000004,2.79732,13.35595,1.8016901,5.2344,190.55655
53+
2021-04-01,5.14784,0.908738,1.10899,20.642,5.40941,0.57661,5.64981,7.99649,6.74324,4.153119,5.367795,0.158468,1.90723,0.69828,9.00515,0.48568,3.00089,13.45755,1.8372548000000002,5.52521,193.6831
54+
2021-07-01,5.05413,0.939523,1.1323,20.91196,5.47941,0.5818,5.83899,8.05928,6.67206,4.266435,5.33882,0.159873,1.936301,0.66858,9.36779,0.49891599999999997,2.97498,13.72724,1.871626,5.577,194.78893
55+
2021-10-01,5.23725,0.932196,1.13713,21.24709,5.63489,0.58634,5.87112,8.05701,6.82061,4.294219,5.391243,0.161441,1.9498929999999999,0.68504,9.37914,0.507602,3.27744,13.88566,1.8745770000000002,5.64812,198.0629
56+
2022-01-01,5.27676,0.9466310000000001,1.1433,21.40751,5.60826,0.5893,5.85862,8.12132,6.89605,4.298887,5.391991999999999,0.163454,1.9587679999999998,0.6846,9.29927,0.520615,3.10012,13.7735,1.8835115,5.69182,197.27918

causalpy/data/datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
DATASETS = {
88
"banks": {"filename": "banks.csv"},
9+
"brexit": {"filename": "GDP_in_dollars_billions.csv"},
910
"did": {"filename": "did.csv"},
1011
"drinking": {"filename": "drinking.csv"},
1112
"its": {"filename": "its.csv"},

causalpy/pymc_models.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Dict
2+
13
import arviz as az
24
import numpy as np
35
import pymc as pm
@@ -9,9 +11,10 @@ class ModelBuilder(pm.Model):
911
This is a wrapper around pm.Model to give scikit-learn like API
1012
"""
1113

12-
def __init__(self):
14+
def __init__(self, sample_kwargs: Dict = {}):
1315
super().__init__()
1416
self.idata = None
17+
self.sample_kwargs = sample_kwargs
1518

1619
def build_model(self, X, y, coords):
1720
raise NotImplementedError
@@ -26,7 +29,7 @@ def fit(self, X, y, coords):
2629
"""
2730
self.build_model(X, y, coords)
2831
with self.model:
29-
self.idata = pm.sample()
32+
self.idata = pm.sample(**self.sample_kwargs)
3033
self.idata.extend(pm.sample_prior_predictive())
3134
self.idata.extend(pm.sample_posterior_predictive(self.idata))
3235
return self.idata
@@ -69,7 +72,12 @@ def build_model(self, X, y, coords):
6972
n_predictors = X.shape[1]
7073
X = pm.MutableData("X", X, dims=["obs_ind", "coeffs"])
7174
y = pm.MutableData("y", y[:, 0], dims="obs_ind")
75+
# TODO: There we should allow user-specified priors here
7276
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
77+
# beta = pm.Dirichlet(
78+
# name="beta", a=(1 / n_predictors) * np.ones(n_predictors),
79+
# dims="coeffs"
80+
# )
7381
sigma = pm.HalfNormal("sigma", 1)
7482
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")
7583
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind")

docs/examples.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Synthetic Control
88
notebooks/sc_skl.ipynb
99
notebooks/sc2_pymc.ipynb
1010
notebooks/sc2_skl.ipynb
11+
notebooks/sc_pymc_brexit.ipynb
1112

1213

1314
Difference in Differences

docs/notebooks/sc_pymc_brexit.ipynb

Lines changed: 919 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)