@@ -109,7 +109,7 @@ def fit(self, Y, T, *, X, inference=None):
109109 self .models [ind ].fit (X [T == ind ], Y [T == ind ])
110110
111111 def const_marginal_effect (self , X ):
112- """Calculate the constant marignal treatment effect on a vector of features for each sample.
112+ """Calculate the constant marginal treatment effect on a vector of features for each sample.
113113
114114 Parameters
115115 ----------
@@ -127,7 +127,11 @@ def const_marginal_effect(self, X):
127127 X = check_array (X )
128128 taus = []
129129 for ind in range (self ._d_t [0 ]):
130- taus .append (self .models [ind + 1 ].predict (X ) - self .models [0 ].predict (X ))
130+ if hasattr (self .models [ind + 1 ], 'predict_proba' ):
131+ taus .append (self .models [ind + 1 ].predict_proba (X )[:, 1 ] - self .models [0 ].predict_proba (X )[:, 1 ])
132+ else :
133+ taus .append (self .models [ind + 1 ].predict (X ) - self .models [0 ].predict (X ))
134+
131135 taus = np .column_stack (taus ).reshape ((- 1 ,) + self ._d_t + self ._d_y ) # shape as of m*d_t*d_y
132136 if self ._d_y :
133137 taus = transpose (taus , (0 , 2 , 1 )) # shape as of m*d_y*d_t
@@ -242,7 +246,12 @@ def const_marginal_effect(self, X=None):
242246 X = check_array (X )
243247 Xs , Ts = broadcast_unit_treatments (X , self ._d_t [0 ] + 1 )
244248 feat_arr = np .concatenate ((Xs , Ts ), axis = 1 )
245- prediction = self .overall_model .predict (feat_arr ).reshape ((- 1 , self ._d_t [0 ] + 1 ,) + self ._d_y )
249+
250+ if hasattr (self .overall_model , 'predict_proba' ):
251+ prediction = self .overall_model .predict_proba (feat_arr )[:, 1 ].reshape ((- 1 , self ._d_t [0 ] + 1 ,) + self ._d_y )
252+ else :
253+ prediction = self .overall_model .predict (feat_arr ).reshape ((- 1 , self ._d_t [0 ] + 1 ,) + self ._d_y )
254+
246255 if self ._d_y :
247256 prediction = transpose (prediction , (0 , 2 , 1 ))
248257 taus = (prediction - np .repeat (prediction [:, :, 0 ], self ._d_t [0 ] + 1 ).reshape (prediction .shape ))[:, :, 1 :]
@@ -393,8 +402,14 @@ def const_marginal_effect(self, X):
393402 taus = []
394403 for ind in range (self ._d_t [0 ]):
395404 propensity_scores = self .propensity_models [ind ].predict_proba (X )[:, 1 :]
396- tau_hat = propensity_scores * self .cate_controls_models [ind ].predict (X ).reshape (m , - 1 ) \
397- + (1 - propensity_scores ) * self .cate_treated_models [ind ].predict (X ).reshape (m , - 1 )
405+
406+ if hasattr (self .cate_controls_models [ind ], 'predict_proba' ):
407+ tau_hat = propensity_scores * self .cate_controls_models [ind ].predict_proba (X )[:, 1 ].reshape (m , - 1 ) \
408+ + (1 - propensity_scores ) * self .cate_treated_models [ind ].predict_proba (X )[:, 1 ].reshape (m , - 1 )
409+ else :
410+ tau_hat = propensity_scores * self .cate_controls_models [ind ].predict (X ).reshape (m , - 1 ) \
411+ + (1 - propensity_scores ) * self .cate_treated_models [ind ].predict (X ).reshape (m , - 1 )
412+
398413 taus .append (tau_hat )
399414 taus = np .column_stack (taus ).reshape ((- 1 ,) + self ._d_t + self ._d_y ) # shape as of m*d_t*d_y
400415 if self ._d_y :
@@ -549,7 +564,10 @@ def const_marginal_effect(self, X):
549564 X = check_array (X )
550565 taus = []
551566 for model in self .final_models :
552- taus .append (model .predict (X ))
567+ if hasattr (model , 'predict_proba' ):
568+ taus .append (model .predict_proba (X )[:, 1 ])
569+ else :
570+ taus .append (model .predict (X ))
553571 taus = np .column_stack (taus ).reshape ((- 1 ,) + self ._d_t + self ._d_y ) # shape as of m*d_t*d_y
554572 if self ._d_y :
555573 taus = transpose (taus , (0 , 2 , 1 )) # shape as of m*d_y*d_t
0 commit comments