@@ -259,7 +259,7 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
259259 The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
260260 The first category will be treated as the control treatment.
261261
262- n_splits : int, cross-validation generator or an iterable
262+ cv : int, cross-validation generator or an iterable
263263 Determines the cross-validation splitting strategy.
264264 Possible inputs for cv are:
265265
@@ -333,7 +333,7 @@ def _gen_ortho_learner_model_final(self):
333333 np.random.seed(123)
334334 X = np.random.normal(size=(100, 3))
335335 y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.1, size=(100,))
336- est = OrthoLearner(n_splits =2, discrete_treatment=False, discrete_instrument=False,
336+ est = OrthoLearner(cv =2, discrete_treatment=False, discrete_instrument=False,
337337 categories='auto', random_state=None)
338338 est.fit(y, X[:, 0], W=X[:, 1:])
339339
@@ -391,7 +391,7 @@ def _gen_ortho_learner_model_final(self):
391391 import scipy.special
392392 T = np.random.binomial(1, scipy.special.expit(W[:, 0]))
393393 y = T + W[:, 0] + np.random.normal(0, 0.01, size=(100,))
394- est = OrthoLearner(n_splits =2, discrete_treatment=True, discrete_instrument=False,
394+ est = OrthoLearner(cv =2, discrete_treatment=True, discrete_instrument=False,
395395 categories='auto', random_state=None)
396396 est.fit(y, T, W=W)
397397
@@ -424,8 +424,9 @@ def _gen_ortho_learner_model_final(self):
424424 """
425425
426426 def __init__ (self , * ,
427- discrete_treatment , discrete_instrument , categories , n_splits , random_state ,
428- mc_iters = None , mc_agg = 'mean' ):
427+ discrete_treatment , discrete_instrument , categories , cv , random_state ,
428+ n_splits = 'raise' , mc_iters = None , mc_agg = 'mean' ):
429+ self .cv = cv
429430 self .n_splits = n_splits
430431 self .discrete_treatment = discrete_treatment
431432 self .discrete_instrument = discrete_instrument
@@ -566,7 +567,7 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, sample_var=No
566567 Sample variance for each sample
567568 groups: (n,) vector, optional
568569 All rows corresponding to the same group will be kept together during splitting.
569- If groups is not None, the n_splits argument passed to this class's initializer
570+ If groups is not None, the cv argument passed to this class's initializer
570571 must support a 'groups' argument to its split method.
571572 cache_values: bool, default False
572573 Whether to cache the inputs and computed nuisances, which will allow refitting a different final model
@@ -712,16 +713,16 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, group
712713 if self .discrete_instrument :
713714 Z = self .z_transformer .transform (reshape (Z , (- 1 , 1 )))
714715
715- if self .n_splits == 1 : # special case, no cross validation
716+ if self .cv == 1 : # special case, no cross validation
716717 folds = None
717718 else :
718- splitter = check_cv (self .n_splits , [0 ], classifier = stratify )
719+ splitter = check_cv (self .cv , [0 ], classifier = stratify )
719720 # if check_cv produced a new KFold or StratifiedKFold object, we need to set shuffle and random_state
720721 # TODO: ideally, we'd also infer whether we need a GroupKFold (if groups are passed)
721722 # however, sklearn doesn't support both stratifying and grouping (see
722723 # https://github.com/scikit-learn/scikit-learn/issues/13621), so for now the user needs to supply
723724 # their own object that supports grouping if they want to use groups.
724- if splitter != self .n_splits and isinstance (splitter , (KFold , StratifiedKFold )):
725+ if splitter != self .cv and isinstance (splitter , (KFold , StratifiedKFold )):
725726 splitter .shuffle = True
726727 splitter .random_state = self ._random_state
727728
@@ -856,3 +857,18 @@ def models_nuisance_(self):
856857 if not hasattr (self , '_models_nuisance' ):
857858 raise AttributeError ("Model is not fitted!" )
858859 return self ._models_nuisance
860+
861+ #######################################################
862+ # These should be removed once `n_splits` is deprecated
863+ #######################################################
864+
865+ @property
866+ def n_splits (self ):
867+ return self .cv
868+
869+ @n_splits .setter
870+ def n_splits (self , value ):
871+ if value != 'raise' :
872+ warn ("Parameter `n_splits` has been deprecated and will be removed in the next version. "
873+ "Use parameter `cv` instead." )
874+ self .cv = value
0 commit comments