@@ -250,7 +250,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
250250 default = None , metadata = {"help" : "Define sagemaker session for execution" }
251251 )
252252 name : Optional [str ] = field (
253- default = "model-name-" + uuid .uuid1 ().hex ,
253+ default_factory = lambda : "model-name-" + uuid .uuid1 ().hex ,
254254 metadata = {"help" : "Define the model name" },
255255 )
256256 mode : Optional [Mode ] = field (
@@ -1130,7 +1130,7 @@ def build(
11301130 def _get_processing_unit (self ):
11311131 """Detects if the resource requirements are intended for a CPU or GPU instance."""
11321132 # Assume custom orchestrator will be deployed as an endpoint to a CPU instance
1133- if not self .resource_requirements :
1133+ if not self .resource_requirements or not self . resource_requirements . num_accelerators :
11341134 return "cpu"
11351135 for ic in self .modelbuilder_list or []:
11361136 if ic .resource_requirements .num_accelerators > 0 :
@@ -1171,10 +1171,10 @@ def _get_ic_resource_requirements(self, mb: ModelBuilder = None) -> ModelBuilder
11711171
11721172 @_capture_telemetry ("build_custom_orchestrator" )
11731173 def _get_smd_image_uri (self , processing_unit : str = None ) -> str :
1174- """Gets the SMD Inference URI.
1174+ """Gets the SMD Inference Image URI.
11751175
11761176 Returns:
1177- str: Pytorch DLC URI.
1177+ str: SMD Inference Image URI.
11781178 """
11791179 from sagemaker import image_uris
11801180 import sys
@@ -1183,10 +1183,10 @@ def _get_smd_image_uri(self, processing_unit: str = None) -> str:
11831183 from packaging .version import Version
11841184
11851185 formatted_py_version = f"py{ sys .version_info .major } { sys .version_info .minor } "
1186- if Version (f"{ sys .version_info .major } { sys .version_info .minor } " ) < Version ("3.11.11 " ):
1186+ if Version (f"{ sys .version_info .major } { sys .version_info .minor } " ) < Version ("3.12 " ):
11871187 raise ValueError (
11881188 f"Found Python version { formatted_py_version } but"
1189- f"Custom orchestrator deployment requires Python version >= 3.11.11 ."
1189+ f"Custom orchestrator deployment requires Python version >= 3.12 ."
11901190 )
11911191
11921192 INSTANCE_TYPES = {"cpu" : "ml.c5.xlarge" , "gpu" : "ml.g5.4xlarge" }
@@ -1957,7 +1957,7 @@ def deploy(
19571957 ] = None ,
19581958 update_endpoint : Optional [bool ] = False ,
19591959 custom_orchestrator_instance_type : str = None ,
1960- custom_orchestrator_initial_instance_count : int = 1 ,
1960+ custom_orchestrator_initial_instance_count : int = None ,
19611961 ** kwargs ,
19621962 ) -> Union [Predictor , Transformer , List [Predictor ]]:
19631963 """Deploys the built Model.
@@ -2054,13 +2054,14 @@ def deploy(
20542054 )
20552055 if self ._deployables .get ("CustomOrchestrator" , None ):
20562056 custom_orchestrator = self ._deployables .get ("CustomOrchestrator" )
2057+ if not custom_orchestrator_instance_type and not instance_type :
2058+ logger .warning (
2059+ "Deploying custom orchestrator as an endpoint but no instance type was "
2060+ "set. Defaulting to `ml.c5.xlarge`."
2061+ )
2062+ custom_orchestrator_instance_type = "ml.c5.xlarge"
2063+ custom_orchestrator_initial_instance_count = 1
20572064 if custom_orchestrator ["Mode" ] == "Endpoint" :
2058- if not custom_orchestrator_instance_type :
2059- logger .warning (
2060- "Deploying custom orchestrator as an endpoint but no instance type was "
2061- "set. Defaulting to `ml.c5.xlarge`."
2062- )
2063- custom_orchestrator_instance_type = "ml.c5.xlarge"
20642065 logger .info (
20652066 "Deploying custom orchestrator on instance type %s." ,
20662067 custom_orchestrator_instance_type ,
@@ -2073,13 +2074,18 @@ def deploy(
20732074 )
20742075 )
20752076 elif custom_orchestrator ["Mode" ] == "InferenceComponent" :
2077+ logger .info (
2078+ "Deploying custom orchestrator as an inference component "
2079+ f"to endpoint { endpoint_name } "
2080+ )
20762081 predictors .append (
20772082 self ._deploy_for_ic (
20782083 ic_data = custom_orchestrator ,
20792084 container_timeout_in_seconds = container_timeout_in_second ,
20802085 instance_type = custom_orchestrator_instance_type or instance_type ,
20812086 initial_instance_count = custom_orchestrator_initial_instance_count
20822087 or initial_instance_count ,
2088+ endpoint_name = endpoint_name ,
20832089 ** kwargs ,
20842090 )
20852091 )
0 commit comments