@@ -300,6 +300,11 @@ def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800)
300300 returns:
301301 Tuned Model.
302302 """
303+ if self .mode == Mode .SAGEMAKER_ENDPOINT :
304+ logger .warning (
305+ "Tuning is only a %s capability. Returning original model." , Mode .LOCAL_CONTAINER
306+ )
307+ return self .pysdk_model
303308
304309 num_shard_env_var_name = "SM_NUM_GPUS"
305310 if "OPTION_TENSOR_PARALLEL_DEGREE" in self .pysdk_model .env .keys ():
@@ -468,58 +473,47 @@ def _build_for_jumpstart(self):
468473 self .secret_key = None
469474 self .jumpstart = True
470475
471- self .pysdk_model = self ._create_pre_trained_js_model ()
472- self .pysdk_model .tune = lambda * args , ** kwargs : self ._default_tune ()
473-
474- logger .info (
475- "JumpStart ID %s is packaged with Image URI: %s" , self .model , self .pysdk_model .image_uri
476- )
477-
478- if self .mode != Mode .SAGEMAKER_ENDPOINT :
479- if self ._is_gated_model (self .pysdk_model ):
480- raise ValueError (
481- "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
482- )
483-
484- if "djl-inference" in self .pysdk_model .image_uri :
485- logger .info ("Building for DJL JumpStart Model ID..." )
486- self .model_server = ModelServer .DJL_SERVING
487- self .image_uri = self .pysdk_model .image_uri
488-
489- self ._build_for_djl_jumpstart ()
490-
491- self .pysdk_model .tune = self .tune_for_djl_jumpstart
492- elif "tgi-inference" in self .pysdk_model .image_uri :
493- logger .info ("Building for TGI JumpStart Model ID..." )
494- self .model_server = ModelServer .TGI
495- self .image_uri = self .pysdk_model .image_uri
496-
497- self ._build_for_tgi_jumpstart ()
476+ pysdk_model = self ._create_pre_trained_js_model ()
477+ image_uri = pysdk_model .image_uri
498478
499- self .pysdk_model .tune = self .tune_for_tgi_jumpstart
500- elif "huggingface-pytorch-inference:" in self .pysdk_model .image_uri :
501- logger .info ("Building for MMS JumpStart Model ID..." )
502- self .model_server = ModelServer .MMS
503- self .image_uri = self .pysdk_model .image_uri
479+ logger .info ("JumpStart ID %s is packaged with Image URI: %s" , self .model , image_uri )
504480
505- self ._build_for_mms_jumpstart ()
506- else :
507- raise ValueError (
508- "JumpStart Model ID was not packaged "
509- "with djl-inference, tgi-inference, or mms-inference container."
510- )
511-
512- return self .pysdk_model
481+ if self ._is_gated_model (pysdk_model ) and self .mode != Mode .SAGEMAKER_ENDPOINT :
482+ raise ValueError (
483+ "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
484+ )
513485
514- def _default_tune (self ):
515- """Logs a warning message if tune is invoked on endpoint mode.
486+ if "djl-inference" in image_uri :
487+ logger .info ("Building for DJL JumpStart Model ID..." )
488+ self .model_server = ModelServer .DJL_SERVING
489+ self .pysdk_model = pysdk_model
490+ self .image_uri = self .pysdk_model .image_uri
491+
492+ self ._build_for_djl_jumpstart ()
493+
494+ self .pysdk_model .tune = self .tune_for_djl_jumpstart
495+ elif "tgi-inference" in image_uri :
496+ logger .info ("Building for TGI JumpStart Model ID..." )
497+ self .model_server = ModelServer .TGI
498+ self .pysdk_model = pysdk_model
499+ self .image_uri = self .pysdk_model .image_uri
500+
501+ self ._build_for_tgi_jumpstart ()
502+
503+ self .pysdk_model .tune = self .tune_for_tgi_jumpstart
504+ elif "huggingface-pytorch-inference:" in image_uri :
505+ logger .info ("Building for MMS JumpStart Model ID..." )
506+ self .model_server = ModelServer .MMS
507+ self .pysdk_model = pysdk_model
508+ self .image_uri = self .pysdk_model .image_uri
509+
510+ self ._build_for_mms_jumpstart ()
511+ elif self .mode != Mode .SAGEMAKER_ENDPOINT :
512+ raise ValueError (
513+ "JumpStart Model ID was not packaged "
514+ "with djl-inference, tgi-inference, or mms-inference container."
515+ )
516516
517- Returns:
518- Jumpstart Model: ``This`` model
519- """
520- logger .warning (
521- "Tuning is only a %s capability. Returning original model." , Mode .LOCAL_CONTAINER
522- )
523517 return self .pysdk_model
524518
525519 def _is_gated_model (self , model ) -> bool :
0 commit comments