Skip to content

Commit cf70f59

Browse files
author
Joseph Zhang
committed
Require EULA acceptance when using a gated 1p draft model via ModelBuilder.
1 parent 7ec16e6 commit cf70f59

File tree

6 files changed

+321
-10
lines changed

6 files changed

+321
-10
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
_custom_speculative_decoding,
4949
SPECULATIVE_DRAFT_MODEL,
5050
_is_inferentia_or_trainium,
51+
_validate_and_set_eula_for_draft_model_sources,
5152
)
5253
from sagemaker.serve.utils.predictors import (
5354
DjlLocalModePredictor,
@@ -733,10 +734,6 @@ def _optimize_for_jumpstart(
733734
if (
734735
not optimization_config or not optimization_config.get("ModelCompilationConfig")
735736
) and is_compilation:
736-
# Ensure optimization_config exists
737-
if not optimization_config:
738-
optimization_config = {}
739-
740737
# Fallback to default if override_env is None or empty
741738
if not compilation_override_env:
742739
compilation_override_env = pysdk_model_env_vars
@@ -867,6 +864,11 @@ def _set_additional_model_source(
867864
"Cannot find deployment config compatible for optimization job."
868865
)
869866

867+
_validate_and_set_eula_for_draft_model_sources(
868+
pysdk_model=self.pysdk_model,
869+
accept_eula=speculative_decoding_config.get("AcceptEula"),
870+
)
871+
870872
self.pysdk_model.env.update(
871873
{"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}"}
872874
)

src/sagemaker/serve/builder/model_builder.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
validate_image_uri_and_hardware,
100100
)
101101
from sagemaker.utils import Tags
102+
from sagemaker.serve.utils.optimize_utils import _validate_and_set_eula_for_draft_model_sources
102103
from sagemaker.workflow.entities import PipelineVariable
103104
from sagemaker.huggingface.llm_utils import (
104105
get_huggingface_model_metadata,
@@ -589,6 +590,21 @@ def _model_builder_deploy_wrapper(
589590
model_server=self.model_server,
590591
)
591592

593+
if self.deployment_config:
594+
accept_draft_model_eula = kwargs.get("accept_draft_model_eula", False)
595+
try:
596+
_validate_and_set_eula_for_draft_model_sources(
597+
pysdk_model=self,
598+
accept_eula=accept_draft_model_eula,
599+
)
600+
except ValueError as e:
601+
logger.error(
602+
"This deployment tried to use a gated draft model but the EULA was not "
603+
"accepted. Please review the EULA, set accept_draft_model_eula to True, "
604+
"and try again."
605+
)
606+
raise e
607+
592608
if "endpoint_logging" not in kwargs:
593609
kwargs["endpoint_logging"] = True
594610
predictor = self._original_deploy(

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,60 @@ def _extract_speculative_draft_model_provider(
172172
return "sagemaker"
173173

174174

175+
def _extract_additional_model_data_source_s3_uri(
176+
additional_model_data_source: Optional[Dict] = None,
177+
) -> Optional[str]:
178+
"""Extracts model data source s3 uri from a model data source in Pascal case.
179+
180+
Args:
181+
additional_model_data_source (Optional[Dict]): A model data source.
182+
183+
Returns:
184+
str: S3 uri of the model resources.
185+
"""
186+
if (
187+
additional_model_data_source is None
188+
or additional_model_data_source.get("S3DataSource", None) is None
189+
):
190+
return None
191+
192+
return additional_model_data_source.get("S3DataSource").get("S3Uri", None)
193+
194+
195+
def _extract_deployment_config_additional_model_data_source_s3_uri(
196+
additional_model_data_source: Optional[Dict] = None,
197+
) -> Optional[str]:
198+
"""Extracts model data source s3 uri from a model data source in snake case.
199+
200+
Args:
201+
additional_model_data_source (Optional[Dict]): A model data source.
202+
203+
Returns:
204+
str: S3 uri of the model resources.
205+
"""
206+
if (
207+
additional_model_data_source is None
208+
or additional_model_data_source.get("s3_data_source", None) is None
209+
):
210+
return None
211+
212+
return additional_model_data_source.get("s3_data_source").get("s3_uri", None)
213+
214+
215+
def _is_draft_model_gated(
216+
draft_model_config: Optional[Dict] = None,
217+
) -> bool:
218+
"""Extracts model gated-ness from draft model data source.
219+
220+
Args:
221+
draft_model_config (Optional[Dict]): A model data source.
222+
223+
Returns:
224+
bool: Whether the draft model is gated or not.
225+
"""
226+
return draft_model_config.get("hosting_eula_key", None)
227+
228+
175229
def _extracts_and_validates_speculative_model_source(
176230
speculative_decoding_config: Dict,
177231
) -> str:
@@ -289,7 +343,7 @@ def _extract_optimization_config_and_env(
289343
if optimization_config:
290344
return optimization_config, quantization_override_env, compilation_override_env
291345

292-
return None, None, None
346+
return {}, None, None
293347

294348

295349
def _custom_speculative_decoding(
@@ -310,6 +364,8 @@ def _custom_speculative_decoding(
310364
speculative_decoding_config
311365
)
312366

367+
accept_eula = speculative_decoding_config.get("AcceptEula", False)
368+
313369
if _is_s3_uri(additional_model_source):
314370
channel_name = _generate_channel_name(model.additional_model_data_sources)
315371
speculative_draft_model = f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}"
@@ -326,3 +382,78 @@ def _custom_speculative_decoding(
326382
)
327383

328384
return model
385+
386+
387+
def _validate_and_set_eula_for_draft_model_sources(
388+
pysdk_model: Model,
389+
accept_eula: bool = False,
390+
):
391+
"""Validates whether the EULA has been accepted for gated additional draft model sources.
392+
393+
If accepted, updates the model data source's model access config.
394+
395+
Args:
396+
pysdk_model (Model): The model whose additional model data sources to check.
397+
accept_eula (bool): EULA acceptance for the draft model.
398+
"""
399+
if not pysdk_model:
400+
return
401+
402+
deployment_config_draft_model_sources = (
403+
pysdk_model.deployment_config.get("DeploymentArgs", {})
404+
.get("AdditionalDataSources", {})
405+
.get("speculative_decoding", [])
406+
if pysdk_model.deployment_config
407+
else None
408+
)
409+
pysdk_model_additional_model_sources = pysdk_model.additional_model_data_sources
410+
411+
if not deployment_config_draft_model_sources or not pysdk_model_additional_model_sources:
412+
return
413+
414+
# Gated/ungated classification is only available through deployment_config.
415+
# Thus we must check each draft model in the deployment_config and see if it is set
416+
# as an additional model data source on the PySDK model itself.
417+
model_access_config_updated = False
418+
for source in deployment_config_draft_model_sources:
419+
if source.get("channel_name") != "draft_model":
420+
continue
421+
422+
if not _is_draft_model_gated(source):
423+
continue
424+
425+
deployment_config_draft_model_source_s3_uri = (
426+
_extract_deployment_config_additional_model_data_source_s3_uri(source)
427+
)
428+
429+
# If EULA is accepted, proceed with modifying the draft model data source
430+
for additional_source in pysdk_model_additional_model_sources:
431+
if additional_source.get("ChannelName") != "draft_model":
432+
continue
433+
434+
# Verify the pysdk model source and deployment config model source match
435+
pysdk_model_source_s3_uri = _extract_additional_model_data_source_s3_uri(
436+
additional_source
437+
)
438+
if deployment_config_draft_model_source_s3_uri not in pysdk_model_source_s3_uri:
439+
continue
440+
441+
if not accept_eula:
442+
raise ValueError(
443+
"Gated draft model requires accepting end-user license agreement (EULA)."
444+
)
445+
446+
# Set ModelAccessConfig.AcceptEula to True
447+
updated_source = additional_source.copy()
448+
updated_source["S3DataSource"]["ModelAccessConfig"] = {"AcceptEula": True}
449+
450+
index = pysdk_model.additional_model_data_sources.index(additional_source)
451+
pysdk_model.additional_model_data_sources[index] = updated_source
452+
453+
model_access_config_updated = True
454+
break
455+
456+
if model_access_config_updated:
457+
break
458+
459+
return

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
LocalModelOutOfMemoryException,
2626
LocalModelInvocationException,
2727
)
28-
from tests.unit.sagemaker.serve.constants import DEPLOYMENT_CONFIGS
28+
from tests.unit.sagemaker.serve.constants import (
29+
DEPLOYMENT_CONFIGS,
30+
OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL,
31+
)
2932

3033
mock_model_id = "huggingface-llm-amazon-falconlite"
3134
mock_t5_model_id = "google/flan-t5-xxl"
@@ -1198,6 +1201,51 @@ def test_optimize_quantize_for_jumpstart(
11981201

11991202
self.assertIsNotNone(out_put)
12001203

1204+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
1205+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
1206+
def test_optimize_gated_draft_model_for_jumpstart_with_accept_eula_false(
1207+
self,
1208+
mock_serve_settings,
1209+
mock_telemetry,
1210+
):
1211+
mock_sagemaker_session = Mock()
1212+
1213+
mock_pysdk_model = Mock()
1214+
mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"}
1215+
mock_pysdk_model.model_data = mock_model_data
1216+
mock_pysdk_model.image_uri = mock_tgi_image_uri
1217+
mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS
1218+
mock_pysdk_model.deployment_config = OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL
1219+
1220+
sample_input = {
1221+
"inputs": "The diamondback terrapin or simply terrapin is a species "
1222+
"of turtle native to the brackish coastal tidal marshes of the",
1223+
"parameters": {"max_new_tokens": 1024},
1224+
}
1225+
sample_output = [
1226+
{
1227+
"generated_text": "The diamondback terrapin or simply terrapin is a "
1228+
"species of turtle native to the brackish coastal "
1229+
"tidal marshes of the east coast."
1230+
}
1231+
]
1232+
1233+
model_builder = ModelBuilder(
1234+
model="meta-textgeneration-llama-3-70b",
1235+
schema_builder=SchemaBuilder(sample_input, sample_output),
1236+
sagemaker_session=mock_sagemaker_session,
1237+
)
1238+
1239+
model_builder.pysdk_model = mock_pysdk_model
1240+
1241+
self.assertRaises(
1242+
ValueError,
1243+
model_builder._optimize_for_jumpstart(
1244+
accept_eula=True,
1245+
speculative_decoding_config={"Provider": "sagemaker", "AcceptEula": False},
1246+
),
1247+
)
1248+
12011249
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
12021250
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
12031251
def test_optimize_quantize_and_compile_for_jumpstart(
@@ -1248,10 +1296,6 @@ def test_optimize_quantize_and_compile_for_jumpstart(
12481296
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
12491297
},
12501298
compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}},
1251-
env_vars={
1252-
"OPTION_TENSOR_PARALLEL_DEGREE": "1",
1253-
"OPTION_MAX_ROLLING_BATCH_SIZE": "2",
1254-
},
12551299
output_path="s3://bucket/code/",
12561300
)
12571301

tests/unit/sagemaker/serve/constants.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,105 @@
165165
},
166166
},
167167
]
168+
OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL = {
169+
"DeploymentConfigName": "lmi-optimized",
170+
"DeploymentArgs": {
171+
"ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
172+
"djl-inference:0.29.0-lmi11.0.0-cu124",
173+
"ModelData": {
174+
"S3DataSource": {
175+
"S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/"
176+
"meta-textgeneration-llama-3-1-70b/artifacts/inference-prepack/v2.0.0/",
177+
"S3DataType": "S3Prefix",
178+
"CompressionType": "None",
179+
}
180+
},
181+
"ModelPackageArn": None,
182+
"Environment": {
183+
"SAGEMAKER_PROGRAM": "inference.py",
184+
"ENDPOINT_SERVER_TIMEOUT": "3600",
185+
"MODEL_CACHE_ROOT": "/opt/ml/model",
186+
"SAGEMAKER_ENV": "1",
187+
"HF_MODEL_ID": "/opt/ml/model",
188+
"OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model",
189+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
190+
},
191+
"InstanceType": "ml.g6.2xlarge",
192+
"ComputeResourceRequirements": {
193+
"MinMemoryRequiredInMb": 131072,
194+
"NumberOfAcceleratorDevicesRequired": 1,
195+
},
196+
"ModelDataDownloadTimeout": 1200,
197+
"ContainerStartupHealthCheckTimeout": 1200,
198+
"AdditionalDataSources": {
199+
"speculative_decoding": [
200+
{
201+
"channel_name": "draft_model",
202+
"provider": {"name": "JumpStart", "classification": "gated"},
203+
"artifact_version": "v1",
204+
"hosting_eula_key": "fmhMetadata/eula/llama3_2Eula.txt",
205+
"s3_data_source": {
206+
"s3_uri": "meta-textgeneration/meta-textgeneration-llama-3-2-1b/artifacts/"
207+
"inference-prepack/v1.0.0/",
208+
"compression_type": "None",
209+
"s3_data_type": "S3Prefix",
210+
},
211+
}
212+
]
213+
},
214+
},
215+
"AccelerationConfigs": [
216+
{
217+
"type": "Compilation",
218+
"enabled": False,
219+
"diy_workflow_overrides": {
220+
"gpu-lmi-trt": {
221+
"enabled": False,
222+
"reason": "TRT-LLM 0.11.0 in LMI v11 does not support llama 3.1",
223+
}
224+
},
225+
},
226+
{
227+
"type": "Speculative-Decoding",
228+
"enabled": True,
229+
"diy_workflow_overrides": {
230+
"gpu-lmi-trt": {
231+
"enabled": False,
232+
"reason": "LMI v11 does not support Speculative Decoding for TRT",
233+
}
234+
},
235+
},
236+
{
237+
"type": "Quantization",
238+
"enabled": False,
239+
"diy_workflow_overrides": {
240+
"gpu-lmi-trt": {
241+
"enabled": False,
242+
"reason": "TRT-LLM 0.11.0 in LMI v11 does not support llama 3.1",
243+
}
244+
},
245+
},
246+
],
247+
"BenchmarkMetrics": {"ml.g6.2xlarge": None},
248+
}
249+
GATED_DRAFT_MODEL_CONFIG = {
250+
"channel_name": "draft_model",
251+
"provider": {"name": "JumpStart", "classification": "gated"},
252+
"artifact_version": "v1",
253+
"hosting_eula_key": "fmhMetadata/eula/llama3_2Eula.txt",
254+
"s3_data_source": {
255+
"s3_uri": "meta-textgeneration/meta-textgeneration-llama-3-2-1b/artifacts/"
256+
"inference-prepack/v1.0.0/",
257+
"compression_type": "None",
258+
"s3_data_type": "S3Prefix",
259+
},
260+
}
261+
NON_GATED_DRAFT_MODEL_CONFIG = {
262+
"channel_name": "draft_model",
263+
"s3_data_source": {
264+
"compression_type": "None",
265+
"s3_data_type": "S3Prefix",
266+
"s3_uri": "s3://sagemaker-sd-models-beta-us-west-2/"
267+
"sagemaker-speculative-decoding-llama3-small-v3/",
268+
},
269+
}

0 commit comments

Comments
 (0)