Skip to content

Commit d20bbd2

Browse files
committed
Add get_n_splits to KFold classes for compat
1 parent 0c2a69a commit d20bbd2

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

econml/sklearn_extensions/model_selection.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,27 @@ def split(self, X, y, sample_weight=None):
137137
"""
138138
return _split_weighted_sample(self, X, y, sample_weight, is_stratified=False)
139139

140+
def get_n_splits(self, X, y, groups=None):
141+
"""Return the number of splitting iterations in the cross-validator.
142+
143+
Parameters
144+
----------
145+
X : object
146+
Always ignored, exists for compatibility.
147+
148+
y : object
149+
Always ignored, exists for compatibility.
150+
151+
groups : object
152+
Always ignored, exists for compatibility.
153+
154+
Returns
155+
-------
156+
n_splits : int
157+
Returns the number of splitting iterations in the cross-validator.
158+
"""
159+
return self.n_splits
160+
140161
def _get_folds_from_splits(self, splits, sample_size):
141162
folds = []
142163
sample_indices = np.arange(sample_size)
@@ -213,6 +234,27 @@ def split(self, X, y, sample_weight=None):
213234
"""
214235
return _split_weighted_sample(self, X, y, sample_weight, is_stratified=True)
215236

237+
def get_n_splits(self, X, y, groups=None):
238+
"""Return the number of splitting iterations in the cross-validator.
239+
240+
Parameters
241+
----------
242+
X : object
243+
Always ignored, exists for compatibility.
244+
245+
y : object
246+
Always ignored, exists for compatibility.
247+
248+
groups : object
249+
Always ignored, exists for compatibility.
250+
251+
Returns
252+
-------
253+
n_splits : int
254+
Returns the number of splitting iterations in the cross-validator.
255+
"""
256+
return self.n_splits
257+
216258

217259
class GridSearchCVList(BaseEstimator):
218260
""" An extension of GridSearchCV that allows for passing a list of estimators each with their own

0 commit comments

Comments
 (0)