@@ -866,16 +866,10 @@ def _create_sagemaker_model(
866866 # _base_name, model_name are not needed under PipelineSession.
867867 # the model_data may be Pipeline variable
868868 # which may break the _base_name generation
869- model_uri = None
870- if isinstance (self .model_data , (str , PipelineVariable )):
871- model_uri = self .model_data
872- elif isinstance (self .model_data , dict ):
873- model_uri = self .model_data .get ("S3DataSource" , {}).get ("S3Uri" , None )
874-
875869 self ._ensure_base_name_if_needed (
876870 image_uri = container_def ["Image" ],
877871 script_uri = self .source_dir ,
878- model_uri = model_uri ,
872+ model_uri = self . _get_model_uri () ,
879873 )
880874 self ._set_model_name_if_needed ()
881875
@@ -912,6 +906,14 @@ def _create_sagemaker_model(
912906 )
913907 self .sagemaker_session .create_model (** create_model_args )
914908
909+ def _get_model_uri (self ):
910+ model_uri = None
911+ if isinstance (self .model_data , (str , PipelineVariable )):
912+ model_uri = self .model_data
913+ elif isinstance (self .model_data , dict ):
914+ model_uri = self .model_data .get ("S3DataSource" , {}).get ("S3Uri" , None )
915+ return model_uri
916+
915917 def _ensure_base_name_if_needed (self , image_uri , script_uri , model_uri ):
916918 """Create a base name from the image URI if there is no model name provided.
917919
@@ -1496,7 +1498,7 @@ def deploy(
14961498 self ._ensure_base_name_if_needed (
14971499 image_uri = self .image_uri ,
14981500 script_uri = self .source_dir ,
1499- model_uri = self .model_data ,
1501+ model_uri = self ._get_model_uri () ,
15001502 )
15011503 if self ._base_name is not None :
15021504 self ._base_name = "-" .join ((self ._base_name , compiled_model_suffix ))
0 commit comments