8585 SM_CODE_CONTAINER_PATH ,
8686 SM_DRIVERS ,
8787 SM_DRIVERS_LOCAL_PATH ,
88+ SM_RECIPE ,
89+ SM_RECIPE_YAML ,
90+ SM_RECIPE_CONTAINER_PATH ,
8891 TRAIN_SCRIPT ,
8992 DEFAULT_CONTAINER_ENTRYPOINT ,
9093 DEFAULT_CONTAINER_ARGUMENTS ,
100103from sagemaker .telemetry .telemetry_logging import _telemetry_emitter
101104from sagemaker .telemetry .constants import Feature
102105from sagemaker .modules import logger
103- from sagemaker .modules .train .sm_recipes .utils import _get_args_from_recipe , _determine_device_type
106+ from sagemaker .modules .train .sm_recipes .utils import (
107+ _get_args_from_recipe ,
108+ _determine_device_type ,
109+ _is_nova_recipe ,
110+ _load_base_recipe ,
111+ )
104112
105113
106114class Mode (Enum ):
@@ -242,6 +250,7 @@ class ModelTrainer(BaseModel):
242250 _remote_debug_config : Optional [RemoteDebugConfig ] = PrivateAttr (default = None )
243251 _metric_definitions : Optional [List [MetricDefinition ]] = PrivateAttr (default = None )
244252
253+ _is_nova_recipe : Optional [bool ] = PrivateAttr (default = None )
245254 _temp_recipe_train_dir : Optional [TemporaryDirectory ] = PrivateAttr (default = None )
246255
247256 CONFIGURABLE_ATTRIBUTES : ClassVar [List [str ]] = [
@@ -449,6 +458,33 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
449458 + "Must be a valid file within the 'source_dir'." ,
450459 )
451460
461+ @staticmethod
462+ def _validate_and_load_hyperparameters_file (hyperparameters_file : str ) -> Dict [str , Any ]:
463+ """Validate the hyperparameters file."""
464+ if not os .path .exists (hyperparameters_file ):
465+ raise ValueError (f"Hyperparameters file not found: { hyperparameters_file } " )
466+ logger .info (f"Loading hyperparameters from file: { hyperparameters_file } " )
467+ with open (hyperparameters_file , "r" ) as f :
468+ contents = f .read ()
469+ try :
470+ hyperparameters = json .loads (contents )
471+ logger .debug ("Hyperparameters loaded as JSON" )
472+ return hyperparameters
473+ except json .JSONDecodeError :
474+ try :
475+ logger .info (f"contents: { contents } " )
476+ hyperparameters = yaml .safe_load (contents )
477+ if not isinstance (hyperparameters , dict ):
478+ raise ValueError ("YAML contents must be a valid mapping" )
479+ logger .info (f"hyperparameters: { hyperparameters } " )
480+ logger .debug ("Hyperparameters loaded as YAML" )
481+ return hyperparameters
482+ except (yaml .YAMLError , ValueError ):
483+ raise ValueError (
484+ f"Invalid hyperparameters file: { hyperparameters_file } . "
485+ "Must be a valid JSON or YAML file."
486+ )
487+
452488 def model_post_init (self , __context : Any ):
453489 """Post init method to perform custom validation and set default values."""
454490 self ._validate_training_image_and_algorithm_name (self .training_image , self .algorithm_name )
@@ -510,27 +546,9 @@ def model_post_init(self, __context: Any):
510546 )
511547
512548 if self .hyperparameters and isinstance (self .hyperparameters , str ):
513- if not os .path .exists (self .hyperparameters ):
514- raise ValueError (f"Hyperparameters file not found: { self .hyperparameters } " )
515- logger .info (f"Loading hyperparameters from file: { self .hyperparameters } " )
516- with open (self .hyperparameters , "r" ) as f :
517- contents = f .read ()
518- try :
519- self .hyperparameters = json .loads (contents )
520- logger .debug ("Hyperparameters loaded as JSON" )
521- except json .JSONDecodeError :
522- try :
523- logger .info (f"contents: { contents } " )
524- self .hyperparameters = yaml .safe_load (contents )
525- if not isinstance (self .hyperparameters , dict ):
526- raise ValueError ("YAML contents must be a valid mapping" )
527- logger .info (f"hyperparameters: { self .hyperparameters } " )
528- logger .debug ("Hyperparameters loaded as YAML" )
529- except (yaml .YAMLError , ValueError ):
530- raise ValueError (
531- f"Invalid hyperparameters file: { self .hyperparameters } . "
532- "Must be a valid JSON or YAML file."
533- )
549+ self .hyperparameters = self ._validate_and_load_hyperparameters_file (
550+ self .hyperparameters
551+ )
534552
535553 if self .training_mode == Mode .SAGEMAKER_TRAINING_JOB :
536554 if self .output_data_config is None :
@@ -613,6 +631,22 @@ def train(
613631
614632 final_input_data_config = list (existing_channels .values ()) + new_channels
615633
634+ if self ._is_nova_recipe :
635+ for input_data in final_input_data_config :
636+ if input_data .channel_name == SM_RECIPE :
637+ raise ValueError (
638+ "Cannot use reserved channel name 'recipe' as an input channel name "
639+ " for Nova Recipe"
640+ )
641+ recipe_file_path = os .path .join (self ._temp_recipe_train_dir .name , SM_RECIPE_YAML )
642+ recipe_channel = self .create_input_data_channel (
643+ channel_name = SM_RECIPE ,
644+ data_source = recipe_file_path ,
645+ key_prefix = input_data_key_prefix ,
646+ )
647+ final_input_data_config .append (recipe_channel )
648+ self .hyperparameters .update ({"sagemaker_recipe_local_path" : SM_RECIPE_CONTAINER_PATH })
649+
616650 if final_input_data_config :
617651 final_input_data_config = self ._get_input_data_config (
618652 final_input_data_config , input_data_key_prefix
@@ -1005,6 +1039,7 @@ def from_recipe(
10051039 checkpoint_config : Optional [shapes .CheckpointConfig ] = None ,
10061040 training_input_mode : Optional [str ] = "File" ,
10071041 environment : Optional [Dict [str , str ]] = None ,
1042+ hyperparameters : Optional [Union [Dict [str , Any ], str ]] = {},
10081043 tags : Optional [List [Tag ]] = None ,
10091044 sagemaker_session : Optional [Session ] = None ,
10101045 role : Optional [str ] = None ,
@@ -1101,14 +1136,21 @@ def from_recipe(
11011136 """
11021137 if compute .instance_type is None :
11031138 raise ValueError (
1104- "Must set ``instance_type`` in compute_config when using training recipes."
1139+ "Must set ``instance_type`` in ``compute`` input when using training recipes."
11051140 )
11061141 device_type = _determine_device_type (compute .instance_type )
1107- if device_type == "cpu" :
1142+ recipe = _load_base_recipe (
1143+ training_recipe = training_recipe , recipe_overrides = recipe_overrides
1144+ )
1145+ is_nova = _is_nova_recipe (recipe = recipe )
1146+
1147+ if device_type == "cpu" and not is_nova :
11081148 raise ValueError (
1109- "Training recipes are not supported for CPU instances. "
1149+ "Training recipe is not supported for CPU instances. "
11101150 + "Please provide a GPU or Tranium instance type."
11111151 )
1152+ if training_image is None and is_nova :
1153+ raise ValueError ("training_image must be provided when using recipe for Nova." )
11121154
11131155 if training_image_config and training_image is None :
11141156 raise ValueError ("training_image must be provided when using training_image_config." )
@@ -1126,15 +1168,27 @@ def from_recipe(
11261168 # - distributed
11271169 # - compute
11281170 # - hyperparameters
1129- model_trainer_args , recipe_train_dir = _get_args_from_recipe (
1130- training_recipe = training_recipe ,
1171+ model_trainer_args , tmp_dir = _get_args_from_recipe (
1172+ training_recipe = recipe ,
11311173 recipe_overrides = recipe_overrides ,
11321174 requirements = requirements ,
11331175 compute = compute ,
11341176 region_name = sagemaker_session .boto_region_name ,
1177+ role = role ,
11351178 )
11361179 if training_image is not None :
11371180 model_trainer_args ["training_image" ] = training_image
1181+ if hyperparameters and not is_nova :
1182+ logger .warning (
1183+ "Hyperparameters are not supported for general training recipes. "
1184+ + "Ignoring hyperparameters input."
1185+ )
1186+ if is_nova :
1187+ if hyperparameters and isinstance (hyperparameters , str ):
1188+ hyperparameters = cls ._validate_and_load_hyperparameters_file (hyperparameters )
1189+ model_trainer_args ["hyperparameters" ].update (hyperparameters )
1190+ elif hyperparameters and isinstance (hyperparameters , dict ):
1191+ model_trainer_args ["hyperparameters" ].update (hyperparameters )
11381192
11391193 model_trainer = cls (
11401194 sagemaker_session = sagemaker_session ,
@@ -1151,8 +1205,8 @@ def from_recipe(
11511205 tags = tags ,
11521206 ** model_trainer_args ,
11531207 )
1154-
1155- model_trainer ._temp_recipe_train_dir = recipe_train_dir
1208+ model_trainer . _is_nova_recipe = is_nova
1209+ model_trainer ._temp_recipe_train_dir = tmp_dir
11561210 return model_trainer
11571211
11581212 def with_tensorboard_output_config (
0 commit comments