Skip to content

Commit 1257e87

Browse files
authored
Merge pull request #130 from pymc-labs/idata
access idata through new idata property + show property in autogenerated docs
2 parents d98952c + 120f671 commit 1257e87

File tree

5 files changed

+925
-974
lines changed

5 files changed

+925
-974
lines changed

causalpy/pymc_experiments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def idata(self):
3232
def print_coefficients(self):
3333
"""Prints the model coefficients"""
3434
print("Model coefficients:")
35-
coeffs = az.extract(self.prediction_model.idata.posterior, var_names="beta")
35+
coeffs = az.extract(self.idata.posterior, var_names="beta")
3636
# Note: f"{name: <30}" pads the name with spaces so that we have alignment of
3737
# the stats despite variable names of different lengths
3838
for name in self.labels:
@@ -697,7 +697,7 @@ def __init__(
697697
self.pred_treated = self.prediction_model.predict(X=np.asarray(new_x))
698698

699699
# Evaluate causal impact as equal to the trestment effect
700-
self.causal_impact = self.prediction_model.idata.posterior["beta"].sel(
700+
self.causal_impact = self.idata.posterior["beta"].sel(
701701
{"coeffs": self._get_treatment_effect_coeff()}
702702
)
703703

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 16 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,8 @@ def test_did():
2020
)
2121
assert isinstance(df, pd.DataFrame)
2222
assert isinstance(result, cp.pymc_experiments.DifferenceInDifferences)
23-
assert (
24-
len(result.prediction_model.idata.posterior.coords["chain"])
25-
== sample_kwargs["chains"]
26-
)
27-
assert (
28-
len(result.prediction_model.idata.posterior.coords["draw"])
29-
== sample_kwargs["draws"]
30-
)
23+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
24+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
3125

3226

3327
@pytest.mark.integration
@@ -61,14 +55,8 @@ def test_did_banks():
6155
)
6256
assert isinstance(df, pd.DataFrame)
6357
assert isinstance(result, cp.pymc_experiments.DifferenceInDifferences)
64-
assert (
65-
len(result.prediction_model.idata.posterior.coords["chain"])
66-
== sample_kwargs["chains"]
67-
)
68-
assert (
69-
len(result.prediction_model.idata.posterior.coords["draw"])
70-
== sample_kwargs["draws"]
71-
)
58+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
59+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
7260

7361

7462
@pytest.mark.integration
@@ -82,14 +70,8 @@ def test_rd():
8270
)
8371
assert isinstance(df, pd.DataFrame)
8472
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)
85-
assert (
86-
len(result.prediction_model.idata.posterior.coords["chain"])
87-
== sample_kwargs["chains"]
88-
)
89-
assert (
90-
len(result.prediction_model.idata.posterior.coords["draw"])
91-
== sample_kwargs["draws"]
92-
)
73+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
74+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
9375

9476

9577
@pytest.mark.integration
@@ -108,14 +90,8 @@ def test_rd_drinking():
10890
)
10991
assert isinstance(df, pd.DataFrame)
11092
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)
111-
assert (
112-
len(result.prediction_model.idata.posterior.coords["chain"])
113-
== sample_kwargs["chains"]
114-
)
115-
assert (
116-
len(result.prediction_model.idata.posterior.coords["draw"])
117-
== sample_kwargs["draws"]
118-
)
93+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
94+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
11995

12096

12197
@pytest.mark.integration
@@ -132,14 +108,8 @@ def test_its():
132108
)
133109
assert isinstance(df, pd.DataFrame)
134110
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
135-
assert (
136-
len(result.prediction_model.idata.posterior.coords["chain"])
137-
== sample_kwargs["chains"]
138-
)
139-
assert (
140-
len(result.prediction_model.idata.posterior.coords["draw"])
141-
== sample_kwargs["draws"]
142-
)
111+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
112+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
143113

144114

145115
@pytest.mark.integration
@@ -156,14 +126,8 @@ def test_its_covid():
156126
)
157127
assert isinstance(df, pd.DataFrame)
158128
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
159-
assert (
160-
len(result.prediction_model.idata.posterior.coords["chain"])
161-
== sample_kwargs["chains"]
162-
)
163-
assert (
164-
len(result.prediction_model.idata.posterior.coords["draw"])
165-
== sample_kwargs["draws"]
166-
)
129+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
130+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
167131

168132

169133
@pytest.mark.integration
@@ -178,14 +142,8 @@ def test_sc():
178142
)
179143
assert isinstance(df, pd.DataFrame)
180144
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
181-
assert (
182-
len(result.prediction_model.idata.posterior.coords["chain"])
183-
== sample_kwargs["chains"]
184-
)
185-
assert (
186-
len(result.prediction_model.idata.posterior.coords["draw"])
187-
== sample_kwargs["draws"]
188-
)
145+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
146+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
189147

190148

191149
@pytest.mark.integration
@@ -210,14 +168,8 @@ def test_sc_brexit():
210168
)
211169
assert isinstance(df, pd.DataFrame)
212170
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
213-
assert (
214-
len(result.prediction_model.idata.posterior.coords["chain"])
215-
== sample_kwargs["chains"]
216-
)
217-
assert (
218-
len(result.prediction_model.idata.posterior.coords["draw"])
219-
== sample_kwargs["draws"]
220-
)
171+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
172+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
221173

222174

223175
@pytest.mark.integration

docs/api_pymc_experiments.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
.. automodule:: causalpy.pymc_experiments
88
:members:
99
:undoc-members:
10+
:inherited-members:

docs/notebooks/rd_pymc_drinking.ipynb

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -446,9 +446,9 @@
446446
"source": [
447447
"fig, ax = plt.subplots(1, 2, figsize=(10, 3))\n",
448448
"\n",
449-
"az.plot_forest(result.prediction_model.idata.posterior, var_names=\"beta\", ax=ax[0])\n",
449+
"az.plot_forest(result.idata.posterior, var_names=\"beta\", ax=ax[0])\n",
450450
"az.plot_posterior(\n",
451-
" result.prediction_model.idata.posterior.beta.sel(coeffs=\"treated[T.True]\"),\n",
451+
" result.idata.posterior.beta.sel(coeffs=\"treated[T.True]\"),\n",
452452
" round_to=3,\n",
453453
" ax=ax[1],\n",
454454
")\n",
@@ -792,9 +792,7 @@
792792
}
793793
],
794794
"source": [
795-
"az.plot_forest(\n",
796-
" result2.prediction_model.idata.posterior, var_names=\"beta\", figsize=(10, 3)\n",
797-
");"
795+
"az.plot_forest(result2.idata.posterior, var_names=\"beta\", figsize=(10, 3));"
798796
]
799797
},
800798
{
@@ -1129,9 +1127,9 @@
11291127
"source": [
11301128
"fig, ax = plt.subplots(1, 2, figsize=(10, 3))\n",
11311129
"\n",
1132-
"az.plot_forest(result3.prediction_model.idata.posterior, var_names=\"beta\", ax=ax[0])\n",
1130+
"az.plot_forest(result3.idata.posterior, var_names=\"beta\", ax=ax[0])\n",
11331131
"az.plot_posterior(\n",
1134-
" result3.prediction_model.idata.posterior.beta.sel(coeffs=\"treated[T.True]\"),\n",
1132+
" result3.idata.posterior.beta.sel(coeffs=\"treated[T.True]\"),\n",
11351133
" round_to=3,\n",
11361134
" ax=ax[1],\n",
11371135
")\n",

docs/notebooks/sc_pymc_brexit.ipynb

Lines changed: 901 additions & 901 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)