2323from sagemaker import model_uris
2424from sagemaker .serve .model_server .djl_serving .prepare import prepare_djl_js_resources
2525from sagemaker .serve .model_server .djl_serving .utils import _get_admissible_tensor_parallel_degrees
26+ from sagemaker .serve .model_server .multi_model_server .prepare import prepare_mms_js_resources
2627from sagemaker .serve .model_server .tgi .prepare import prepare_tgi_js_resources , _create_dir_structure
2728from sagemaker .serve .mode .function_pointers import Mode
2829from sagemaker .serve .utils .exceptions import (
3536from sagemaker .serve .utils .predictors import (
3637 DjlLocalModePredictor ,
3738 TgiLocalModePredictor ,
39+ TransformersLocalModePredictor ,
3840)
3941from sagemaker .serve .utils .local_hardware import (
4042 _get_nb_instance ,
@@ -90,6 +92,7 @@ def __init__(self):
9092 self .existing_properties = None
9193 self .prepared_for_tgi = None
9294 self .prepared_for_djl = None
95+ self .prepared_for_mms = None
9396 self .schema_builder = None
9497 self .nb_instance_type = None
9598 self .ram_usage_model_load = None
@@ -137,7 +140,11 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
137140
138141 if overwrite_mode == Mode .SAGEMAKER_ENDPOINT :
139142 self .mode = self .pysdk_model .mode = Mode .SAGEMAKER_ENDPOINT
140- if not hasattr (self , "prepared_for_djl" ) or not hasattr (self , "prepared_for_tgi" ):
143+ if (
144+ not hasattr (self , "prepared_for_djl" )
145+ or not hasattr (self , "prepared_for_tgi" )
146+ or not hasattr (self , "prepared_for_mms" )
147+ ):
141148 self .pysdk_model .model_data , env = self ._prepare_for_mode ()
142149 elif overwrite_mode == Mode .LOCAL_CONTAINER :
143150 self .mode = self .pysdk_model .mode = Mode .LOCAL_CONTAINER
@@ -160,6 +167,13 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
160167 dependencies = self .dependencies ,
161168 model_data = self .pysdk_model .model_data ,
162169 )
170+ elif not hasattr (self , "prepared_for_mms" ):
171+ self .js_model_config , self .prepared_for_mms = prepare_mms_js_resources (
172+ model_path = self .model_path ,
173+ js_id = self .model ,
174+ dependencies = self .dependencies ,
175+ model_data = self .pysdk_model .model_data ,
176+ )
163177
164178 self ._prepare_for_mode ()
165179 env = {}
@@ -179,6 +193,10 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
179193 predictor = TgiLocalModePredictor (
180194 self .modes [str (Mode .LOCAL_CONTAINER )], serializer , deserializer
181195 )
196+ elif self .model_server == ModelServer .MMS :
197+ predictor = TransformersLocalModePredictor (
198+ self .modes [str (Mode .LOCAL_CONTAINER )], serializer , deserializer
199+ )
182200
183201 ram_usage_before = _get_ram_usage_mb ()
184202 self .modes [str (Mode .LOCAL_CONTAINER )].create_server (
@@ -254,6 +272,24 @@ def _build_for_tgi_jumpstart(self):
254272
255273 self .pysdk_model .env .update (env )
256274
275+ def _build_for_mms_jumpstart (self ):
276+ """Placeholder docstring"""
277+
278+ env = {}
279+ if self .mode == Mode .LOCAL_CONTAINER :
280+ if not hasattr (self , "prepared_for_mms" ):
281+ self .js_model_config , self .prepared_for_mms = prepare_mms_js_resources (
282+ model_path = self .model_path ,
283+ js_id = self .model ,
284+ dependencies = self .dependencies ,
285+ model_data = self .pysdk_model .model_data ,
286+ )
287+ self ._prepare_for_mode ()
288+ elif self .mode == Mode .SAGEMAKER_ENDPOINT and hasattr (self , "prepared_for_mms" ):
289+ self .pysdk_model .model_data , env = self ._prepare_for_mode ()
290+
291+ self .pysdk_model .env .update (env )
292+
257293 def _tune_for_js (self , sharded_supported : bool , max_tuning_duration : int = 1800 ):
258294 """Tune for Jumpstart Models in Local Mode.
259295
@@ -264,11 +300,6 @@ def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800)
264300 returns:
265301 Tuned Model.
266302 """
267- if self .mode != Mode .LOCAL_CONTAINER :
268- logger .warning (
269- "Tuning is only a %s capability. Returning original model." , Mode .LOCAL_CONTAINER
270- )
271- return self .pysdk_model
272303
273304 num_shard_env_var_name = "SM_NUM_GPUS"
274305 if "OPTION_TENSOR_PARALLEL_DEGREE" in self .pysdk_model .env .keys ():
@@ -437,42 +468,58 @@ def _build_for_jumpstart(self):
437468 self .secret_key = None
438469 self .jumpstart = True
439470
440- pysdk_model = self ._create_pre_trained_js_model ()
471+ self .pysdk_model = self ._create_pre_trained_js_model ()
472+ self .pysdk_model .tune = lambda * args , ** kwargs : self ._default_tune ()
441473
442- image_uri = pysdk_model .image_uri
474+ logger .info (
475+ "JumpStart ID %s is packaged with Image URI: %s" , self .model , self .pysdk_model .image_uri
476+ )
443477
444- logger .info ("JumpStart ID %s is packaged with Image URI: %s" , self .model , image_uri )
478+ if self .mode != Mode .SAGEMAKER_ENDPOINT :
479+ if self ._is_gated_model (self .pysdk_model ):
480+ raise ValueError (
481+ "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
482+ )
445483
446- if self . _is_gated_model ( pysdk_model ) and self .mode != Mode . SAGEMAKER_ENDPOINT :
447- raise ValueError (
448- "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
449- )
484+ if "djl-inference" in self .pysdk_model . image_uri :
485+ logger . info ( "Building for DJL JumpStart Model ID..." )
486+ self . model_server = ModelServer . DJL_SERVING
487+ self . image_uri = self . pysdk_model . image_uri
450488
451- if "djl-inference" in image_uri :
452- logger .info ("Building for DJL JumpStart Model ID..." )
453- self .model_server = ModelServer .DJL_SERVING
489+ self ._build_for_djl_jumpstart ()
454490
455- self .pysdk_model = pysdk_model
456- self .image_uri = self .pysdk_model .image_uri
491+ self .pysdk_model .tune = self .tune_for_djl_jumpstart
492+ elif "tgi-inference" in self .pysdk_model .image_uri :
493+ logger .info ("Building for TGI JumpStart Model ID..." )
494+ self .model_server = ModelServer .TGI
495+ self .image_uri = self .pysdk_model .image_uri
457496
458- self ._build_for_djl_jumpstart ()
497+ self ._build_for_tgi_jumpstart ()
459498
460- self .pysdk_model .tune = self .tune_for_djl_jumpstart
461- elif "tgi-inference" in image_uri :
462- logger .info ("Building for TGI JumpStart Model ID..." )
463- self .model_server = ModelServer .TGI
499+ self .pysdk_model .tune = self .tune_for_tgi_jumpstart
500+ elif "huggingface-pytorch-inference:" in self .pysdk_model .image_uri :
501+ logger .info ("Building for MMS JumpStart Model ID..." )
502+ self .model_server = ModelServer .MMS
503+ self .image_uri = self .pysdk_model .image_uri
464504
465- self .pysdk_model = pysdk_model
466- self .image_uri = self .pysdk_model .image_uri
505+ self ._build_for_mms_jumpstart ()
506+ else :
507+ raise ValueError (
508+ "JumpStart Model ID was not packaged "
509+ "with djl-inference, tgi-inference, or mms-inference container."
510+ )
467511
468- self ._build_for_tgi_jumpstart ()
512+ return self .pysdk_model
469513
470- self .pysdk_model .tune = self .tune_for_tgi_jumpstart
471- else :
472- raise ValueError (
473- "JumpStart Model ID was not packaged with djl-inference or tgi-inference container."
474- )
514+ def _default_tune (self ):
515+ """Logs a warning message if tune is invoked on endpoint mode.
475516
517+ Returns:
518+ Jumpstart Model: ``This`` model
519+ """
520+ logger .warning (
521+ "Tuning is only a %s capability. Returning original model." , Mode .LOCAL_CONTAINER
522+ )
476523 return self .pysdk_model
477524
478525 def _is_gated_model (self , model ) -> bool :
0 commit comments