Skip to content

Commit 8b73f34

Browse files
author
Joseph Zhang
committed
ModelBuilder speculative decoding UTs and minor fixes.
1 parent 8f0083b commit 8b73f34

File tree

7 files changed

+245
-33
lines changed

7 files changed

+245
-33
lines changed

src/sagemaker/jumpstart/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,9 +1122,9 @@ def to_json(self, exclude_keys=True) -> Dict[str, Any]:
11221122
class JumpStartModelDataSource(AdditionalModelDataSource):
11231123
"""Data class JumpStart additional model data source."""
11241124

1125-
SERIALIZATION_EXCLUSION_SET = {
1125+
SERIALIZATION_EXCLUSION_SET = AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET.union(
11261126
"artifact_version"
1127-
} | AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET
1127+
)
11281128

11291129
__slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__
11301130

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -737,9 +737,7 @@ def _optimize_for_jumpstart(
737737
if not optimization_config:
738738
optimization_config = {}
739739

740-
if (
741-
not optimization_config or not optimization_config.get("ModelCompilationConfig")
742-
) and is_compilation:
740+
if not optimization_config.get("ModelCompilationConfig") and is_compilation:
743741
# Fallback to default if override_env is None or empty
744742
if not compilation_override_env:
745743
compilation_override_env = pysdk_model_env_vars
@@ -907,7 +905,9 @@ def _set_additional_model_source(
907905
)
908906
else:
909907
self.pysdk_model = _custom_speculative_decoding(
910-
self.pysdk_model, speculative_decoding_config, speculative_decoding_config.get("AcceptEula", False)
908+
self.pysdk_model,
909+
speculative_decoding_config,
910+
speculative_decoding_config.get("AcceptEula", False),
911911
)
912912

913913
def _find_compatible_deployment_config(

src/sagemaker/serve/builder/model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def _model_builder_deploy_wrapper(
591591
)
592592

593593
if "endpoint_logging" not in kwargs:
594-
kwargs["endpoint_logging"] = True
594+
kwargs["endpoint_logging"] = False
595595
predictor = self._original_deploy(
596596
*args,
597597
instance_type=instance_type,

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,8 @@ def _deployment_config_contains_draft_model(deployment_config: Optional[Dict]) -
7373
return False
7474
deployment_args = deployment_config.get("DeploymentArgs", {})
7575
additional_data_sources = deployment_args.get("AdditionalDataSources")
76-
if not additional_data_sources:
77-
return False
78-
return additional_data_sources.get("speculative_decoding", False)
76+
77+
return "speculative_decoding" in additional_data_sources if additional_data_sources else False
7978

8079

8180
def _is_draft_model_jumpstart_provided(deployment_config: Optional[Dict]) -> bool:
@@ -207,15 +206,15 @@ def _extract_speculative_draft_model_provider(
207206
if speculative_decoding_config is None:
208207
return None
209208

210-
if speculative_decoding_config.get("ModelProvider").lower() == "jumpstart":
209+
model_provider = speculative_decoding_config.get("ModelProvider", "").lower()
210+
211+
if model_provider == "jumpstart":
211212
return "jumpstart"
212213

213-
if speculative_decoding_config.get(
214-
"ModelProvider"
215-
).lower() == "custom" or speculative_decoding_config.get("ModelSource"):
214+
if model_provider == "custom" or speculative_decoding_config.get("ModelSource"):
216215
return "custom"
217216

218-
if speculative_decoding_config.get("ModelProvider").lower() == "sagemaker":
217+
if model_provider == "sagemaker":
219218
return "sagemaker"
220219

221220
return "auto"
@@ -238,7 +237,7 @@ def _extract_additional_model_data_source_s3_uri(
238237
):
239238
return None
240239

241-
return additional_model_data_source.get("S3DataSource").get("S3Uri", None)
240+
return additional_model_data_source.get("S3DataSource").get("S3Uri")
242241

243242

244243
def _extract_deployment_config_additional_model_data_source_s3_uri(
@@ -272,7 +271,7 @@ def _is_draft_model_gated(
272271
Returns:
273272
bool: Whether the draft model is gated or not.
274273
"""
275-
return draft_model_config.get("hosting_eula_key", None)
274+
return "hosting_eula_key" in draft_model_config if draft_model_config else False
276275

277276

278277
def _extracts_and_validates_speculative_model_source(
@@ -371,7 +370,7 @@ def _extract_optimization_config_and_env(
371370
compilation_config (Optional[Dict]): The compilation config.
372371
373372
Returns:
374-
Optional[Tuple[Optional[Dict], Optional[Dict]]]:
373+
Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]:
375374
The optimization config and environment variables.
376375
"""
377376
optimization_config = {}
@@ -388,7 +387,7 @@ def _extract_optimization_config_and_env(
388387
if compilation_config is not None:
389388
optimization_config["ModelCompilationConfig"] = compilation_config
390389

391-
# Return both dicts and environment variable if either is present
390+
# Return optimization config dict and environment variables if either is present
392391
if optimization_config:
393392
return optimization_config, quantization_override_env, compilation_override_env
394393

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tests.unit.sagemaker.serve.constants import (
2929
DEPLOYMENT_CONFIGS,
3030
OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL,
31+
CAMEL_CASE_ADDTL_DRAFT_MODEL_DATA_SOURCES,
3132
)
3233

3334
mock_model_id = "huggingface-llm-amazon-falconlite"
@@ -1203,19 +1204,34 @@ def test_optimize_quantize_for_jumpstart(
12031204

12041205
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
12051206
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
1206-
def test_optimize_gated_draft_model_for_jumpstart_with_accept_eula_false(
1207+
@patch(
1208+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
1209+
return_value=True,
1210+
)
1211+
@patch(
1212+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
1213+
return_value=MagicMock(),
1214+
)
1215+
@patch(
1216+
"sagemaker.serve.builder.jumpstart_builder._jumpstart_speculative_decoding",
1217+
return_value=True,
1218+
)
1219+
def test_jumpstart_model_provider_calls_jumpstart_speculative_decoding(
12071220
self,
1221+
mock_js_speculative_decoding,
1222+
mock_pretrained_js_model,
1223+
mock_is_js_model,
12081224
mock_serve_settings,
1209-
mock_telemetry,
1225+
mock_capture_telemetry,
12101226
):
12111227
mock_sagemaker_session = Mock()
1212-
12131228
mock_pysdk_model = Mock()
12141229
mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"}
12151230
mock_pysdk_model.model_data = mock_model_data
12161231
mock_pysdk_model.image_uri = mock_tgi_image_uri
12171232
mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS
12181233
mock_pysdk_model.deployment_config = OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL
1234+
mock_pysdk_model.additional_model_data_sources = CAMEL_CASE_ADDTL_DRAFT_MODEL_DATA_SOURCES
12191235

12201236
sample_input = {
12211237
"inputs": "The diamondback terrapin or simply terrapin is a species "
@@ -1238,14 +1254,17 @@ def test_optimize_gated_draft_model_for_jumpstart_with_accept_eula_false(
12381254

12391255
model_builder.pysdk_model = mock_pysdk_model
12401256

1241-
self.assertRaises(
1242-
ValueError,
1243-
model_builder._optimize_for_jumpstart(
1244-
accept_eula=True,
1245-
speculative_decoding_config={"Provider": "sagemaker", "AcceptEula": False},
1246-
),
1257+
model_builder._optimize_for_jumpstart(
1258+
accept_eula=True,
1259+
speculative_decoding_config={
1260+
"ModelProvider": "JumpStart",
1261+
"ModelID": "meta-textgeneration-llama-3-2-1b",
1262+
"AcceptEula": False,
1263+
},
12471264
)
12481265

1266+
mock_js_speculative_decoding.assert_called_once()
1267+
12491268
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
12501269
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
12511270
def test_optimize_quantize_and_compile_for_jumpstart(

tests/unit/sagemaker/serve/constants.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,43 @@
165165
},
166166
},
167167
]
168+
NON_OPTIMIZED_DEPLOYMENT_CONFIG = {
169+
"ConfigName": "neuron-inference",
170+
"BenchmarkMetrics": [
171+
{"name": "Latency", "value": "100", "unit": "Tokens/S"},
172+
{"name": "Throughput", "value": "1867", "unit": "Tokens/S"},
173+
],
174+
"DeploymentArgs": {
175+
"ModelDataDownloadTimeout": 1200,
176+
"ContainerStartupHealthCheckTimeout": 1200,
177+
"ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4"
178+
".0-gpu-py310-cu121-ubuntu20.04",
179+
"ModelData": {
180+
"S3DataSource": {
181+
"S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration"
182+
"-llama-2-7b/artifacts/inference-prepack/v1.0.0/",
183+
"S3DataType": "S3Prefix",
184+
"CompressionType": "None",
185+
}
186+
},
187+
"InstanceType": "ml.p2.xlarge",
188+
"Environment": {
189+
"SAGEMAKER_PROGRAM": "inference.py",
190+
"ENDPOINT_SERVER_TIMEOUT": "3600",
191+
"MODEL_CACHE_ROOT": "/opt/ml/model",
192+
"SAGEMAKER_ENV": "1",
193+
"HF_MODEL_ID": "/opt/ml/model",
194+
"MAX_INPUT_LENGTH": "4095",
195+
"MAX_TOTAL_TOKENS": "4096",
196+
"SM_NUM_GPUS": "1",
197+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
198+
},
199+
"ComputeResourceRequirements": {
200+
"MinMemoryRequiredInMb": 16384,
201+
"NumberOfAcceleratorDevicesRequired": 1,
202+
},
203+
},
204+
}
168205
OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL = {
169206
"DeploymentConfigName": "lmi-optimized",
170207
"DeploymentArgs": {
@@ -267,3 +304,14 @@
267304
"sagemaker-speculative-decoding-llama3-small-v3/",
268305
},
269306
}
307+
CAMEL_CASE_ADDTL_DRAFT_MODEL_DATA_SOURCES = [
308+
{
309+
"ChannelName": "draft_model",
310+
"S3DataSource": {
311+
"S3Uri": "meta-textgeneration/meta-textgeneration-llama-3-2-1b/artifacts/"
312+
"inference-prepack/v1.0.0/",
313+
"CompressionType": "None",
314+
"S3DataType": "S3Prefix",
315+
},
316+
}
317+
]

0 commit comments

Comments
 (0)