1717import logging
1818from typing import Dict , Any , Optional , Union , List , Tuple
1919
20- from sagemaker import Model
20+ from sagemaker import Model , Session
2121from sagemaker .enums import Tag
22+ from sagemaker .jumpstart .utils import accessors , get_eula_message
23+
2224
2325logger = logging .getLogger (__name__ )
2426
@@ -164,6 +166,9 @@ def _extract_speculative_draft_model_provider(
164166 if speculative_decoding_config is None :
165167 return None
166168
169+ if speculative_decoding_config .get ("ModelProvider" ) == "JumpStart" :
170+ return "jumpstart"
171+
167172 if speculative_decoding_config .get (
168173 "ModelProvider"
169174 ) == "Custom" or speculative_decoding_config .get ("ModelSource" ):
@@ -292,7 +297,7 @@ def _generate_additional_model_data_sources(
292297 },
293298 }
294299 if accept_eula :
295- additional_model_data_source ["S3DataSource" ]["ModelAccessConfig" ] = {"ACCEPT_EULA " : True }
300+ additional_model_data_source ["S3DataSource" ]["ModelAccessConfig" ] = {"AcceptEula " : True }
296301
297302 return [additional_model_data_source ]
298303
@@ -327,10 +332,10 @@ def _extract_optimization_config_and_env(
327332 """
328333 optimization_config = {}
329334 quantization_override_env = (
330- quantization_config .get ("OverrideEnvironment" , {} ) if quantization_config else None
335+ quantization_config .get ("OverrideEnvironment" ) if quantization_config else None
331336 )
332337 compilation_override_env = (
333- compilation_config .get ("OverrideEnvironment" , {} ) if compilation_config else None
338+ compilation_config .get ("OverrideEnvironment" ) if compilation_config else None
334339 )
335340
336341 if quantization_config is not None :
@@ -343,7 +348,7 @@ def _extract_optimization_config_and_env(
343348 if optimization_config :
344349 return optimization_config , quantization_override_env , compilation_override_env
345350
346- return {} , None , None
351+ return None , None , None
347352
348353
349354def _custom_speculative_decoding (
@@ -364,7 +369,7 @@ def _custom_speculative_decoding(
364369 speculative_decoding_config
365370 )
366371
367- accept_eula = speculative_decoding_config .get ("AcceptEula" , False )
372+ accept_eula = speculative_decoding_config .get ("AcceptEula" , accept_eula )
368373
369374 if _is_s3_uri (additional_model_source ):
370375 channel_name = _generate_channel_name (model .additional_model_data_sources )
@@ -384,6 +389,65 @@ def _custom_speculative_decoding(
384389 return model
385390
386391
392+ def _jumpstart_speculative_decoding (
393+ model = Model ,
394+ speculative_decoding_config : Optional [Dict [str , Any ]] = None ,
395+ sagemaker_session : Optional [Session ] = None ,
396+ ):
397+ """Modifies the given model for speculative decoding config with JumpStart provider.
398+
399+ Args:
400+ model (Model): The model.
401+ speculative_decoding_config (Optional[Dict]): The speculative decoding config.
402+ sagemaker_session (Optional[Session]): Sagemaker session for execution.
403+ """
404+ if speculative_decoding_config :
405+ js_id = speculative_decoding_config .get ("ModelID" )
406+ if not js_id :
407+ raise ValueError (
408+ "`ModelID` is a required field in `speculative_decoding_config` when "
409+ "using JumpStart as draft model provider."
410+ )
411+ model_version = speculative_decoding_config .get ("ModelVersion" , "*" )
412+ accept_eula = speculative_decoding_config .get ("AcceptEula" , False )
413+ channel_name = _generate_channel_name (model .additional_model_data_sources )
414+
415+ model_specs = accessors .JumpStartModelsAccessor .get_model_specs (
416+ model_id = js_id ,
417+ version = model_version ,
418+ region = sagemaker_session .boto_region_name ,
419+ sagemaker_session = sagemaker_session ,
420+ )
421+ model_spec_json = model_specs .to_json ()
422+
423+ js_bucket = accessors .JumpStartModelsAccessor .get_jumpstart_content_bucket ()
424+
425+ if model_spec_json .get ("gated_bucket" , False ):
426+ if not accept_eula :
427+ eula_message = get_eula_message (
428+ model_specs = model_specs , region = sagemaker_session .boto_region_name
429+ )
430+ raise ValueError (
431+ f"{ eula_message } Please set `AcceptEula` to True in "
432+ f"speculative_decoding_config once acknowledged."
433+ )
434+ js_bucket = accessors .JumpStartModelsAccessor .get_jumpstart_gated_content_bucket ()
435+
436+ key_prefix = model_spec_json .get ("hosting_prepacked_artifact_key" )
437+ model .additional_model_data_sources = _generate_additional_model_data_sources (
438+ f"s3://{ js_bucket } /{ key_prefix } " ,
439+ channel_name ,
440+ accept_eula ,
441+ )
442+
443+ model .env .update (
444+ {"OPTION_SPECULATIVE_DRAFT_MODEL" : f"{ SPECULATIVE_DRAFT_MODEL } /{ channel_name } /" }
445+ )
446+ model .add_tags (
447+ {"Key" : Tag .SPECULATIVE_DRAFT_MODEL_PROVIDER , "Value" : "jumpstart" },
448+ )
449+
450+
387451def _validate_and_set_eula_for_draft_model_sources (
388452 pysdk_model : Model ,
389453 accept_eula : bool = False ,
0 commit comments