Skip to content

Commit 4cd357e

Browse files
committed
#69 auto extract outcome_variable_name + better text summary output
1 parent e011c9d commit 4cd357e

File tree

2 files changed

+409
-19
lines changed

2 files changed

+409
-19
lines changed

causalpy/pymc_experiments.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
from causalpy.plot_utils import plot_xY
1010

1111
LEGEND_FONT_SIZE = 12
12+
az.style.use("arviz-darkgrid")
1213

1314

1415
class ExperimentalDesign:
1516
"""Base class"""
1617

1718
prediction_model = None
19+
expt_type = None
1820

1921
def __init__(self, prediction_model=None, **kwargs):
2022
if prediction_model is not None:
@@ -345,20 +347,20 @@ def __init__(
345347
treatment_threshold: float,
346348
prediction_model=None,
347349
running_variable_name: str = "x",
348-
outcome_variable_name="y",
349350
**kwargs,
350351
):
351352
super().__init__(prediction_model=prediction_model, **kwargs)
353+
self.expt_type = "Regression Discontinuity"
352354
self.data = data
353355
self.formula = formula
354356
self.running_variable_name = running_variable_name
355-
self.outcome_variable_name = outcome_variable_name
356357
self.treatment_threshold = treatment_threshold
357358
y, X = dmatrices(formula, self.data)
358359
self._y_design_info = y.design_info
359360
self._x_design_info = X.design_info
360361
self.labels = X.design_info.column_names
361362
self.y, self.X = np.asarray(y), np.asarray(X)
363+
self.outcome_variable_name = y.design_info.column_names[0]
362364

363365
# TODO: `treated` is a deterministic function of x and treatment_threshold, so this could be a function rather than supplied data
364366

@@ -445,7 +447,8 @@ def plot(self):
445447

446448
def summary(self):
447449
"""Print text output summarising the results"""
448-
print("Difference in Differences experiment")
450+
451+
print(f"{self.expt_type:=^80}")
449452
print(f"Formula: {self.formula}")
450453
print(f"Running variable: {self.running_variable_name}")
451454
print(f"Threshold on running variable: {self.treatment_threshold}")
@@ -455,15 +458,17 @@ def summary(self):
455458
)
456459
print("Model coefficients:")
457460
coeffs = az.extract(self.prediction_model.idata.posterior, var_names="beta")
461+
# Note: f"{name: <30}" pads the name with spaces so that we have alignment of the stats despite variable names of different lengths
458462
for name in self.labels:
459463
coeff_samples = coeffs.sel(coeffs=name)
460464
print(
461-
f"\t{name}\t\t{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
465+
f" {name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
462466
)
463467
# add coeff for measurement std
464468
coeff_samples = az.extract(
465469
self.prediction_model.idata.posterior, var_names="sigma"
466470
)
471+
name = "sigma"
467472
print(
468-
f"\tsigma\t\t{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
473+
f" {name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
469474
)

docs/notebooks/rd_pymc_drinking.ipynb

Lines changed: 399 additions & 14 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)