diff --git a/econml/metalearners/_metalearners.py b/econml/metalearners/_metalearners.py index 37dc2f6d2..99878f5f0 100644 --- a/econml/metalearners/_metalearners.py +++ b/econml/metalearners/_metalearners.py @@ -109,7 +109,7 @@ def fit(self, Y, T, *, X, inference=None): self.models[ind].fit(X[T == ind], Y[T == ind]) def const_marginal_effect(self, X): - """Calculate the constant marignal treatment effect on a vector of features for each sample. + """Calculate the constant marginal treatment effect on a vector of features for each sample. Parameters ---------- @@ -127,7 +127,14 @@ def const_marginal_effect(self, X): X = check_array(X) taus = [] for ind in range(self._d_t[0]): - taus.append(self.models[ind + 1].predict(X) - self.models[0].predict(X)) + if ( + hasattr(self.models[ind + 1], 'predict_proba') and + hasattr(self.models[0], 'predict_proba') + ): + taus.append(self.models[ind + 1].predict_proba(X)[:, 1] - self.models[0].predict_proba(X)[:, 1]) + else: + taus.append(self.models[ind + 1].predict(X) - self.models[0].predict(X)) + taus = np.column_stack(taus).reshape((-1,) + self._d_t + self._d_y) # shape as of m*d_t*d_y if self._d_y: taus = transpose(taus, (0, 2, 1)) # shape as of m*d_y*d_t @@ -242,7 +249,12 @@ def const_marginal_effect(self, X=None): X = check_array(X) Xs, Ts = broadcast_unit_treatments(X, self._d_t[0] + 1) feat_arr = np.concatenate((Xs, Ts), axis=1) - prediction = self.overall_model.predict(feat_arr).reshape((-1, self._d_t[0] + 1,) + self._d_y) + + if hasattr(self.overall_model, 'predict_proba'): + prediction = self.overall_model.predict_proba(feat_arr)[:, 1].reshape((-1, self._d_t[0] + 1,) + self._d_y) + else: + prediction = self.overall_model.predict(feat_arr).reshape((-1, self._d_t[0] + 1,) + self._d_y) + if self._d_y: prediction = transpose(prediction, (0, 2, 1)) taus = (prediction - np.repeat(prediction[:, :, 0], self._d_t[0] + 1).reshape(prediction.shape))[:, :, 1:] @@ -393,8 +405,17 @@ def const_marginal_effect(self, X): taus = [] for ind in range(self._d_t[0]): propensity_scores = self.propensity_models[ind].predict_proba(X)[:, 1:] - tau_hat = propensity_scores * self.cate_controls_models[ind].predict(X).reshape(m, -1) \ - + (1 - propensity_scores) * self.cate_treated_models[ind].predict(X).reshape(m, -1) + + if ( + hasattr(self.cate_controls_models[ind], 'predict_proba') and + hasattr(self.cate_treated_models[ind], 'predict_proba') + ): + tau_hat = propensity_scores * self.cate_controls_models[ind].predict_proba(X)[:, 1].reshape(m, -1) \ + + (1 - propensity_scores) * self.cate_treated_models[ind].predict_proba(X)[:, 1].reshape(m, -1) + else: + tau_hat = propensity_scores * self.cate_controls_models[ind].predict(X).reshape(m, -1) \ + + (1 - propensity_scores) * self.cate_treated_models[ind].predict(X).reshape(m, -1) + taus.append(tau_hat) taus = np.column_stack(taus).reshape((-1,) + self._d_t + self._d_y) # shape as of m*d_t*d_y if self._d_y: @@ -549,7 +570,10 @@ def const_marginal_effect(self, X): X = check_array(X) taus = [] for model in self.final_models: - taus.append(model.predict(X)) + if hasattr(model, 'predict_proba'): + taus.append(model.predict_proba(X)[:, 1]) + else: + taus.append(model.predict(X)) taus = np.column_stack(taus).reshape((-1,) + self._d_t + self._d_y) # shape as of m*d_t*d_y if self._d_y: taus = transpose(taus, (0, 2, 1)) # shape as of m*d_y*d_t