1919import os
2020import re
2121import copy
22+ from typing import List , Dict
2223
2324import sagemaker
2425from sagemaker import (
3839from sagemaker .async_inference import AsyncInferenceConfig
3940from sagemaker .predictor_async import AsyncPredictor
4041from sagemaker .workflow import is_pipeline_variable
42+ from sagemaker .workflow .pipeline_context import runnable_by_pipeline , PipelineSession
4143
4244LOGGER = logging .getLogger ("sagemaker" )
4345
@@ -289,6 +291,7 @@ def __init__(
289291 self .uploaded_code = None
290292 self .repacked_model_data = None
291293
294+ @runnable_by_pipeline
292295 def register (
293296 self ,
294297 content_types ,
@@ -310,12 +313,12 @@ def register(
310313 """Creates a model package for creating SageMaker models or listing on Marketplace.
311314
312315 Args:
313- content_types (list): The supported MIME types for the input data (default: None) .
314- response_types (list): The supported MIME types for the output data (default: None) .
316+ content_types (list): The supported MIME types for the input data.
317+ response_types (list): The supported MIME types for the output data.
315318 inference_instances (list): A list of the instance types that are used to
316- generate inferences in real-time (default: None) .
319+ generate inferences in real-time.
317320 transform_instances (list): A list of the instance types on which a transformation
318- job can be run or on which an endpoint can be deployed (default: None) .
321+ job can be run or on which an endpoint can be deployed.
319322 model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
320323 using `model_package_name` makes the Model Package un-versioned (default: None).
321324 model_package_group_name (str): Model Package Group name, exclusive to
@@ -366,12 +369,50 @@ def register(
366369 model_package = self .sagemaker_session .create_model_package_from_containers (
367370 ** model_pkg_args
368371 )
372+ if isinstance (self .sagemaker_session , PipelineSession ):
373+ return None
369374 return ModelPackage (
370375 role = self .role ,
371376 model_data = self .model_data ,
372377 model_package_arn = model_package .get ("ModelPackageArn" ),
373378 )
374379
380+ @runnable_by_pipeline
381+ def create (
382+ self ,
383+ instance_type : str = None ,
384+ accelerator_type : str = None ,
385+ serverless_inference_config : ServerlessInferenceConfig = None ,
386+ tags : List [Dict [str , str ]] = None ,
387+ ):
388+ """Create a SageMaker Model Entity
389+
390+ Args:
391+ instance_type (str): The EC2 instance type that this Model will be
392+ used for, this is only used to determine if the image needs GPU
393+ support or not (default: None).
394+ accelerator_type (str): Type of Elastic Inference accelerator to
395+ attach to an endpoint for model loading and inference, for
396+ example, 'ml.eia1.medium'. If not specified, no Elastic
397+ Inference accelerator will be attached to the endpoint (default: None).
398+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
399+ Specifies configuration related to serverless endpoint. Instance type is
400+ not provided in serverless inference. So this is used to find image URIs
401+ (default: None).
402+ tags (List[Dict[str, str]]): The list of tags to add to
403+ the model (default: None). Example: >>> tags = [{'Key': 'tagname', 'Value':
404+ 'tagvalue'}] For more information about tags, see
405+ https://boto3.amazonaws.com/v1/documentation
406+ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
407+ """
408+ # TODO: we should replace _create_sagemaker_model() with create()
409+ self ._create_sagemaker_model (
410+ instance_type = instance_type ,
411+ accelerator_type = accelerator_type ,
412+ tags = tags ,
413+ serverless_inference_config = serverless_inference_config ,
414+ )
415+
375416 def _init_sagemaker_session_if_does_not_exist (self , instance_type = None ):
376417 """Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
377418
@@ -455,6 +496,24 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
455496 if repack and self .model_data is not None and self .entry_point is not None :
456497 if is_pipeline_variable (self .model_data ):
457498 # model is not yet there, defer repacking to later during pipeline execution
499+ if not isinstance (self .sagemaker_session , PipelineSession ):
500+ # TODO: link the doc in the warning once ready
501+ logging .warning (
502+ "The model_data is a Pipeline variable of type %s, "
503+ "which should be used under `PipelineSession` and "
504+ "leverage `ModelStep` to create or register model. "
505+ "Otherwise some functionalities e.g. "
506+ "runtime repack may be missing" ,
507+ type (self .model_data ),
508+ )
509+ return
510+ self .sagemaker_session .context .need_runtime_repack .add (id (self ))
511+ # Add the uploaded_code and repacked_model_data to update the container env
512+ self .repacked_model_data = self .model_data
513+ self .uploaded_code = fw_utils .UploadedCode (
514+ s3_prefix = self .repacked_model_data ,
515+ script_name = os .path .basename (self .entry_point ),
516+ )
458517 return
459518 if local_code and self .model_data .startswith ("file://" ):
460519 repacked_model_data = self .model_data
@@ -538,22 +597,29 @@ def _create_sagemaker_model(
538597 serverless_inference_config = serverless_inference_config ,
539598 )
540599
541- self ._ensure_base_name_if_needed (
542- image_uri = container_def ["Image" ], script_uri = self .source_dir , model_uri = self .model_data
543- )
544- self ._set_model_name_if_needed ()
600+ if not isinstance (self .sagemaker_session , PipelineSession ):
601+ # _base_name, model_name are not needed under PipelineSession.
602+ # the model_data may be Pipeline variable
603+ # which may break the _base_name generation
604+ self ._ensure_base_name_if_needed (
605+ image_uri = container_def ["Image" ],
606+ script_uri = self .source_dir ,
607+ model_uri = self .model_data ,
608+ )
609+ self ._set_model_name_if_needed ()
545610
546611 enable_network_isolation = self .enable_network_isolation ()
547612
548613 self ._init_sagemaker_session_if_does_not_exist (instance_type )
549- self . sagemaker_session . create_model (
550- self .name ,
551- self .role ,
552- container_def ,
614+ create_model_args = dict (
615+ name = self .name ,
616+ role = self .role ,
617+ container_defs = container_def ,
553618 vpc_config = self .vpc_config ,
554619 enable_network_isolation = enable_network_isolation ,
555620 tags = tags ,
556621 )
622+ self .sagemaker_session .create_model (** create_model_args )
557623
558624 def _ensure_base_name_if_needed (self , image_uri , script_uri , model_uri ):
559625 """Create a base name from the image URI if there is no model name provided.
0 commit comments