Skip to content

Commit 0ef90ff

Browse files
committed
Add cross-validation seed property and kwargs in HyperparameterSearchSettings.set_{kfold,single_split}_validation
1 parent 7062cfa commit 0ef90ff

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

dataikuapi/dss/ml.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def _repr_html_(self):
433433
res += self._key_repr("splitRatio")
434434
elif self._raw_settings["mode"] in {"KFOLD", "TIME_SERIES_KFOLD"}:
435435
res += self._key_repr("nFolds")
436-
436+
res += self._key_repr("cvSeed")
437437
res += self._key_repr("stratified")
438438

439439
res += "Execution Settings:\n"
@@ -471,6 +471,14 @@ def _set_seed(self, seed):
471471
else:
472472
self._raw_settings["seed"] = seed
473473

474+
def _set_cv_seed(self, seed):
475+
if seed is not None:
476+
if not isinstance(seed, int):
477+
warnings.warn("HyperparameterSearchSettings ignoring invalid input: seed")
478+
warnings.warn("seed must be an integer")
479+
else:
480+
self._raw_settings["cvSeed"] = seed
481+
474482
@property
475483
def strategy(self):
476484
"""
@@ -544,7 +552,24 @@ def validation_mode(self, mode):
544552
assert mode in {"KFOLD", "SHUFFLE", "TIME_SERIES_KFOLD", "TIME_SERIES_SINGLE_SPLIT", "CUSTOM"}
545553
self._raw_settings["mode"] = mode
546554

547-
def set_kfold_validation(self, n_folds=5, stratified=True):
555+
@property
556+
def cv_seed(self):
557+
"""
558+
:return: cross-validation seed for splitting the data during hyperparameter search
559+
:rtype: int
560+
"""
561+
return self._raw_settings["cvSeed"]
562+
563+
@cv_seed.setter
564+
def cv_seed(self, seed):
565+
"""
566+
:param seed: cross-validation seed for splitting the data during hyperparameter search
567+
:type seed: int
568+
"""
569+
assert isinstance(seed, int)
570+
self._raw_settings["cvSeed"] = seed
571+
572+
def set_kfold_validation(self, n_folds=5, stratified=True, cv_seed=0):
548573
"""
549574
Sets the validation mode to k-fold cross-validation (either "KFOLD" or "TIME_SERIES_KFOLD" if time-based ordering
550575
is enabled).
@@ -570,8 +595,10 @@ def set_kfold_validation(self, n_folds=5, stratified=True):
570595
warnings.warn("stratified must be a boolean")
571596
else:
572597
self._raw_settings["stratified"] = stratified
598+
if cv_seed is not None:
599+
self._set_cv_seed(cv_seed)
573600

574-
def set_single_split_validation(self, split_ratio=0.8, stratified=True):
601+
def set_single_split_validation(self, split_ratio=0.8, stratified=True, cv_seed=0):
575602
"""
576603
Sets the validation mode to single split (either "SHUFFLE" or "TIME_SERIES_SINGLE_SPLIT" if time-based ordering
577604
is enabled).
@@ -597,6 +624,8 @@ def set_single_split_validation(self, split_ratio=0.8, stratified=True):
597624
warnings.warn("stratified must be a boolean")
598625
else:
599626
self._raw_settings["stratified"] = stratified
627+
if cv_seed is not None:
628+
self._set_cv_seed(cv_seed)
600629

601630
def set_custom_validation(self, code=None):
602631
"""

0 commit comments

Comments
 (0)