Skip to content

Commit 81ea342

Browse files
committed
first draft of kink design notebook done
1 parent 15ef4ec commit 81ea342

File tree

3 files changed

+373
-6
lines changed

3 files changed

+373
-6
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ repos:
1010
exclude_types: [svg]
1111
- id: check-yaml
1212
- id: check-added-large-files
13+
args: ['--maxkb=1500']
1314
- repo: https://github.com/pycqa/isort
1415
rev: 5.12.0
1516
hooks:

causalpy/pymc_experiments.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,7 @@ def __init__(
10161016
self.x_pred = pd.DataFrame(
10171017
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
10181018
)
1019+
# self.x_pred = pd.DataFrame({self.running_variable_name: xi})
10191020
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred)
10201021
self.pred = self.model.predict(X=np.asarray(new_x))
10211022

@@ -1041,19 +1042,19 @@ def __init__(
10411042
self.gradient_left = (
10421043
self.pred_discon["posterior_predictive"].sel(obs_ind=1)["mu"]
10431044
- self.pred_discon["posterior_predictive"].sel(obs_ind=0)["mu"]
1044-
)
1045+
) / self.epsilon
10451046
self.gradient_right = (
10461047
self.pred_discon["posterior_predictive"].sel(obs_ind=2)["mu"]
10471048
- self.pred_discon["posterior_predictive"].sel(obs_ind=1)["mu"]
1048-
)
1049+
) / self.epsilon
10491050
self.gradient_change = self.gradient_right - self.gradient_left
10501051

10511052
def _input_validation(self):
10521053
"""Validate the input data and model formula for correctness"""
1053-
# if "treated" not in self.formula:
1054-
# raise FormulaException(
1055-
# "A predictor called `treated` should be in the formula"
1056-
# )
1054+
if "treated" not in self.formula:
1055+
raise FormulaException(
1056+
"A predictor called `treated` should be in the formula"
1057+
)
10571058

10581059
if _is_variable_dummy_coded(self.data["treated"]) is False:
10591060
raise DataException(

0 commit comments

Comments
 (0)