Skip to content

Commit 71eebeb

Browse files
committed
📝 small refactoring
1 parent a43fcbb commit 71eebeb

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

sklift/models/models.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,23 +93,23 @@ def fit(self, X, y, treatment, estimator_fit_params=None):
9393
if self.method == 'dummy':
9494
if isinstance(X, np.ndarray):
9595
X_mod = np.column_stack((X, treatment))
96-
elif isinstance(X, pd.core.frame.DataFrame):
96+
elif isinstance(X, pd.DataFrame):
9797
X_mod = X.assign(treatment=treatment)
9898
else:
9999
raise TypeError("Expected numpy.ndarray or pandas.DataFrame in training vector X, got %s" % type(X))
100100

101101
if self.method == 'treatment_interaction':
102102
if isinstance(X, np.ndarray):
103103
X_mod = np.column_stack((X, np.multiply(X, np.array(treatment).reshape(-1, 1)), treatment))
104-
elif isinstance(X, pd.core.frame.DataFrame):
104+
elif isinstance(X, pd.DataFrame):
105105
X_mod = pd.concat([
106106
X,
107107
X.apply(lambda x: x * treatment)
108108
.rename(columns=lambda x: str(x) + '_treatment_interaction')
109-
], axis=1)\
109+
], axis=1) \
110110
.assign(treatment=treatment)
111111
else:
112-
raise TypeError("Expected numpy.ndarray or pandas.DataFrame in training vector X, got %s" % type(X))
112+
raise TypeError("Expected numpy.ndarray or pandas.DataFrame in training vector X, got %s" % type(X))
113113

114114
self._type_of_target = type_of_target(y)
115115

@@ -133,7 +133,7 @@ def predict(self, X):
133133
if isinstance(X, np.ndarray):
134134
X_mod_trmnt = np.column_stack((X, np.ones(X.shape[0])))
135135
X_mod_ctrl = np.column_stack((X, np.zeros(X.shape[0])))
136-
elif isinstance(X, pd.core.frame.DataFrame):
136+
elif isinstance(X, pd.DataFrame):
137137
X_mod_trmnt = X.assign(treatment=np.ones(X.shape[0]))
138138
X_mod_ctrl = X.assign(treatment=np.zeros(X.shape[0]))
139139
else:
@@ -143,18 +143,18 @@ def predict(self, X):
143143
if isinstance(X, np.ndarray):
144144
X_mod_trmnt = np.column_stack((X, np.multiply(X, np.ones((X.shape[0], 1))), np.ones(X.shape[0])))
145145
X_mod_ctrl = np.column_stack((X, np.multiply(X, np.zeros((X.shape[0], 1))), np.zeros(X.shape[0])))
146-
elif isinstance(X, pd.core.frame.DataFrame):
146+
elif isinstance(X, pd.DataFrame):
147147
X_mod_trmnt = pd.concat([
148148
X,
149149
X.apply(lambda x: x * np.ones(X.shape[0]))
150150
.rename(columns=lambda x: str(x) + '_treatment_interaction')
151-
], axis=1)\
151+
], axis=1) \
152152
.assign(treatment=np.ones(X.shape[0]))
153153
X_mod_ctrl = pd.concat([
154154
X,
155155
X.apply(lambda x: x * np.zeros(X.shape[0]))
156156
.rename(columns=lambda x: str(x) + '_treatment_interaction')
157-
], axis=1)\
157+
], axis=1) \
158158
.assign(treatment=np.zeros(X.shape[0]))
159159
else:
160160
raise TypeError("Expected numpy.ndarray or pandas.DataFrame in training vector X, got %s" % type(X))
@@ -209,6 +209,7 @@ class ClassTransformation(BaseEstimator):
209209
.. _ClassTransformation in documentation:
210210
https://scikit-uplift.readthedocs.io/en/latest/api/models.html#class-transformation
211211
"""
212+
212213
def __init__(self, estimator):
213214
self.estimator = estimator
214215
self._type_of_target = None

0 commit comments

Comments
 (0)