@@ -383,6 +383,83 @@ def to_input_req(self):
383383 }
384384
385385
386+ class InstanceConfig :
387+ """Instance configuration for training jobs started by hyperparameter tuning.
388+
389+ Contains the configuration(s) for one or more resources for processing hyperparameter jobs.
390+ These resources include compute instances and storage volumes to use in model training jobs
391+ launched by hyperparameter tuning jobs.
392+ """
393+
394+ def __init__ (
395+ self ,
396+ instance_count : Union [int , PipelineVariable ] = None ,
397+ instance_type : Union [str , PipelineVariable ] = None ,
398+ volume_size : Union [int , PipelineVariable ] = 30 ,
399+ ):
400+ """Creates a ``InstanceConfig`` instance.
401+
402+ It takes instance configuration information for training
403+ jobs that are created as the result of a hyperparameter tuning job.
404+
405+ Args:
406+ * instance_count (str or PipelineVariable): The number of compute instances of type
407+ InstanceType to use. For distributed training, select a value greater than 1.
408+ * instance_type (str or PipelineVariable):
409+ The instance type used to run hyperparameter optimization tuning jobs.
410+ * volume_size (int or PipelineVariable): The volume size in GB of the data to be
411+ processed for hyperparameter optimization
412+ """
413+ self .instance_count = instance_count
414+ self .instance_type = instance_type
415+ self .volume_size = volume_size
416+
417+ @classmethod
418+ def from_job_desc (cls , instance_config ):
419+ """Creates a ``InstanceConfig`` from an instance configuration response.
420+
421+ This is the instance configuration from the DescribeTuningJob response.
422+
423+ Args:
424+ instance_config (dict): The expected format of the
425+ ``instance_config`` contains one first-class field
426+
427+ Returns:
428+ sagemaker.tuner.InstanceConfig: De-serialized instance of
429+ InstanceConfig containing the strategy configuration.
430+ """
431+ return cls (
432+ instance_count = instance_config ["InstanceCount" ],
433+ instance_type = instance_config [" InstanceType " ],
434+ volume_size = instance_config ["VolumeSizeInGB" ],
435+ )
436+
437+ def to_input_req (self ):
438+ """Converts the ``self`` instance to the desired input request format.
439+
440+ Examples:
441+ >>> strategy_config = InstanceConfig(
442+ instance_count=1,
443+ instance_type='ml.m4.xlarge',
444+ volume_size=30
445+ )
446+ >>> strategy_config.to_input_req()
447+ {
448+ "InstanceCount":1,
449+ "InstanceType":"ml.m4.xlarge",
450+ "VolumeSizeInGB":30
451+ }
452+
453+ Returns:
454+ dict: Containing the instance configurations.
455+ """
456+ return {
457+ "InstanceCount" : self .instance_count ,
458+ "InstanceType" : self .instance_type ,
459+ "VolumeSizeInGB" : self .volume_size ,
460+ }
461+
462+
386463class HyperparameterTuner (object ):
387464 """Defines interaction with Amazon SageMaker hyperparameter tuning jobs.
388465
@@ -482,14 +559,14 @@ def __init__(
482559 self .estimator = None
483560 self .objective_metric_name = None
484561 self ._hyperparameter_ranges = None
562+ self .static_hyperparameters = None
485563 self .metric_definitions = None
486564 self .estimator_dict = {estimator_name : estimator }
487565 self .objective_metric_name_dict = {estimator_name : objective_metric_name }
488566 self ._hyperparameter_ranges_dict = {estimator_name : hyperparameter_ranges }
489567 self .metric_definitions_dict = (
490568 {estimator_name : metric_definitions } if metric_definitions is not None else {}
491569 )
492- self .static_hyperparameters = None
493570 else :
494571 self .estimator = estimator
495572 self .objective_metric_name = objective_metric_name
@@ -521,6 +598,31 @@ def __init__(
521598 self .warm_start_config = warm_start_config
522599 self .early_stopping_type = early_stopping_type
523600 self .random_seed = random_seed
601+ self .instance_configs_dict = None
602+ self .instance_configs = None
603+
604+ def override_resource_config (
605+ self , instance_configs : Union [List [InstanceConfig ], Dict [str , List [InstanceConfig ]]]
606+ ):
607+ """Override the instance configuration of the estimators used by the tuner.
608+
609+ Args:
610+ instance_configs (List[InstanceConfig] or Dict[str, List[InstanceConfig]):
611+ The InstanceConfigs to use as an override for the instance configuration
612+ of the estimator. ``None`` will remove the override.
613+ """
614+ if isinstance (instance_configs , dict ):
615+ self ._validate_dict_argument (
616+ name = "instance_configs" ,
617+ value = instance_configs ,
618+ allowed_keys = list (self .estimator_dict .keys ()),
619+ )
620+ self .instance_configs_dict = instance_configs
621+ else :
622+ self .instance_configs = instance_configs
623+ if self .estimator_dict is not None and self .estimator_dict .keys ():
624+ estimator_names = list (self .estimator_dict .keys ())
625+ self .instance_configs_dict = {estimator_names [0 ]: instance_configs }
524626
525627 def _prepare_for_tuning (self , job_name = None , include_cls_metadata = False ):
526628 """Prepare the tuner instance for tuning (fit)."""
@@ -589,7 +691,6 @@ def _prepare_job_name_for_tuning(self, job_name=None):
589691
590692 def _prepare_static_hyperparameters_for_tuning (self , include_cls_metadata = False ):
591693 """Prepare static hyperparameters for all estimators before tuning."""
592- self .static_hyperparameters = None
593694 if self .estimator is not None :
594695 self .static_hyperparameters = self ._prepare_static_hyperparameters (
595696 self .estimator , self ._hyperparameter_ranges , include_cls_metadata
@@ -1817,6 +1918,7 @@ def _get_tuner_args(cls, tuner, inputs):
18171918 estimator = tuner .estimator ,
18181919 static_hyperparameters = tuner .static_hyperparameters ,
18191920 metric_definitions = tuner .metric_definitions ,
1921+ instance_configs = tuner .instance_configs ,
18201922 )
18211923
18221924 if tuner .estimator_dict is not None :
@@ -1830,12 +1932,44 @@ def _get_tuner_args(cls, tuner, inputs):
18301932 tuner .objective_type ,
18311933 tuner .objective_metric_name_dict [estimator_name ],
18321934 tuner .hyperparameter_ranges_dict ()[estimator_name ],
1935+ tuner .instance_configs_dict .get (estimator_name , None )
1936+ if tuner .instance_configs_dict is not None
1937+ else None ,
18331938 )
18341939 for estimator_name in sorted (tuner .estimator_dict .keys ())
18351940 ]
18361941
18371942 return tuner_args
18381943
1944+ @staticmethod
1945+ def _prepare_hp_resource_config (
1946+ instance_configs : List [InstanceConfig ],
1947+ instance_count : int ,
1948+ instance_type : str ,
1949+ volume_size : int ,
1950+ volume_kms_key : str ,
1951+ ):
1952+ """Placeholder hpo resource config for one estimator of the tuner."""
1953+ resource_config = {}
1954+ if volume_kms_key is not None :
1955+ resource_config ["VolumeKmsKeyId" ] = volume_kms_key
1956+
1957+ if instance_configs is None :
1958+ resource_config ["InstanceCount" ] = instance_count
1959+ resource_config ["InstanceType" ] = instance_type
1960+ resource_config ["VolumeSizeInGB" ] = volume_size
1961+ else :
1962+ resource_config ["InstanceConfigs" ] = _TuningJob ._prepare_instance_configs (
1963+ instance_configs
1964+ )
1965+
1966+ return resource_config
1967+
1968+ @staticmethod
1969+ def _prepare_instance_configs (instance_configs : List [InstanceConfig ]):
1970+ """Prepare instance config for create tuning request."""
1971+ return [config .to_input_req () for config in instance_configs ]
1972+
18391973 @staticmethod
18401974 def _prepare_training_config (
18411975 inputs ,
@@ -1846,10 +1980,20 @@ def _prepare_training_config(
18461980 objective_type = None ,
18471981 objective_metric_name = None ,
18481982 parameter_ranges = None ,
1983+ instance_configs = None ,
18491984 ):
18501985 """Prepare training config for one estimator."""
18511986 training_config = _Job ._load_config (inputs , estimator )
18521987
1988+ del training_config ["resource_config" ]
1989+ training_config ["hpo_resource_config" ] = _TuningJob ._prepare_hp_resource_config (
1990+ instance_configs ,
1991+ estimator .instance_count ,
1992+ estimator .instance_type ,
1993+ estimator .volume_size ,
1994+ estimator .volume_kms_key ,
1995+ )
1996+
18531997 training_config ["input_mode" ] = estimator .input_mode
18541998 training_config ["metric_definitions" ] = metric_definitions
18551999
0 commit comments