@@ -905,6 +905,30 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A
905905 }
906906 return hyperparameters
907907
908+ @staticmethod
909+ def _nova_encode_hyperparameters (hyperparameters : Dict [str , Any ]) -> Dict [str , Any ]:
910+ """Applies JSON encoding for Nova job hyperparameters, preserving string values.
911+
912+ For Nova jobs, string values should not be JSON-encoded.
913+
914+ Args:
915+ hyperparameters (dict): Dictionary of hyperparameters.
916+
917+ Returns:
918+ dict: Dictionary with encoded hyperparameters.
919+ """
920+ current_hyperparameters = hyperparameters
921+ if current_hyperparameters is not None :
922+ hyperparameters = {}
923+ for k , v in current_hyperparameters .items ():
924+ if is_pipeline_variable (v ):
925+ hyperparameters [str (k )] = v .to_string ()
926+ elif isinstance (v , str ):
927+ hyperparameters [str (k )] = v
928+ else :
929+ hyperparameters [str (k )] = json .dumps (v )
930+ return hyperparameters
931+
908932 def _prepare_for_training (self , job_name = None ):
909933 """Set any values in the estimator that need to be set before training.
910934
@@ -938,7 +962,11 @@ def _prepare_for_training(self, job_name=None):
938962 self .source_dir = updated_paths ["source_dir" ]
939963 self .dependencies = updated_paths ["dependencies" ]
940964
941- if self .source_dir or self .entry_point or self .dependencies :
965+ if (
966+ self .source_dir
967+ or self .entry_point
968+ or (self .dependencies and len (self .dependencies ) > 0 )
969+ ):
942970 # validate source dir will raise a ValueError if there is something wrong with
943971 # the source directory. We are intentionally not handling it because this is a
944972 # critical error.
@@ -3579,7 +3607,11 @@ def __init__(
35793607 git_config = git_config ,
35803608 enable_network_isolation = enable_network_isolation ,
35813609 )
3582- if not is_pipeline_variable (entry_point ) and entry_point .startswith ("s3://" ):
3610+ if (
3611+ not is_pipeline_variable (entry_point )
3612+ and entry_point is not None
3613+ and entry_point .startswith ("s3://" )
3614+ ):
35833615 raise ValueError (
35843616 "Invalid entry point script: {}. Must be a path to a local file." .format (
35853617 entry_point
@@ -3599,6 +3631,7 @@ def __init__(
35993631 self .checkpoint_s3_uri = checkpoint_s3_uri
36003632 self .checkpoint_local_path = checkpoint_local_path
36013633 self .enable_sagemaker_metrics = enable_sagemaker_metrics
3634+ self .is_nova_job = kwargs .get ("is_nova_job" , False )
36023635
36033636 def _prepare_for_training (self , job_name = None ):
36043637 """Set hyperparameters needed for training. This method will also validate ``source_dir``.
@@ -3713,7 +3746,10 @@ def _model_entry_point(self):
37133746
37143747 def set_hyperparameters (self , ** kwargs ):
37153748 """Escapes the dict argument as JSON, updates the private hyperparameter attribute."""
3716- self ._hyperparameters .update (EstimatorBase ._json_encode_hyperparameters (kwargs ))
3749+ if self .is_nova_job :
3750+ self ._hyperparameters .update (EstimatorBase ._nova_encode_hyperparameters (kwargs ))
3751+ else :
3752+ self ._hyperparameters .update (EstimatorBase ._json_encode_hyperparameters (kwargs ))
37173753
37183754 def hyperparameters (self ):
37193755 """Returns the hyperparameters as a dictionary to use for training.
@@ -3724,7 +3760,10 @@ def hyperparameters(self):
37243760 Returns:
37253761 dict[str, str]: The hyperparameters.
37263762 """
3727- return EstimatorBase ._json_encode_hyperparameters (self ._hyperparameters )
3763+ if self .is_nova_job :
3764+ return EstimatorBase ._nova_encode_hyperparameters (self ._hyperparameters )
3765+ else :
3766+ return EstimatorBase ._json_encode_hyperparameters (self ._hyperparameters )
37283767
37293768 @classmethod
37303769 def _prepare_init_params_from_job_description (cls , job_details , model_channel_name = None ):
0 commit comments