Skip to content

Commit 1dd73c7

Browse files
authored
Deprecate n_splits with cv (#362)
* deprecated n_splits with cv
1 parent b5c25cc commit 1dd73c7

26 files changed

+3298
-3232
lines changed

doc/spec/estimation/dml.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,10 @@ Usage FAQs
551551
If one uses cross-validated estimators as first stages, then model selection for the first stage models
552552
is performed automatically.
553553

554-
- **How should I set the parameter `n_splits`?**
554+
- **How should I set the parameter `cv`?**
555555

556556
This parameter defines the number of data partitions to create in order to fit the first stages in a
557-
crossfittin manner (see :class:`._OrthoLearner`). The default is 2, which
557+
crossfitting manner (see :class:`._OrthoLearner`). The default is 2, which
558558
is the minimal. However, larger values like 5 or 6 can lead to greater statistical stability of the method,
559559
especially if the number of samples is small. So we advise that for small datasets, one should raise this
560560
value. This can increase the computational cost as more first stage models are being fitted.

doc/spec/estimation/dr.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ Below we give a brief description of each of these classes:
242242
}, cv=10, n_jobs=-1, scoring='neg_mean_squared_error'
243243
)
244244
est = DRLearner(model_regression=model_reg(), model_propensity=model_clf(),
245-
model_final=model_reg(), n_splits=5)
245+
model_final=model_reg(), cv=5)
246246
est.fit(y, T, X=X, W=W)
247247
point = est.effect(X, T0=T0, T1=T1)
248248

@@ -427,7 +427,7 @@ Usage FAQs
427427
}, cv=5, n_jobs=-1, scoring='neg_mean_squared_error'
428428
)
429429
est = DRLearner(model_regression=model_reg(), model_propensity=model_clf(),
430-
model_final=model_reg(), n_splits=5)
430+
model_final=model_reg(), cv=5)
431431
est.fit(y, T, X=X, W=W)
432432
point = est.effect(X, T0=T0, T1=T1)
433433

@@ -467,10 +467,10 @@ Usage FAQs
467467
If one uses cross-validated estimators as first stages, then model selection for the first stage models
468468
is performed automatically.
469469

470-
- **How should I set the parameter `n_splits`?**
470+
- **How should I set the parameter `cv`?**
471471

472472
This parameter defines the number of data partitions to create in order to fit the first stages in a
473-
crossfittin manner (see :class:`._OrthoLearner`). The default is 2, which
473+
crossfitting manner (see :class:`._OrthoLearner`). The default is 2, which
474474
is the minimal. However, larger values like 5 or 6 can lead to greater statistical stability of the method,
475475
especially if the number of samples is small. So we advise that for small datasets, one should raise this
476476
value. This can increase the computational cost as more first stage models are being fitted.

