@@ -413,6 +413,7 @@ def __init__(
413413 strategy_config : Optional [StrategyConfig ] = None ,
414414 early_stopping_type : Union [str , PipelineVariable ] = "Off" ,
415415 estimator_name : Optional [str ] = None ,
416+ random_seed : Optional [int ] = None ,
416417 ):
417418 """Creates a ``HyperparameterTuner`` instance.
418419
@@ -470,6 +471,9 @@ def __init__(
470471 estimator_name (str): A unique name to identify an estimator within the
471472 hyperparameter tuning job, when more than one estimator is used with
472473 the same tuning job (default: None).
474+ random_seed (int): An initial value used to initialize a pseudo-random number generator.
475+ Setting a random seed will make the hyperparameter tuning search strategies to
476+ produce more consistent configurations for the same tuning job.
473477 """
474478 if hyperparameter_ranges is None or len (hyperparameter_ranges ) == 0 :
475479 raise ValueError ("Need to specify hyperparameter ranges" )
@@ -516,6 +520,7 @@ def __init__(
516520 self .latest_tuning_job = None
517521 self .warm_start_config = warm_start_config
518522 self .early_stopping_type = early_stopping_type
523+ self .random_seed = random_seed
519524
520525 def _prepare_for_tuning (self , job_name = None , include_cls_metadata = False ):
521526 """Prepare the tuner instance for tuning (fit)."""
@@ -1222,6 +1227,9 @@ def _prepare_init_params_from_job_description(cls, job_details):
12221227 "base_tuning_job_name" : base_from_name (job_details ["HyperParameterTuningJobName" ]),
12231228 }
12241229
1230+ if "RandomSeed" in tuning_config :
1231+ params ["random_seed" ] = tuning_config ["RandomSeed" ]
1232+
12251233 if "HyperParameterTuningJobObjective" in tuning_config :
12261234 params ["objective_metric_name" ] = tuning_config ["HyperParameterTuningJobObjective" ][
12271235 "MetricName"
@@ -1483,6 +1491,7 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato
14831491 warm_start_type = warm_start_type , parents = all_parents
14841492 ),
14851493 early_stopping_type = self .early_stopping_type ,
1494+ random_seed = self .random_seed ,
14861495 )
14871496
14881497 if len (self .estimator_dict ) > 1 :
@@ -1508,6 +1517,7 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato
15081517 max_parallel_jobs = self .max_parallel_jobs ,
15091518 warm_start_config = WarmStartConfig (warm_start_type = warm_start_type , parents = all_parents ),
15101519 early_stopping_type = self .early_stopping_type ,
1520+ random_seed = self .random_seed ,
15111521 )
15121522
15131523 @classmethod
@@ -1526,6 +1536,7 @@ def create(
15261536 tags = None ,
15271537 warm_start_config = None ,
15281538 early_stopping_type = "Off" ,
1539+ random_seed = None ,
15291540 ):
15301541 """Factory method to create a ``HyperparameterTuner`` instance.
15311542
@@ -1586,6 +1597,9 @@ def create(
15861597 Can be either 'Auto' or 'Off' (default: 'Off'). If set to 'Off', early stopping
15871598 will not be attempted. If set to 'Auto', early stopping of some training jobs may
15881599 happen, but is not guaranteed to.
1600+ random_seed (int): An initial value used to initialize a pseudo-random number generator.
1601+ Setting a random seed will make the hyperparameter tuning search strategies to
1602+ produce more consistent configurations for the same tuning job.
15891603
15901604 Returns:
15911605 sagemaker.tuner.HyperparameterTuner: a new ``HyperparameterTuner`` object that can
@@ -1624,6 +1638,7 @@ def create(
16241638 tags = tags ,
16251639 warm_start_config = warm_start_config ,
16261640 early_stopping_type = early_stopping_type ,
1641+ random_seed = random_seed ,
16271642 )
16281643
16291644 for estimator_name in estimator_names [1 :]:
@@ -1775,6 +1790,9 @@ def _get_tuner_args(cls, tuner, inputs):
17751790 "early_stopping_type" : tuner .early_stopping_type ,
17761791 }
17771792
1793+ if tuner .random_seed is not None :
1794+ tuning_config ["random_seed" ] = tuner .random_seed
1795+
17781796 if tuner .strategy_config is not None :
17791797 tuning_config ["strategy_config" ] = tuner .strategy_config .to_input_req ()
17801798
0 commit comments