4949 SPECULATIVE_DRAFT_MODEL ,
5050 _is_inferentia_or_trainium ,
5151 _jumpstart_speculative_decoding ,
52+ _deployment_config_contains_draft_model ,
53+ _is_draft_model_jumpstart_provided ,
5254)
5355from sagemaker .serve .utils .predictors import (
5456 DjlLocalModePredictor ,
@@ -850,7 +852,7 @@ def _set_additional_model_source(
850852
851853 channel_name = _generate_channel_name (self .pysdk_model .additional_model_data_sources )
852854
853- if model_provider == "sagemaker" :
855+ if model_provider in [ "sagemaker" , "auto" ] :
854856 additional_model_data_sources = (
855857 self .pysdk_model .deployment_config .get ("DeploymentArgs" , {}).get (
856858 "AdditionalDataSources"
@@ -863,6 +865,15 @@ def _set_additional_model_source(
863865 speculative_decoding_config
864866 )
865867 if deployment_config :
868+ if model_provider == "sagemaker" and _is_draft_model_jumpstart_provided (
869+ deployment_config
870+ ):
871+ raise ValueError (
872+ "No `Sagemaker` provided draft model was found for "
873+ f"{ self .model } . Try setting `ModelProvider` "
874+ "to `Auto` instead."
875+ )
876+
866877 try :
867878 self .pysdk_model .set_deployment_config (
868879 config_name = deployment_config .get ("DeploymentConfigName" ),
@@ -878,12 +889,21 @@ def _set_additional_model_source(
878889 raise ValueError (
879890 "Cannot find deployment config compatible for optimization job."
880891 )
892+ else :
893+ if model_provider == "sagemaker" and _is_draft_model_jumpstart_provided (
894+ self .pysdk_model .deployment_config
895+ ):
896+ raise ValueError (
897+ "No `Sagemaker` provided draft model was found for "
898+ f"{ self .model } . Try setting `ModelProvider` "
899+ "to `Auto` instead."
900+ )
881901
882902 self .pysdk_model .env .update (
883903 {"OPTION_SPECULATIVE_DRAFT_MODEL" : f"{ SPECULATIVE_DRAFT_MODEL } /{ channel_name } /" }
884904 )
885905 self .pysdk_model .add_tags (
886- {"Key" : Tag .SPECULATIVE_DRAFT_MODEL_PROVIDER , "Value" : "sagemaker" },
906+ {"Key" : Tag .SPECULATIVE_DRAFT_MODEL_PROVIDER , "Value" : model_provider },
887907 )
888908 elif model_provider == "jumpstart" :
889909 _jumpstart_speculative_decoding (
@@ -911,15 +931,17 @@ def _find_compatible_deployment_config(
911931 for deployment_config in self .pysdk_model .list_deployment_configs ():
912932 image_uri = deployment_config .get ("deployment_config" , {}).get ("ImageUri" )
913933
914- if _is_image_compatible_with_optimization_job (image_uri ):
934+ if _is_image_compatible_with_optimization_job (
935+ image_uri
936+ ) and _deployment_config_contains_draft_model (deployment_config ):
915937 if (
916- model_provider == "sagemaker"
938+ model_provider in [ "sagemaker" , "auto" ]
917939 and deployment_config .get ("DeploymentArgs" , {}).get ("AdditionalDataSources" )
918940 ) or model_provider == "custom" :
919941 return deployment_config
920942
921943 # There's no matching config from jumpstart to add sagemaker draft model location
922- if model_provider == "sagemaker" :
944+ if model_provider in [ "sagemaker" , "auto" ] :
923945 return None
924946
925947 # fall back to the default jumpstart model deployment config for optimization job
0 commit comments