Skip to content

Commit 5d0ef61

Browse files
committed
Merge branch 'feature/fix_regression_models' into dev
2 parents 99f9b70 + 4652f94 commit 5d0ef61

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

sklift/models/models.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def fit(self, X, y, treatment, estimator_trmnt_fit_params=None, estimator_ctrl_f
355355
if self._type_of_target == 'binary':
356356
ddr_treatment = self.estimator_trmnt.predict_proba(X_ctrl)[:, 1]
357357
else:
358-
ddr_treatment = self.estimator_trmnt.predict(X_ctrl)[:, 1]
358+
ddr_treatment = self.estimator_trmnt.predict(X_ctrl)
359359

360360
if isinstance(X_ctrl, np.ndarray):
361361
X_ctrl_mod = np.column_stack((X_ctrl, ddr_treatment))
@@ -393,21 +393,29 @@ def predict(self, X):
393393
X_mod = X.assign(ddr_control=self.ctrl_preds_)
394394
else:
395395
raise TypeError("Expected numpy.ndarray or pandas.DataFrame, got %s" % type(X))
396-
self.trmnt_preds_ = self.estimator_trmnt.predict_proba(X_mod)[:, 1]
396+
397+
if self._type_of_target == 'binary':
398+
self.trmnt_preds_ = self.estimator_trmnt.predict_proba(X_mod)[:, 1]
399+
else:
400+
self.trmnt_preds_ = self.estimator_trmnt.predict(X_mod)
397401

398402
elif self.method == 'ddr_treatment':
399403
if self._type_of_target == 'binary':
400404
self.trmnt_preds_ = self.estimator_trmnt.predict_proba(X)[:, 1]
401405
else:
402-
self.trmnt_preds_ = self.estimator_trmnt.predict_proba(X)[:, 1]
406+
self.trmnt_preds_ = self.estimator_trmnt.predict(X)
403407

404408
if isinstance(X, np.ndarray):
405409
X_mod = np.column_stack((X, self.trmnt_preds_))
406410
elif isinstance(X, pd.DataFrame):
407411
X_mod = X.assign(ddr_treatment=self.trmnt_preds_)
408412
else:
409413
raise TypeError("Expected numpy.ndarray or pandas.DataFrame, got %s" % type(X))
410-
self.ctrl_preds_ = self.estimator_ctrl.predict_proba(X_mod)[:, 1]
414+
415+
if self._type_of_target == 'binary':
416+
self.ctrl_preds_ = self.estimator_ctrl.predict_proba(X_mod)[:, 1]
417+
else:
418+
self.ctrl_preds_ = self.estimator_ctrl.predict(X_mod)
411419

412420
else:
413421
if self._type_of_target == 'binary':

0 commit comments

Comments
 (0)