17
17
import logging
18
18
from typing import Dict , Any , Optional , Union , List , Tuple
19
19
20
- from sagemaker import Model
20
+ from sagemaker import Model , Session
21
21
from sagemaker .enums import Tag
22
+ from sagemaker .jumpstart .utils import accessors , get_eula_message
23
+
22
24
23
25
logger = logging .getLogger (__name__ )
24
26
@@ -164,6 +166,9 @@ def _extract_speculative_draft_model_provider(
164
166
if speculative_decoding_config is None :
165
167
return None
166
168
169
+ if speculative_decoding_config .get ("ModelProvider" ) == "JumpStart" :
170
+ return "jumpstart"
171
+
167
172
if speculative_decoding_config .get (
168
173
"ModelProvider"
169
174
) == "Custom" or speculative_decoding_config .get ("ModelSource" ):
@@ -292,7 +297,7 @@ def _generate_additional_model_data_sources(
292
297
},
293
298
}
294
299
if accept_eula :
295
- additional_model_data_source ["S3DataSource" ]["ModelAccessConfig" ] = {"ACCEPT_EULA " : True }
300
+ additional_model_data_source ["S3DataSource" ]["ModelAccessConfig" ] = {"AcceptEula " : True }
296
301
297
302
return [additional_model_data_source ]
298
303
@@ -327,10 +332,10 @@ def _extract_optimization_config_and_env(
327
332
"""
328
333
optimization_config = {}
329
334
quantization_override_env = (
330
- quantization_config .get ("OverrideEnvironment" , {} ) if quantization_config else None
335
+ quantization_config .get ("OverrideEnvironment" ) if quantization_config else None
331
336
)
332
337
compilation_override_env = (
333
- compilation_config .get ("OverrideEnvironment" , {} ) if compilation_config else None
338
+ compilation_config .get ("OverrideEnvironment" ) if compilation_config else None
334
339
)
335
340
336
341
if quantization_config is not None :
@@ -343,7 +348,7 @@ def _extract_optimization_config_and_env(
343
348
if optimization_config :
344
349
return optimization_config , quantization_override_env , compilation_override_env
345
350
346
- return {} , None , None
351
+ return None , None , None
347
352
348
353
349
354
def _custom_speculative_decoding (
@@ -364,7 +369,7 @@ def _custom_speculative_decoding(
364
369
speculative_decoding_config
365
370
)
366
371
367
- accept_eula = speculative_decoding_config .get ("AcceptEula" , False )
372
+ accept_eula = speculative_decoding_config .get ("AcceptEula" , accept_eula )
368
373
369
374
if _is_s3_uri (additional_model_source ):
370
375
channel_name = _generate_channel_name (model .additional_model_data_sources )
@@ -384,6 +389,65 @@ def _custom_speculative_decoding(
384
389
return model
385
390
386
391
392
+ def _jumpstart_speculative_decoding (
393
+ model = Model ,
394
+ speculative_decoding_config : Optional [Dict [str , Any ]] = None ,
395
+ sagemaker_session : Optional [Session ] = None ,
396
+ ):
397
+ """Modifies the given model for speculative decoding config with JumpStart provider.
398
+
399
+ Args:
400
+ model (Model): The model.
401
+ speculative_decoding_config (Optional[Dict]): The speculative decoding config.
402
+ sagemaker_session (Optional[Session]): Sagemaker session for execution.
403
+ """
404
+ if speculative_decoding_config :
405
+ js_id = speculative_decoding_config .get ("ModelID" )
406
+ if not js_id :
407
+ raise ValueError (
408
+ "`ModelID` is a required field in `speculative_decoding_config` when "
409
+ "using JumpStart as draft model provider."
410
+ )
411
+ model_version = speculative_decoding_config .get ("ModelVersion" , "*" )
412
+ accept_eula = speculative_decoding_config .get ("AcceptEula" , False )
413
+ channel_name = _generate_channel_name (model .additional_model_data_sources )
414
+
415
+ model_specs = accessors .JumpStartModelsAccessor .get_model_specs (
416
+ model_id = js_id ,
417
+ version = model_version ,
418
+ region = sagemaker_session .boto_region_name ,
419
+ sagemaker_session = sagemaker_session ,
420
+ )
421
+ model_spec_json = model_specs .to_json ()
422
+
423
+ js_bucket = accessors .JumpStartModelsAccessor .get_jumpstart_content_bucket ()
424
+
425
+ if model_spec_json .get ("gated_bucket" , False ):
426
+ if not accept_eula :
427
+ eula_message = get_eula_message (
428
+ model_specs = model_specs , region = sagemaker_session .boto_region_name
429
+ )
430
+ raise ValueError (
431
+ f"{ eula_message } Please set `AcceptEula` to True in "
432
+ f"speculative_decoding_config once acknowledged."
433
+ )
434
+ js_bucket = accessors .JumpStartModelsAccessor .get_jumpstart_gated_content_bucket ()
435
+
436
+ key_prefix = model_spec_json .get ("hosting_prepacked_artifact_key" )
437
+ model .additional_model_data_sources = _generate_additional_model_data_sources (
438
+ f"s3://{ js_bucket } /{ key_prefix } " ,
439
+ channel_name ,
440
+ accept_eula ,
441
+ )
442
+
443
+ model .env .update (
444
+ {"OPTION_SPECULATIVE_DRAFT_MODEL" : f"{ SPECULATIVE_DRAFT_MODEL } /{ channel_name } /" }
445
+ )
446
+ model .add_tags (
447
+ {"Key" : Tag .SPECULATIVE_DRAFT_MODEL_PROVIDER , "Value" : "jumpstart" },
448
+ )
449
+
450
+
387
451
def _validate_and_set_eula_for_draft_model_sources (
388
452
pysdk_model : Model ,
389
453
accept_eula : bool = False ,
0 commit comments