@@ -355,7 +355,7 @@ def fit(self, X, y, treatment, estimator_trmnt_fit_params=None, estimator_ctrl_f
355355 if self ._type_of_target == 'binary' :
356356 ddr_treatment = self .estimator_trmnt .predict_proba (X_ctrl )[:, 1 ]
357357 else :
358- ddr_treatment = self .estimator_trmnt .predict (X_ctrl )[:, 1 ]
358+ ddr_treatment = self .estimator_trmnt .predict (X_ctrl )
359359
360360 if isinstance (X_ctrl , np .ndarray ):
361361 X_ctrl_mod = np .column_stack ((X_ctrl , ddr_treatment ))
@@ -393,21 +393,29 @@ def predict(self, X):
393393 X_mod = X .assign (ddr_control = self .ctrl_preds_ )
394394 else :
395395 raise TypeError ("Expected numpy.ndarray or pandas.DataFrame, got %s" % type (X ))
396- self .trmnt_preds_ = self .estimator_trmnt .predict_proba (X_mod )[:, 1 ]
396+
397+ if self ._type_of_target == 'binary' :
398+ self .trmnt_preds_ = self .estimator_trmnt .predict_proba (X_mod )[:, 1 ]
399+ else :
400+ self .trmnt_preds_ = self .estimator_trmnt .predict (X_mod )
397401
398402 elif self .method == 'ddr_treatment' :
399403 if self ._type_of_target == 'binary' :
400404 self .trmnt_preds_ = self .estimator_trmnt .predict_proba (X )[:, 1 ]
401405 else :
402- self .trmnt_preds_ = self .estimator_trmnt .predict_proba (X )[:, 1 ]
406+ self .trmnt_preds_ = self .estimator_trmnt .predict (X )
403407
404408 if isinstance (X , np .ndarray ):
405409 X_mod = np .column_stack ((X , self .trmnt_preds_ ))
406410 elif isinstance (X , pd .DataFrame ):
407411 X_mod = X .assign (ddr_treatment = self .trmnt_preds_ )
408412 else :
409413 raise TypeError ("Expected numpy.ndarray or pandas.DataFrame, got %s" % type (X ))
410- self .ctrl_preds_ = self .estimator_ctrl .predict_proba (X_mod )[:, 1 ]
414+
415+ if self ._type_of_target == 'binary' :
416+ self .ctrl_preds_ = self .estimator_ctrl .predict_proba (X_mod )[:, 1 ]
417+ else :
418+ self .ctrl_preds_ = self .estimator_ctrl .predict (X_mod )
411419
412420 else :
413421 if self ._type_of_target == 'binary' :
0 commit comments