Skip to content

Commit 047dcbf

Browse files
authored
Merge pull request #126 from pymc-labs/idata-property
allow direct access to idata via a property
2 parents 4cab835 + 3347266 commit 047dcbf

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

causalpy/pymc_experiments.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ def __init__(self, prediction_model=None, **kwargs):
2424
if self.prediction_model is None:
2525
raise ValueError("fitting_model not set or passed.")
2626

27+
@property
28+
def idata(self):
29+
"""Access to the InferenceData object"""
30+
return self.prediction_model.idata
31+
2732
def print_coefficients(self):
2833
"""Prints the model coefficients"""
2934
print("Model coefficients:")

causalpy/tests/test_pymc_models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import pymc as pm
55
import pytest
66

7+
import causalpy as cp
78
from causalpy.pymc_models import ModelBuilder
89

10+
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
11+
912

1013
class MyToyModel(ModelBuilder):
1114
def build_model(self, X, y, coords):
@@ -67,3 +70,19 @@ def test_fit_predict(self, coords, rng) -> None:
6770
assert isinstance(score, pd.Series)
6871
assert score.shape == (2,)
6972
assert isinstance(predictions, az.InferenceData)
73+
74+
75+
def test_idata_property():
76+
"""Test that we can access the idata property of the model"""
77+
df = cp.load_data("did")
78+
result = cp.pymc_experiments.DifferenceInDifferences(
79+
df,
80+
formula="y ~ 1 + group + t + treated:group",
81+
time_variable_name="t",
82+
group_variable_name="group",
83+
treated=1,
84+
untreated=0,
85+
prediction_model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
86+
)
87+
assert hasattr(result, "idata")
88+
assert isinstance(result.idata, az.InferenceData)

0 commit comments

Comments
 (0)