Skip to content

Commit 0c2a69a

Browse files
committed
Ensure use of get_feature_names_or_default
1 parent a37eb3c commit 0c2a69a

File tree

5 files changed

+33
-25
lines changed

5 files changed

+33
-25
lines changed

econml/solutions/causal_analysis/_causal_analysis.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ...inference import NormalInferenceResults
2626
from ...sklearn_extensions.linear_model import WeightedLasso
2727
from ...sklearn_extensions.model_selection import GridSearchCVList
28-
from ...utilities import _RegressionWrapper, inverse_onehot
28+
from ...utilities import _RegressionWrapper, get_feature_names_or_default, inverse_onehot
2929

3030
# TODO: this utility is documented but internal; reimplement?
3131
from sklearn.utils import _safe_indexing
@@ -220,13 +220,16 @@ def transform(self, X):
220220
else:
221221
return rest
222222

223+
# TODO: remove once older sklearn support is no longer needed
223224
def get_feature_names(self, names=None):
225+
return self.get_feature_names_out(names)
226+
227+
def get_feature_names_out(self, names=None):
224228
if names is None:
225229
names = [f"x{i}" for i in range(self.d_x)]
226230
rest = _safe_indexing(names, self.passthrough, axis=0)
227231
if self.has_cats:
228-
cats = self.one_hot_encoder.get_feature_names(
229-
_safe_indexing(names, self.categorical, axis=0))
232+
cats = get_feature_names_or_default(self.one_hot_encoder, _safe_indexing(names, self.categorical, axis=0))
230233
return np.concatenate((rest, cats))
231234
else:
232235
return rest
@@ -1445,7 +1448,7 @@ def _tree(self, is_policy, Xtest, feature_index, *, treatment_costs=0,
14451448
intrp.interpret(result.estimator, Xtest)
14461449
policy_values = None
14471450

1448-
return intrp, result.X_transformer.get_feature_names(self.feature_names_), treatment_names, policy_values
1451+
return intrp, result.X_transformer.get_feature_names_out(self.feature_names_), treatment_names, policy_values
14491452

14501453
# TODO: it seems like it would be better to just return the tree itself rather than plot it;
14511454
# however, the tree can't store the feature and treatment names we compute here...

econml/tests/test_drlearner.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from econml.dr import DRLearner, LinearDRLearner, SparseLinearDRLearner, ForestDRLearner
2222
from econml.inference import BootstrapInference, StatsModelsInferenceDiscrete
23-
from econml.utilities import shape, hstack, vstack, reshape, cross_product
23+
from econml.utilities import get_feature_names_or_default, shape, hstack, vstack, reshape, cross_product
2424
from econml.sklearn_extensions.linear_model import StatsModelsLinearRegression
2525
import econml.tests.utilities # bugfix for assertWarns
2626

@@ -451,10 +451,10 @@ def test_drlearner_all_attributes(self):
451451
feature_names = ['A', 'B', 'C']
452452
out_feat_names = feature_names
453453
if featurizer is not None:
454-
out_feat_names = featurizer.fit(
455-
X).get_feature_names(feature_names)
454+
out_feat_names = get_feature_names_or_default(featurizer.fit(X),
455+
feature_names)
456456
np.testing.assert_array_equal(
457-
est.featurizer_.n_input_features_, 3)
457+
est.featurizer_.n_features_in_, 3)
458458
np.testing.assert_array_equal(est.cate_feature_names(feature_names),
459459
out_feat_names)
460460

@@ -631,10 +631,10 @@ def test_drlearner_with_inference_all_attributes(self):
631631
out_feat_names = feature_names
632632
if X is not None:
633633
if (featurizer is not None):
634-
out_feat_names = featurizer.fit(
635-
X).get_feature_names(feature_names)
634+
out_feat_names = get_feature_names_or_default(featurizer.fit(X),
635+
feature_names)
636636
np.testing.assert_array_equal(
637-
est.featurizer_.n_input_features_, 2)
637+
est.featurizer_.n_features_in_, 2)
638638
np.testing.assert_array_equal(est.cate_feature_names(feature_names),
639639
out_feat_names)
640640

econml/tests/test_inference.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from econml.inference import (BootstrapInference, NormalInferenceResults,
1515
EmpiricalInferenceResults, PopulationSummaryResults)
1616
from econml.sklearn_extensions.linear_model import StatsModelsLinearRegression, DebiasedLasso
17-
from econml.utilities import get_input_columns
17+
from econml.utilities import get_feature_names_or_default, get_input_columns
1818

