Skip to content

Commit 8c88413

Browse files
authored
Merge pull request #120 from dataiku/enhancement/cv-seed
Add cross-validation seed
2 parents daa6564 + 56be0f0 commit 8c88413

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

dataikuapi/dss/ml.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def _repr_html_(self):
453453
res += self._key_repr("splitRatio")
454454
elif self._raw_settings["mode"] in {"KFOLD", "TIME_SERIES_KFOLD"}:
455455
res += self._key_repr("nFolds")
456-
456+
res += self._key_repr("cvSeed")
457457
res += self._key_repr("stratified")
458458

459459
res += "Execution Settings:\n"
@@ -491,6 +491,11 @@ def _set_seed(self, seed):
491491
else:
492492
self._raw_settings["seed"] = seed
493493

494+
def _set_cv_seed(self, seed):
495+
if seed is not None:
496+
assert isinstance(seed, int), "HyperparameterSearchSettings invalid input: cvSeed must be an integer"
497+
self._raw_settings["cvSeed"] = seed
498+
494499
@property
495500
def strategy(self):
496501
"""
@@ -508,7 +513,7 @@ def strategy(self, strategy):
508513
assert strategy in {"GRID", "RANDOM", "BAYESIAN"}
509514
self._raw_settings["strategy"] = strategy
510515

511-
def set_grid_search(self, shuffle=True, seed=0):
516+
def set_grid_search(self, shuffle=True, seed=1337):
512517
"""
513518
Sets the search strategy to "GRID" to perform a grid-search on the hyperparameters.
514519
@@ -527,7 +532,7 @@ def set_grid_search(self, shuffle=True, seed=0):
527532
self._raw_settings["randomized"] = shuffle
528533
self._set_seed(seed)
529534

530-
def set_random_search(self, seed=0):
535+
def set_random_search(self, seed=1337):
531536
"""
532537
Sets the search strategy to "RANDOM" to perform a random search on the hyperparameters.
533538
@@ -537,7 +542,7 @@ def set_random_search(self, seed=0):
537542
self._raw_settings["strategy"] = "RANDOM"
538543
self._set_seed(seed)
539544

540-
def set_bayesian_search(self, seed=0):
545+
def set_bayesian_search(self, seed=1337):
541546
"""
542547
Sets the search strategy to "BAYESIAN" to perform a Bayesian search on the hyperparameters.
543548
@@ -564,7 +569,23 @@ def validation_mode(self, mode):
564569
assert mode in {"KFOLD", "SHUFFLE", "TIME_SERIES_KFOLD", "TIME_SERIES_SINGLE_SPLIT", "CUSTOM"}
565570
self._raw_settings["mode"] = mode
566571

567-
def set_kfold_validation(self, n_folds=5, stratified=True):
572+
@property
573+
def cv_seed(self):
574+
"""
575+
:return: cross-validation seed for splitting the data during hyperparameter search
576+
:rtype: int
577+
"""
578+
return self._raw_settings["cvSeed"]
579+
580+
@cv_seed.setter
581+
def cv_seed(self, seed):
582+
"""
583+
:param seed: cross-validation seed for splitting the data during hyperparameter search
584+
:type seed: int
585+
"""
586+
self._set_cv_seed(seed)
587+
588+
def set_kfold_validation(self, n_folds=5, stratified=True, cv_seed=1337):
568589
"""
569590
Sets the validation mode to k-fold cross-validation (either "KFOLD" or "TIME_SERIES_KFOLD" if time-based ordering
570591
is enabled).
@@ -590,8 +611,9 @@ def set_kfold_validation(self, n_folds=5, stratified=True):
590611
warnings.warn("stratified must be a boolean")
591612
else:
592613
self._raw_settings["stratified"] = stratified
614+
self._set_cv_seed(cv_seed)
593615

594-
def set_single_split_validation(self, split_ratio=0.8, stratified=True):
616+
def set_single_split_validation(self, split_ratio=0.8, stratified=True, cv_seed=1337):
595617
"""
596618
Sets the validation mode to single split (either "SHUFFLE" or "TIME_SERIES_SINGLE_SPLIT" if time-based ordering
597619
is enabled).
@@ -617,6 +639,7 @@ def set_single_split_validation(self, split_ratio=0.8, stratified=True):
617639
warnings.warn("stratified must be a boolean")
618640
else:
619641
self._raw_settings["stratified"] = stratified
642+
self._set_cv_seed(cv_seed)
620643

621644
def set_custom_validation(self, code=None):
622645
"""

0 commit comments

Comments
 (0)