Skip to content

Commit 2364c4e

Browse files
authored
Merge pull request #141 from pymc-labs/rename-predictionmodel
global find and replace: `prediction_model` -> `model`
2 parents 5aead6b + 285fb3f commit 2364c4e

23 files changed

+100
-106
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ result = cp.pymc_experiments.RegressionDiscontinuity(
5757
df,
5858
formula="all ~ 1 + age + treated",
5959
running_variable_name="age",
60-
prediction_model=cp.pymc_models.LinearRegression(),
60+
model=cp.pymc_models.LinearRegression(),
6161
treatment_threshold=21,
6262
)
6363

causalpy/pymc_experiments.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,19 @@
1515
class ExperimentalDesign:
1616
"""Base class"""
1717

18-
prediction_model = None
18+
model = None
1919
expt_type = None
2020

21-
def __init__(self, prediction_model=None, **kwargs):
22-
if prediction_model is not None:
23-
self.prediction_model = prediction_model
24-
if self.prediction_model is None:
21+
def __init__(self, model=None, **kwargs):
22+
if model is not None:
23+
self.model = model
24+
if self.model is None:
2525
raise ValueError("fitting_model not set or passed.")
2626

2727
@property
2828
def idata(self):
2929
"""Access to the InferenceData object"""
30-
return self.prediction_model.idata
30+
return self.model.idata
3131

3232
def print_coefficients(self):
3333
"""Prints the model coefficients"""
@@ -41,9 +41,7 @@ def print_coefficients(self):
4141
f"{name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]" # noqa: E501
4242
)
4343
# add coeff for measurement std
44-
coeff_samples = az.extract(
45-
self.prediction_model.idata.posterior, var_names="sigma"
46-
)
44+
coeff_samples = az.extract(self.model.idata.posterior, var_names="sigma")
4745
name = "sigma"
4846
print(
4947
f"{name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]" # noqa: E501
@@ -58,10 +56,10 @@ def __init__(
5856
data: pd.DataFrame,
5957
treatment_time: int,
6058
formula: str,
61-
prediction_model=None,
59+
model=None,
6260
**kwargs,
6361
) -> None:
64-
super().__init__(prediction_model=prediction_model, **kwargs)
62+
super().__init__(model=model, **kwargs)
6563
self.treatment_time = treatment_time
6664
# split data in to pre and post intervention
6765
self.datapre = data[data.index <= self.treatment_time]
@@ -86,17 +84,17 @@ def __init__(
8684
# DEVIATION FROM SKL EXPERIMENT CODE =============================
8785
# fit the model to the observed (pre-intervention) data
8886
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.pre_X.shape[0])}
89-
self.prediction_model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
87+
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
9088
# ================================================================
9189

9290
# score the goodness of fit to the pre-intervention data
93-
self.score = self.prediction_model.score(X=self.pre_X, y=self.pre_y)
91+
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
9492

9593
# get the model predictions of the observed (pre-intervention) data
96-
self.pre_pred = self.prediction_model.predict(X=self.pre_X)
94+
self.pre_pred = self.model.predict(X=self.pre_X)
9795

9896
# calculate the counterfactual
99-
self.post_pred = self.prediction_model.predict(X=self.post_X)
97+
self.post_pred = self.model.predict(X=self.post_X)
10098