econml/_ortho_learner.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
259259
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
260260
The first category will be treated as the control treatment.
261261
262-
n_splits: int, cross-validation generator or an iterable
262+
cv: int, cross-validation generator or an iterable
263263
Determines the cross-validation splitting strategy.
264264
Possible inputs for cv are:
265265
@@ -333,7 +333,7 @@ def _gen_ortho_learner_model_final(self):
333333
np.random.seed(123)
334334
X = np.random.normal(size=(100, 3))
335335
y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.1, size=(100,))
336-
est = OrthoLearner(n_splits=2, discrete_treatment=False, discrete_instrument=False,
336+
est = OrthoLearner(cv=2, discrete_treatment=False, discrete_instrument=False,
337337
categories='auto', random_state=None)
338338
est.fit(y, X[:, 0], W=X[:, 1:])
339339
@@ -391,7 +391,7 @@ def _gen_ortho_learner_model_final(self):
391391
import scipy.special
392392
T = np.random.binomial(1, scipy.special.expit(W[:, 0]))
393393
y = T + W[:, 0] + np.random.normal(0, 0.01, size=(100,))
394-
est = OrthoLearner(n_splits=2, discrete_treatment=True, discrete_instrument=False,
394+
est = OrthoLearner(cv=2, discrete_treatment=True, discrete_instrument=False,
395395
categories='auto', random_state=None)
396396
est.fit(y, T, W=W)
397397
@@ -424,8 +424,9 @@ def _gen_ortho_learner_model_final(self):
424424
"""
425425

426426
def __init__(self, *,
427-
discrete_treatment, discrete_instrument, categories, n_splits, random_state,
428-
mc_iters=None, mc_agg='mean'):
427+
discrete_treatment, discrete_instrument, categories, cv, random_state,
428+
n_splits='raise', mc_iters=None, mc_agg='mean'):
429+
self.cv = cv
429430
self.n_splits = n_splits
430431
self.discrete_treatment = discrete_treatment
431432
self.discrete_instrument = discrete_instrument
@@ -566,7 +567,7 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, sample_var=No
566567
Sample variance for each sample
567568
groups: (n,) vector, optional
568569
All rows corresponding to the same group will be kept together during splitting.
569-
If groups is not None, the n_splits argument passed to this class's initializer
570+
If groups is not None, the cv argument passed to this class's initializer
570571
must support a 'groups' argument to its split method.
571572
cache_values: bool, default False
572573
Whether to cache the inputs and computed nuisances, which will allow refitting a different final model
@@ -712,16 +713,16 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, group
712713
if self.discrete_instrument:
713714
Z = self.z_transformer.transform(reshape(Z, (-1, 1)))
714715

715-
if self.n_splits == 1: # special case, no cross validation
716+
if self.cv == 1: # special case, no cross validation
716717
folds = None
717718
else:
718-
splitter = check_cv(self.n_splits, [0], classifier=stratify)
719+
splitter = check_cv(self.cv, [0], classifier=stratify)
719720
# if check_cv produced a new KFold or StratifiedKFold object, we need to set shuffle and random_state
720721
# TODO: ideally, we'd also infer whether we need a GroupKFold (if groups are passed)
721722
# however, sklearn doesn't support both stratifying and grouping (see
722723
# https://github.com/scikit-learn/scikit-learn/issues/13621), so for now the user needs to supply
723724
# their own object that supports grouping if they want to use groups.
724-
if splitter != self.n_splits and isinstance(splitter, (KFold, StratifiedKFold)):
725+
if splitter != self.cv and isinstance(splitter, (KFold, StratifiedKFold)):
725726
splitter.shuffle = True
726727
splitter.random_state = self._random_state
727728

@@ -856,3 +857,18 @@ def models_nuisance_(self):
856857
if not hasattr(self, '_models_nuisance'):
857858
raise AttributeError("Model is not fitted!")
858859
return self._models_nuisance
860+
861+
#######################################################
862+
# These should be removed once `n_splits` is deprecated
863+
#######################################################
864+
865+
@property
866+
def n_splits(self):
867+
return self.cv
868+
869+
@n_splits.setter
870+
def n_splits(self, value):
871+
if value != 'raise':
872+
warn("Parameter `n_splits` has been deprecated and will be removed in the next version. "
873+
"Use parameter `cv` instead.")
874+
self.cv = value

econml/dml/_rlearner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class _RLearner(_OrthoLearner):
147147
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
148148
The first category will be treated as the control treatment.
149149
150-
n_splits: int, cross-validation generator or an iterable
150+
cv: int, cross-validation generator or an iterable
151151
Determines the cross-validation splitting strategy.
152152
Possible inputs for cv are:
153153
@@ -216,7 +216,7 @@ def _gen_rlearner_model_final(self):
216216
np.random.seed(123)
217217
X = np.random.normal(size=(1000, 3))
218218
y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.01, size=(1000,))
219-
est = RLearner(n_splits=2, discrete_treatment=False, categories='auto', random_state=None)
219+
est = RLearner(cv=2, discrete_treatment=False, categories='auto', random_state=None)
220220
est.fit(y, X[:, 0], X=np.ones((X.shape[0], 1)), W=X[:, 1:])
221221
222222
>>> est.const_marginal_effect(np.ones((1,1)))
@@ -261,10 +261,12 @@ def _gen_rlearner_model_final(self):
261261
is multidimensional, then the average of the MSEs for each dimension of Y is returned.
262262
"""
263263

264-
def __init__(self, *, discrete_treatment, categories, n_splits, random_state, mc_iters=None, mc_agg='mean'):
264+
def __init__(self, *, discrete_treatment, categories, cv, random_state,
265+
n_splits='raise', mc_iters=None, mc_agg='mean'):
265266
super().__init__(discrete_treatment=discrete_treatment,
266267
discrete_instrument=False, # no instrument, so doesn't matter
267268
categories=categories,
269+
cv=cv,
268270
n_splits=n_splits,
269271
random_state=random_state,
270272
mc_iters=mc_iters,
@@ -345,7 +347,7 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
345347
Sample variance for each sample
346348
groups: (n,) vector, optional
347349
All rows corresponding to the same group will be kept together during splitting.
348-
If groups is not None, the n_splits argument passed to this class's initializer
350+
If groups is not None, the `cv` argument passed to this class's initializer
349351
must support a 'groups' argument to its split method.
350352
cache_values: bool, default False
351353
Whether to cache inputs and first stage results, which will allow refitting a different final model

econml/dml/causal_forest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,8 @@ def __init__(self, *,
439439
cv = self.n_crossfit_splits
440440
super().__init__(discrete_treatment=discrete_treatment,
441441
categories=categories,
442-
# TODO. change to `cv=cv, n_splits='raise` when merged with the `n_splits` deprecation PR
443-
n_splits=cv,
442+
cv=cv,
443+
n_splits=n_crossfit_splits,
444444
mc_iters=mc_iters,
445445
mc_agg=mc_agg,
446446
random_state=random_state)

0 commit comments

Comments
 (0)