@@ -475,12 +475,22 @@ def model_post_init(self, __context: Any):
475475 if not os .path .exists (self .hyperparameters ):
476476 raise ValueError (f"Hyperparameters file not found: { self .hyperparameters } " )
477477 logger .info (f"Loading hyperparameters from file: { self .hyperparameters } " )
478- if self .hyperparameters .endswith (".json" ):
479- with open (self .hyperparameters , "r" ) as f :
480- self .hyperparameters = json .load (f )
481- elif self .hyperparameters .endswith (".yaml" ):
482- with open (self .hyperparameters , "r" ) as f :
483- self .hyperparameters = yaml .safe_load (f )
478+ with open (self .hyperparameters , "r" ) as f :
479+ contents = f .read ()
480+ try :
481+ self .hyperparameters = json .loads (contents )
482+ logger .debug ("Hyperparameters loaded as JSON" )
483+ except json .JSONDecodeError :
484+ try :
485+ self .hyperparameters = yaml .safe_load (contents )
486+ if not isinstance (self .hyperparameters , dict ):
487+ raise ValueError ("YAML content is not a valid mapping." )
488+ logger .debug ("Hyperparameters loaded as YAML" )
489+ except (yaml .YAMLError , ValueError ) as e :
490+ raise ValueError (
491+ f"Invalid hyperparameters file: { self .hyperparameters } . "
492+ "Must be a valid JSON or YAML file."
493+ )
484494
485495 if self .training_mode == Mode .SAGEMAKER_TRAINING_JOB and self .output_data_config is None :
486496 session = self .sagemaker_session
0 commit comments