3333from sagemaker .estimator import Framework , EstimatorBase
3434from sagemaker .inputs import TrainingInput , FileSystemInput
3535from sagemaker .job import _Job
36- from sagemaker .jumpstart .utils import add_jumpstart_tags , get_jumpstart_base_name_if_jumpstart_model
36+ from sagemaker .jumpstart .utils import (
37+ add_jumpstart_tags ,
38+ get_jumpstart_base_name_if_jumpstart_model ,
39+ )
3740from sagemaker .parameter import (
3841 CategoricalParameter ,
3942 ContinuousParameter ,
4447from sagemaker .workflow .pipeline_context import runnable_by_pipeline
4548
4649from sagemaker .session import Session
47- from sagemaker .utils import base_from_name , base_name_from_image , name_from_base , to_string
50+ from sagemaker .utils import (
51+ base_from_name ,
52+ base_name_from_image ,
53+ name_from_base ,
54+ to_string ,
55+ )
4856
4957AMAZON_ESTIMATOR_MODULE = "sagemaker"
5058AMAZON_ESTIMATOR_CLS_NAMES = {
6068HYPERPARAMETER_TUNING_JOB_NAME = "HyperParameterTuningJobName"
6169PARENT_HYPERPARAMETER_TUNING_JOBS = "ParentHyperParameterTuningJobs"
6270WARM_START_TYPE = "WarmStartType"
71+ GRID_SEARCH = "GridSearch"
6372
6473logger = logging .getLogger (__name__ )
6574
@@ -165,7 +174,8 @@ def from_job_desc(cls, warm_start_config):
165174 parents .append (parent [HYPERPARAMETER_TUNING_JOB_NAME ])
166175
167176 return cls (
168- warm_start_type = WarmStartTypes (warm_start_config [WARM_START_TYPE ]), parents = parents
177+ warm_start_type = WarmStartTypes (warm_start_config [WARM_START_TYPE ]),
178+ parents = parents ,
169179 )
170180
171181 def to_input_req (self ):
@@ -219,7 +229,7 @@ def __init__(
219229 metric_definitions : Optional [List [Dict [str , Union [str , PipelineVariable ]]]] = None ,
220230 strategy : Union [str , PipelineVariable ] = "Bayesian" ,
221231 objective_type : Union [str , PipelineVariable ] = "Maximize" ,
222- max_jobs : Union [int , PipelineVariable ] = 1 ,
232+ max_jobs : Union [int , PipelineVariable ] = None ,
223233 max_parallel_jobs : Union [int , PipelineVariable ] = 1 ,
224234 tags : Optional [List [Dict [str , Union [str , PipelineVariable ]]]] = None ,
225235 base_tuning_job_name : Optional [str ] = None ,
@@ -258,7 +268,8 @@ def __init__(
258268 evaluating training jobs. This value can be either 'Minimize' or
259269 'Maximize' (default: 'Maximize').
260270 max_jobs (int or PipelineVariable): Maximum total number of training jobs to start for
261- the hyperparameter tuning job (default: 1).
271+ the hyperparameter tuning job. The default value is unspecified fot the GridSearch
272+ strategy and the default value is 1 for all others strategies (default: None).
262273 max_parallel_jobs (int or PipelineVariable): Maximum number of parallel training jobs to
263274 start (default: 1).
264275 tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
@@ -311,7 +322,12 @@ def __init__(
311322
312323 self .strategy = strategy
313324 self .objective_type = objective_type
325+ # For the GridSearch strategy we expect the max_jobs equals None and recalculate it later.
326+ # For all other strategies for the backward compatibility we keep
327+ # the default value as 1 (previous default value).
314328 self .max_jobs = max_jobs
329+ if max_jobs is None and strategy is not GRID_SEARCH :
330+ self .max_jobs = 1
315331 self .max_parallel_jobs = max_parallel_jobs
316332
317333 self .tags = tags
@@ -373,7 +389,8 @@ def _prepare_job_name_for_tuning(self, job_name=None):
373389 self .estimator or self .estimator_dict [sorted (self .estimator_dict .keys ())[0 ]]
374390 )
375391 base_name = base_name_from_image (
376- estimator .training_image_uri (), default_base_name = EstimatorBase .JOB_CLASS_NAME
392+ estimator .training_image_uri (),
393+ default_base_name = EstimatorBase .JOB_CLASS_NAME ,
377394 )
378395
379396 jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model (
@@ -434,7 +451,15 @@ def _prepare_static_hyperparameters(
434451 def fit (
435452 self ,
436453 inputs : Optional [
437- Union [str , Dict , List , TrainingInput , FileSystemInput , RecordSet , FileSystemRecordSet ]
454+ Union [
455+ str ,
456+ Dict ,
457+ List ,
458+ TrainingInput ,
459+ FileSystemInput ,
460+ RecordSet ,
461+ FileSystemRecordSet ,
462+ ]
438463 ] = None ,
439464 job_name : Optional [str ] = None ,
440465 include_cls_metadata : Union [bool , Dict [str , bool ]] = False ,
@@ -524,7 +549,9 @@ def _fit_with_estimator_dict(self, inputs, job_name, include_cls_metadata, estim
524549 allowed_keys = estimator_names ,
525550 )
526551 self ._validate_dict_argument (
527- name = "estimator_kwargs" , value = estimator_kwargs , allowed_keys = estimator_names
552+ name = "estimator_kwargs" ,
553+ value = estimator_kwargs ,
554+ allowed_keys = estimator_names ,
528555 )
529556
530557 for (estimator_name , estimator ) in self .estimator_dict .items ():
@@ -546,7 +573,13 @@ def _prepare_estimator_for_tuning(cls, estimator, inputs, job_name, **kwargs):
546573 estimator ._prepare_for_training (job_name )
547574
548575 @classmethod
549- def attach (cls , tuning_job_name , sagemaker_session = None , job_details = None , estimator_cls = None ):
576+ def attach (
577+ cls ,
578+ tuning_job_name ,
579+ sagemaker_session = None ,
580+ job_details = None ,
581+ estimator_cls = None ,
582+ ):
550583 """Attach to an existing hyperparameter tuning job.
551584
552585 Create a HyperparameterTuner bound to an existing hyperparameter
@@ -959,7 +992,8 @@ def _prepare_estimator_cls(cls, estimator_cls, training_details):
959992
960993 # Default to the BYO estimator
961994 return getattr (
962- importlib .import_module (cls .DEFAULT_ESTIMATOR_MODULE ), cls .DEFAULT_ESTIMATOR_CLS_NAME
995+ importlib .import_module (cls .DEFAULT_ESTIMATOR_MODULE ),
996+ cls .DEFAULT_ESTIMATOR_CLS_NAME ,
963997 )
964998
965999 @classmethod
@@ -1151,7 +1185,10 @@ def _validate_parameter_ranges(self, estimator, hyperparameter_ranges):
11511185
11521186 def _validate_parameter_range (self , value_hp , parameter_range ):
11531187 """Placeholder docstring"""
1154- for (parameter_range_key , parameter_range_value ) in parameter_range .__dict__ .items ():
1188+ for (
1189+ parameter_range_key ,
1190+ parameter_range_value ,
1191+ ) in parameter_range .__dict__ .items ():
11551192 if parameter_range_key == "scaling_type" :
11561193 continue
11571194
@@ -1301,7 +1338,7 @@ def create(
13011338 base_tuning_job_name = None ,
13021339 strategy = "Bayesian" ,
13031340 objective_type = "Maximize" ,
1304- max_jobs = 1 ,
1341+ max_jobs = None ,
13051342 max_parallel_jobs = 1 ,
13061343 tags = None ,
13071344 warm_start_config = None ,
@@ -1351,7 +1388,8 @@ def create(
13511388 objective_type (str): The type of the objective metric for evaluating training jobs.
13521389 This value can be either 'Minimize' or 'Maximize' (default: 'Maximize').
13531390 max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
1354- tuning job (default: 1).
1391+ tuning job. The default value is unspecified fot the GridSearch strategy
1392+ and the value is 1 for all others strategies (default: None).
13551393 max_parallel_jobs (int): Maximum number of parallel training jobs to start
13561394 (default: 1).
13571395 tags (list[dict]): List of tags for labeling the tuning job (default: None). For more,
0 commit comments