Skip to content

Commit 3a11f25

Browse files
committed
#19, #34 get summary method to output model coefficients for all remaining experiment types
1 parent 9501440 commit 3a11f25

File tree

6 files changed

+164
-45
lines changed

6 files changed

+164
-45
lines changed

causalpy/pymc_experiments.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,25 @@ def __init__(self, prediction_model=None, **kwargs):
2424
if self.prediction_model is None:
2525
raise ValueError("fitting_model not set or passed.")
2626

27+
def print_coefficients(self):
28+
"""Prints the model coefficients"""
29+
print("Model coefficients:")
30+
coeffs = az.extract(self.prediction_model.idata.posterior, var_names="beta")
31+
# Note: f"{name: <30}" pads the name with spaces so that we have alignment of the stats despite variable names of different lengths
32+
for name in self.labels:
33+
coeff_samples = coeffs.sel(coeffs=name)
34+
print(
35+
f" {name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
36+
)
37+
# add coeff for measurement std
38+
coeff_samples = az.extract(
39+
self.prediction_model.idata.posterior, var_names="sigma"
40+
)
41+
name = "sigma"
42+
print(
43+
f" {name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
44+
)
45+
2746

2847
class TimeSeriesExperiment(ExperimentalDesign):
2948
"""A class to analyse time series quasi-experiments"""
@@ -147,10 +166,20 @@ def plot(self):
147166

148167
return (fig, ax)
149168

169+
def summary(self):
170+
"""Print text output summarising the results"""
171+
172+
print(f"{self.expt_type:=^80}")
173+
print(f"Formula: {self.formula}")
174+
# TODO: extra experiment specific outputs here
175+
self.print_coefficients()
176+
150177

151178
class SyntheticControl(TimeSeriesExperiment):
152179
"""A wrapper around the TimeSeriesExperiment class"""
153180

181+
expt_type = "Synthetic Control"
182+
154183
def plot(self):
155184
"""Plot the results"""
156185
fig, ax = super().plot()
@@ -163,7 +192,7 @@ def plot(self):
163192
class InterruptedTimeSeries(TimeSeriesExperiment):
164193
"""A wrapper around the TimeSeriesExperiment class"""
165194

166-
pass
195+
expt_type = "Interrupted Time Series"
167196

168197

169198
class DifferenceInDifferences(ExperimentalDesign):
@@ -185,6 +214,7 @@ def __init__(
185214
):
186215
super().__init__(prediction_model=prediction_model, **kwargs)
187216
self.data = data
217+
self.expt_type = "Difference in Differences"
188218
self.formula = formula
189219
self.time_variable_name = time_variable_name
190220
y, X = dmatrices(formula, self.data)
@@ -324,6 +354,14 @@ def plot(self):
324354
ax.legend(fontsize=LEGEND_FONT_SIZE)
325355
return (fig, ax)
326356

357+
def summary(self):
358+
"""Print text output summarising the results"""
359+
360+
print(f"{self.expt_type:=^80}")
361+
print(f"Formula: {self.formula}")
362+
# TODO: extra experiment specific outputs here
363+
self.print_coefficients()
364+
327365

328366
class RegressionDiscontinuity(ExperimentalDesign):
329367
"""
@@ -456,19 +494,4 @@ def summary(self):
456494
print(
457495
f"Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f}"
458496
)
459-
print("Model coefficients:")
460-
coeffs = az.extract(self.prediction_model.idata.posterior, var_names="beta")
461-
# Note: f"{name: <30}" pads the name with spaces so that we have alignment of the stats despite variable names of different lengths
462-
for name in self.labels:
463-
coeff_samples = coeffs.sel(coeffs=name)
464-
print(
465-
f" {name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
466-
)
467-
# add coeff for measurement std
468-
coeff_samples = az.extract(
469-
self.prediction_model.idata.posterior, var_names="sigma"
470-
)
471-
name = "sigma"
472-
print(
473-
f" {name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
474-
)
497+
self.print_coefficients()

causalpy/pymc_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def build_model(self, X, y, coords):
6666
n_predictors = X.shape[1]
6767
X = pm.MutableData("X", X, dims=["obs_ind", "coeffs"])
6868
y = pm.MutableData("y", y[:, 0], dims="obs_ind")
69-
beta = pm.Dirichlet("beta", a=np.ones(n_predictors))
69+
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
7070
sigma = pm.HalfNormal("sigma", 1)
7171
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")
7272
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind")

docs/notebooks/did_pymc.ipynb

Lines changed: 21 additions & 4 deletions
Large diffs are not rendered by default.

docs/notebooks/its_pymc.ipynb

Lines changed: 37 additions & 11 deletions
Large diffs are not rendered by default.

docs/notebooks/rd_pymc.ipynb

Lines changed: 34 additions & 1 deletion
Large diffs are not rendered by default.

docs/notebooks/sc_pymc.ipynb

Lines changed: 31 additions & 11 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)