1919

2020
class TestInference(unittest.TestCase):
@@ -51,8 +51,9 @@ def test_summary(self):
5151
summary_results = cate_est.summary()
5252
coef_rows = np.asarray(summary_results.tables[0].data)[1:, 0]
5353
default_names = get_input_columns(TestInference.X)
54-
fnames = PolynomialFeatures(degree=2, include_bias=False).fit(
55-
TestInference.X).get_feature_names(default_names)
54+
fnames = get_feature_names_or_default(PolynomialFeatures(degree=2,
55+
include_bias=False).fit(TestInference.X),
56+
default_names)
5657
np.testing.assert_array_equal(coef_rows, fnames)
5758
intercept_rows = np.asarray(summary_results.tables[1].data)[1:, 0]
5859
np.testing.assert_array_equal(intercept_rows, ['cate_intercept'])
@@ -71,8 +72,9 @@ def test_summary(self):
7172
fnames = ['Q' + str(i) for i in range(TestInference.d_x)]
7273
summary_results = cate_est.summary(feature_names=fnames)
7374
coef_rows = np.asarray(summary_results.tables[0].data)[1:, 0]
74-
fnames = PolynomialFeatures(degree=2, include_bias=False).fit(
75-
TestInference.X).get_feature_names(input_features=fnames)
75+
fnames = get_feature_names_or_default(PolynomialFeatures(degree=2,
76+
include_bias=False).fit(TestInference.X),
77+
fnames)
7678
np.testing.assert_array_equal(coef_rows, fnames)
7779
cate_est = LinearDML(model_t=LinearRegression(), model_y=LinearRegression(), featurizer=None)
7880
cate_est.fit(
@@ -145,8 +147,9 @@ def test_summary_discrete(self):
145147
summary_results = cate_est.summary(T=1)
146148
coef_rows = np.asarray(summary_results.tables[0].data)[1:, 0]
147149
default_names = get_input_columns(TestInference.X)
148-
fnames = PolynomialFeatures(degree=2, include_bias=False).fit(
149-
TestInference.X).get_feature_names(default_names)
150+
fnames = get_feature_names_or_default(PolynomialFeatures(degree=2,
151+
include_bias=False).fit(TestInference.X),
152+
default_names)
150153
np.testing.assert_array_equal(coef_rows, fnames)
151154
intercept_rows = np.asarray(summary_results.tables[1].data)[1:, 0]
152155
np.testing.assert_array_equal(intercept_rows, ['cate_intercept'])
@@ -166,8 +169,9 @@ def test_summary_discrete(self):
166169
fnames = ['Q' + str(i) for i in range(TestInference.d_x)]
167170
summary_results = cate_est.summary(T=1, feature_names=fnames)
168171
coef_rows = np.asarray(summary_results.tables[0].data)[1:, 0]
169-
fnames = PolynomialFeatures(degree=2, include_bias=False).fit(
170-
TestInference.X).get_feature_names(input_features=fnames)
172+
fnames = get_feature_names_or_default(PolynomialFeatures(degree=2,
173+
include_bias=False).fit(TestInference.X),
174+
fnames)
171175
np.testing.assert_array_equal(coef_rows, fnames)
172176
cate_est = LinearDRLearner(model_regression=LinearRegression(),
173177
model_propensity=LogisticRegression(), featurizer=None)

econml/tests/test_integration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT License.
33

4+
from econml.utilities import get_feature_names_or_default
45
import numpy as np
56
import pandas as pd
67
import unittest
@@ -76,7 +77,7 @@ def test_dml(self):
7677
est.fit(Y, T, X=X, W=W, inference='statsmodels')
7778
self._check_input_names(
7879
est.summary(),
79-
feat_comp=est.original_featurizer.get_feature_names(X.columns))
80+
feat_comp=get_feature_names_or_default(est.original_featurizer, X.columns))
8081
est.featurizer = FunctionTransformer()
8182
est.fit(Y, T, X=X, W=W, inference='statsmodels')
8283
self._check_input_names(

notebooks/Solutions/Causal Interpretation for Employee Attrition Dataset.ipynb

Lines changed: 4 additions & 4 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)