Skip to content

Commit 0ed7240

Browse files
committed
#76 fix tests which strangely warn locally but fail remotely
1 parent bcf38f9 commit 0ed7240

File tree

4 files changed

+26
-17
lines changed

4 files changed

+26
-17
lines changed

causalpy/skl_experiments.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,30 +190,30 @@ def __init__(
190190
self.y, self.X = np.asarray(y), np.asarray(X)
191191
self.outcome_variable_name = y.design_info.column_names[0]
192192

193-
# TODO: `treated` is a deterministic function of group and time, so this should
194-
# be a function rather than supplied data
195-
196193
# fit the model to all the data
197194
self.prediction_model.fit(X=self.X, y=self.y)
198195

199196
# predicted outcome for control group
200197
self.x_pred_control = pd.DataFrame(
201-
{"group": [0, 0], "t": [0.0, 1.0], "treated": [0, 0]}
198+
{"group": [0, 0], "t": [0.0, 1.0], "post_treatment": [0, 0]}
202199
)
200+
assert not self.x_pred_control.empty
203201
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
204202
self.y_pred_control = self.prediction_model.predict(np.asarray(new_x))
205203

206204
# predicted outcome for treatment group
207205
self.x_pred_treatment = pd.DataFrame(
208-
{"group": [1, 1], "t": [0.0, 1.0], "treated": [0, 1]}
206+
{"group": [1, 1], "t": [0.0, 1.0], "post_treatment": [0, 1]}
209207
)
208+
assert not self.x_pred_treatment.empty
210209
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
211210
self.y_pred_treatment = self.prediction_model.predict(np.asarray(new_x))
212211

213212
# predicted outcome for counterfactual
214213
self.x_pred_counterfactual = pd.DataFrame(
215-
{"group": [1], "t": [1.0], "treated": [0]}
214+
{"group": [1], "t": [1.0], "post_treatment": [0]}
216215
)
216+
assert not self.x_pred_counterfactual.empty
217217
(new_x,) = build_design_matrices(
218218
[self._x_design_info], self.x_pred_counterfactual
219219
)

causalpy/tests/test_integration_skl_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_did():
1212
data = cp.load_data("did")
1313
result = cp.skl_experiments.DifferenceInDifferences(
1414
data,
15-
formula="y ~ 1 + group + t + treated:group",
15+
formula="y ~ 1 + group + t + group:post_treatment",
1616
time_variable_name="t",
1717
prediction_model=LinearRegression(),
1818
)

causalpy/tests/test_pymc_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_idata_property():
7777
df = cp.load_data("did")
7878
result = cp.pymc_experiments.DifferenceInDifferences(
7979
df,
80-
formula="y ~ 1 + group + t + treated:group",
80+
formula="y ~ 1 + group + t + group:post_treatment",
8181
time_variable_name="t",
8282
group_variable_name="group",
8383
treated=1,

docs/notebooks/did_skl.ipynb

Lines changed: 18 additions & 9 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)