10199
# causal impact pre (ie the residuals of the model fit to observed)
102100
pre_data = xr.DataArray(self.pre_y[:, 0], dims=["obs_ind"])
@@ -242,10 +240,10 @@ def __init__(
242240
group_variable_name: str,
243241
treated: str,
244242
untreated: str,
245-
prediction_model=None,
243+
model=None,
246244
**kwargs,
247245
):
248-
super().__init__(prediction_model=prediction_model, **kwargs)
246+
super().__init__(model=model, **kwargs)
249247
self.data = data
250248
self.expt_type = "Difference in Differences"
251249
self.formula = formula
@@ -291,7 +289,7 @@ def __init__(
291289

292290
# DEVIATION FROM SKL EXPERIMENT CODE =============================
293291
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
294-
self.prediction_model.fit(X=self.X, y=self.y, coords=COORDS)
292+
self.model.fit(X=self.X, y=self.y, coords=COORDS)
295293
# ================================================================
296294

297295
# predicted outcome for control group
@@ -308,7 +306,7 @@ def __init__(
308306
)
309307
assert not self.x_pred_control.empty
310308
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
311-
self.y_pred_control = self.prediction_model.predict(np.asarray(new_x))
309+
self.y_pred_control = self.model.predict(np.asarray(new_x))
312310

313311
# predicted outcome for treatment group
314312
self.x_pred_treatment = (
@@ -324,7 +322,7 @@ def __init__(
324322
)
325323
assert not self.x_pred_treatment.empty
326324
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
327-
self.y_pred_treatment = self.prediction_model.predict(np.asarray(new_x))
325+
self.y_pred_treatment = self.model.predict(np.asarray(new_x))
328326

329327
# predicted outcome for counterfactual
330328
self.x_pred_counterfactual = (
@@ -346,7 +344,7 @@ def __init__(
346344
(new_x,) = build_design_matrices(
347345
[self._x_design_info], self.x_pred_counterfactual
348346
)
349-
self.y_pred_counterfactual = self.prediction_model.predict(np.asarray(new_x))
347+
self.y_pred_counterfactual = self.model.predict(np.asarray(new_x))
350348

351349
# calculate causal impact
352350
self.causal_impact = (
@@ -489,7 +487,7 @@ class RegressionDiscontinuity(ExperimentalDesign):
489487
:param formula: A statistical model formula
490488
:param treatment_threshold: A scalar threshold value at which the treatment
491489
is applied
492-
:param prediction_model: A PyMC model
490+
:param model: A PyMC model
493491
:param running_variable_name: The name of the predictor variable that the treatment
494492
threshold is based upon
495493
@@ -504,11 +502,11 @@ def __init__(
504502
data: pd.DataFrame,
505503
formula: str,
506504
treatment_threshold: float,
507-
prediction_model=None,
505+
model=None,
508506
running_variable_name: str = "x",
509507
**kwargs,
510508
):
511-
super().__init__(prediction_model=prediction_model, **kwargs)
509+
super().__init__(model=model, **kwargs)
512510
self.expt_type = "Regression Discontinuity"
513511
self.data = data
514512
self.formula = formula
@@ -527,11 +525,11 @@ def __init__(
527525
# DEVIATION FROM SKL EXPERIMENT CODE =============================
528526
# fit the model to the observed (pre-intervention) data
529527
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
530-
self.prediction_model.fit(X=self.X, y=self.y, coords=COORDS)
528+
self.model.fit(X=self.X, y=self.y, coords=COORDS)
531529
# ================================================================
532530

533531
# score the goodness of fit to all data
534-
self.score = self.prediction_model.score(X=self.X, y=self.y)
532+
self.score = self.model.score(X=self.X, y=self.y)
535533

536534
# get the model predictions of the observed data
537535
xi = np.linspace(
@@ -543,7 +541,7 @@ def __init__(
543541
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
544542
)
545543
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred)
546-
self.pred = self.prediction_model.predict(X=np.asarray(new_x))
544+
self.pred = self.model.predict(X=np.asarray(new_x))
547545

548546
# calculate discontinuity by evaluating the difference in model expectation on
549547
# either side of the discontinuity
@@ -558,7 +556,7 @@ def __init__(
558556
}
559557
)
560558
(new_x,) = build_design_matrices([self._x_design_info], self.x_discon)
561-
self.pred_discon = self.prediction_model.predict(X=np.asarray(new_x))
559+
self.pred_discon = self.model.predict(X=np.asarray(new_x))
562560
self.discontinuity_at_threshold = (
563561
self.pred_discon["posterior_predictive"].sel(obs_ind=1)["mu"]
564562
- self.pred_discon["posterior_predictive"].sel(obs_ind=0)["mu"]
@@ -633,10 +631,10 @@ def __init__(
633631
formula: str,
634632
group_variable_name: str,
635633
pretreatment_variable_name: str,
636-
prediction_model=None,
634+
model=None,
637635
**kwargs,
638636
):
639-
super().__init__(prediction_model=prediction_model, **kwargs)
637+
super().__init__(model=model, **kwargs)
640638
self.data = data
641639
self.expt_type = "Pretest/posttest Nonequivalent Group Design"
642640
self.formula = formula
@@ -663,7 +661,7 @@ def __init__(
663661

664662
# fit the model to the observed (pre-intervention) data
665663
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
666-
self.prediction_model.fit(X=self.X, y=self.y, coords=COORDS)
664+
self.model.fit(X=self.X, y=self.y, coords=COORDS)
667665

668666
# Calculate the posterior predictive for the treatment and control for an
669667
# interpolated set of pretest values
@@ -681,7 +679,7 @@ def __init__(
681679
}
682680
)
683681
(new_x,) = build_design_matrices([self._x_design_info], x_pred_untreated)
684-
self.pred_untreated = self.prediction_model.predict(X=np.asarray(new_x))
682+
self.pred_untreated = self.model.predict(X=np.asarray(new_x))
685683
# treated
686684
x_pred_untreated = pd.DataFrame(
687685
{
@@ -690,7 +688,7 @@ def __init__(
690688
}
691689
)
692690
(new_x,) = build_design_matrices([self._x_design_info], x_pred_untreated)
693-
self.pred_treated = self.prediction_model.predict(X=np.asarray(new_x))
691+
self.pred_treated = self.model.predict(X=np.asarray(new_x))
694692

695693
# Evaluate causal impact as equal to the trestment effect
696694
self.causal_impact = self.idata.posterior["beta"].sel(

causalpy/skl_experiments.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
class ExperimentalDesign:
1111
"""Base class for experiment designs"""
1212

13-
prediction_model = None
13+
model = None
1414
outcome_variable_name = None
1515

16-
def __init__(self, prediction_model=None, **kwargs):
17-
if prediction_model is not None:
18-
self.prediction_model = prediction_model
19-
if self.prediction_model is None:
16+
def __init__(self, model=None, **kwargs):
17+
if model is not None:
18+
self.model = model
19+
if self.model is None:
2020
raise ValueError("fitting_model not set or passed.")
2121

2222

@@ -26,10 +26,10 @@ def __init__(
2626
data,
2727
treatment_time,
2828
formula,
29-
prediction_model=None,
29+
model=None,
3030
**kwargs,
3131
):
32-
super().__init__(prediction_model=prediction_model, **kwargs)
32+
super().__init__(model=model, **kwargs)
3333
self.treatment_time = treatment_time
3434
# split data in to pre and post intervention
3535
self.datapre = data[data.index <= self.treatment_time]
@@ -52,16 +52,16 @@ def __init__(
5252
self.post_y = np.asarray(new_y)
5353

5454
# fit the model to the observed (pre-intervention) data
55-
self.prediction_model.fit(X=self.pre_X, y=self.pre_y)
55+
self.model.fit(X=self.pre_X, y=self.pre_y)
5656

5757
# score the goodness of fit to the pre-intervention data
58-
self.score = self.prediction_model.score(X=self.pre_X, y=self.pre_y)
58+
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
5959

6060
# get the model predictions of the observed (pre-intervention) data
61-
self.pre_pred = self.prediction_model.predict(X=self.pre_X)
61+
self.pre_pred = self.model.predict(X=self.pre_X)
6262

6363
# calculate the counterfactual
64-
self.post_pred = self.prediction_model.predict(X=self.post_X)
64+
self.post_pred = self.model.predict(X=self.post_X)
6565

6666
# causal impact pre (ie the residuals of the model fit to observed)
6767
self.pre_impact = self.pre_y - self.pre_pred
@@ -134,7 +134,7 @@ def plot(self):
134134
return (fig, ax)
135135

136136
def get_coeffs(self):
137-
return np.squeeze(self.prediction_model.coef_)
137+
return np.squeeze(self.model.coef_)
138138

139139
def plot_coeffs(self):
140140
df = pd.DataFrame(
@@ -176,10 +176,10 @@ def __init__(
176176
data: pd.DataFrame,
177177
formula: str,
178178
time_variable_name: str,
179-
prediction_model=None,
179+
model=None,
180180
**kwargs,
181181
):
182-
super().__init__(prediction_model=prediction_model, **kwargs)
182+
super().__init__(model=model, **kwargs)
183183
self.data = data
184184
self.formula = formula
185185
self.time_variable_name = time_variable_name
@@ -191,23 +191,23 @@ def __init__(
191191
self.outcome_variable_name = y.design_info.column_names[0]
192192

193193
# fit the model to all the data
194-
self.prediction_model.fit(X=self.X, y=self.y)
194+
self.model.fit(X=self.X, y=self.y)
195195

196196
# predicted outcome for control group
197197
self.x_pred_control = pd.DataFrame(
198198
{"group": [0, 0], "t": [0.0, 1.0], "post_treatment": [0, 0]}
199199
)
200200
assert not self.x_pred_control.empty
201201
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
202-
self.y_pred_control = self.prediction_model.predict(np.asarray(new_x))
202+
self.y_pred_control = self.model.predict(np.asarray(new_x))
203203

204204
# predicted outcome for treatment group
205205
self.x_pred_treatment = pd.DataFrame(
206206
{"group": [1, 1], "t": [0.0, 1.0], "post_treatment": [0, 1]}
207207
)
208208
assert not self.x_pred_treatment.empty
209209
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
210-
self.y_pred_treatment = self.prediction_model.predict(np.asarray(new_x))
210+
self.y_pred_treatment = self.model.predict(np.asarray(new_x))
211211

212212
# predicted outcome for counterfactual
213213
self.x_pred_counterfactual = pd.DataFrame(
@@ -217,7 +217,7 @@ def __init__(
217217
(new_x,) = build_design_matrices(
218218
[self._x_design_info], self.x_pred_counterfactual
219219
)
220-
self.y_pred_counterfactual = self.prediction_model.predict(np.asarray(new_x))
220+
self.y_pred_counterfactual = self.model.predict(np.asarray(new_x))
221221

222222
# calculate causal impact
223223
self.causal_impact = self.y_pred_treatment[1] - self.y_pred_counterfactual[0]
@@ -309,11 +309,11 @@ def __init__(
309309
data,
310310
formula,
311311
treatment_threshold,
312-
prediction_model=None,
312+
model=None,
313313
running_variable_name="x",
314314
**kwargs,
315315
):
316-
super().__init__(prediction_model=prediction_model, **kwargs)
316+
super().__init__(model=model, **kwargs)
317317
self.data = data
318318
self.formula = formula
319319
self.running_variable_name = running_variable_name
@@ -329,10 +329,10 @@ def __init__(
329329
# this could be a function rather than supplied data
330330

331331
# fit the model to all the data
332-
self.prediction_model.fit(X=self.X, y=self.y)
332+
self.model.fit(X=self.X, y=self.y)
333333

334334
# score the goodness of fit to all data
335-
self.score = self.prediction_model.score(X=self.X, y=self.y)
335+
self.score = self.model.score(X=self.X, y=self.y)
336336

337337
# get the model predictions of the observed data
338338
xi = np.linspace(
@@ -344,7 +344,7 @@ def __init__(
344344
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
345345
)
346346
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred)
347-
self.pred = self.prediction_model.predict(X=np.asarray(new_x))
347+
self.pred = self.model.predict(X=np.asarray(new_x))
348348

349349
# calculate discontinuity by evaluating the difference in model expectation on
350350
# either side of the discontinuity
@@ -359,7 +359,7 @@ def __init__(
359359
}
360360
)
361361
(new_x,) = build_design_matrices([self._x_design_info], self.x_discon)
362-
self.pred_discon = self.prediction_model.predict(X=np.asarray(new_x))
362+
self.pred_discon = self.model.predict(X=np.asarray(new_x))
363363
self.discontinuity_at_threshold = np.squeeze(self.pred_discon[1]) - np.squeeze(
364364
self.pred_discon[0]
365365
)
@@ -416,5 +416,5 @@ def summary(self):
416416
print("\nResults:")
417417
print(f"Discontinuity at threshold = {self.discontinuity_at_threshold:.2f}")
418418
print("Model coefficients:")
419-
for name, val in zip(self.labels, self.prediction_model.coef_[0]):
419+
for name, val in zip(self.labels, self.model.coef_[0]):
420420
print(f"\t{name}\t\t{val}")

0 commit comments

Comments
 (0)