Skip to content

Commit 54e995f

Browse files
author
Joseph Zhang
committed
Fix incorrect assignment of ModelCompilationConfig and add UT.
1 parent 70e00fd commit 54e995f

File tree

2 files changed

+110
-4
lines changed

2 files changed

+110
-4
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -733,11 +733,17 @@ def _optimize_for_jumpstart(
733733
if (
734734
not optimization_config or not optimization_config.get("ModelCompilationConfig")
735735
) and is_compilation:
736-
override_env = override_env or pysdk_model_env_vars
736+
# Ensure optimization_config exists
737+
if not optimization_config:
738+
optimization_config = {}
739+
740+
# Fallback to default if override_env is None or empty
741+
if not override_env:
742+
override_env = pysdk_model_env_vars
743+
744+
# Update optimization_config with ModelCompilationConfig
737745
optimization_config["ModelCompilationConfig"] = {
738-
"ModelCompilationConfig": {
739-
"OverrideEnvironment": override_env,
740-
}
746+
"OverrideEnvironment": override_env,
741747
}
742748

743749
if speculative_decoding_config:

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,3 +1442,103 @@ def test_optimize_compile_for_jumpstart_with_neuron_env(
14421442
self.assertEqual(optimized_model.env["OPTION_ROLLING_BATCH"], "auto")
14431443
self.assertEqual(optimized_model.env["OPTION_MAX_ROLLING_BATCH_SIZE"], "4")
14441444
self.assertEqual(optimized_model.env["OPTION_NEURON_OPTIMIZE_LEVEL"], "2")
1445+
1446+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
1447+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
1448+
@patch(
1449+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model",
1450+
return_value=True,
1451+
)
1452+
@patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel")
1453+
@patch(
1454+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
1455+
return_value=True,
1456+
)
1457+
@patch("sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model")
1458+
@patch(
1459+
"sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
1460+
return_value=({"model_type": "t5", "n_head": 71}, True),
1461+
)
1462+
def test_optimize_compile_for_jumpstart_without_compilation_config(
1463+
self,
1464+
mock_prepare_for_tgi,
1465+
mock_pre_trained_model,
1466+
mock_is_jumpstart_model,
1467+
mock_js_model,
1468+
mock_is_gated_model,
1469+
mock_serve_settings,
1470+
mock_telemetry,
1471+
):
1472+
mock_sagemaker_session = Mock()
1473+
mock_metadata_config = Mock()
1474+
mock_sagemaker_session.wait_for_optimization_job.side_effect = (
1475+
lambda *args: mock_optimization_job_response
1476+
)
1477+
1478+
mock_metadata_config.resolved_config = {
1479+
"supported_inference_instance_types": ["ml.inf2.48xlarge"],
1480+
"hosting_neuron_model_id": "huggingface-llmneuron-mistral-7b",
1481+
}
1482+
1483+
mock_js_model.return_value = MagicMock()
1484+
mock_js_model.return_value.env = {
1485+
"SAGEMAKER_PROGRAM": "inference.py",
1486+
"ENDPOINT_SERVER_TIMEOUT": "3600",
1487+
"MODEL_CACHE_ROOT": "/opt/ml/model",
1488+
"SAGEMAKER_ENV": "1",
1489+
"HF_MODEL_ID": "/opt/ml/model",
1490+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
1491+
}
1492+
1493+
mock_pre_trained_model.return_value = MagicMock()
1494+
mock_pre_trained_model.return_value.env = dict()
1495+
mock_pre_trained_model.return_value.config_name = "config_name"
1496+
mock_pre_trained_model.return_value.model_data = mock_model_data
1497+
mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
1498+
mock_pre_trained_model.return_value.list_deployment_configs.return_value = (
1499+
DEPLOYMENT_CONFIGS
1500+
)
1501+
mock_pre_trained_model.return_value.deployment_config = DEPLOYMENT_CONFIGS[0]
1502+
mock_pre_trained_model.return_value._metadata_configs = {
1503+
"config_name": mock_metadata_config
1504+
}
1505+
1506+
sample_input = {
1507+
"inputs": "The diamondback terrapin or simply terrapin is a species "
1508+
"of turtle native to the brackish coastal tidal marshes of the",
1509+
"parameters": {"max_new_tokens": 1024},
1510+
}
1511+
sample_output = [
1512+
{
1513+
"generated_text": "The diamondback terrapin or simply terrapin is a "
1514+
"species of turtle native to the brackish coastal "
1515+
"tidal marshes of the east coast."
1516+
}
1517+
]
1518+
1519+
model_builder = ModelBuilder(
1520+
model="meta-textgeneration-llama-3-70b",
1521+
schema_builder=SchemaBuilder(sample_input, sample_output),
1522+
sagemaker_session=mock_sagemaker_session,
1523+
)
1524+
1525+
optimized_model = model_builder.optimize(
1526+
accept_eula=True,
1527+
instance_type="ml.inf2.24xlarge",
1528+
output_path="s3://bucket/code/",
1529+
)
1530+
1531+
self.assertEqual(
1532+
optimized_model.image_uri,
1533+
mock_optimization_job_response["OptimizationOutput"]["RecommendedInferenceImage"],
1534+
)
1535+
self.assertEqual(
1536+
optimized_model.model_data["S3DataSource"]["S3Uri"],
1537+
mock_optimization_job_response["OutputConfig"]["S3OutputLocation"],
1538+
)
1539+
self.assertEqual(optimized_model.env["SAGEMAKER_PROGRAM"], "inference.py")
1540+
self.assertEqual(optimized_model.env["ENDPOINT_SERVER_TIMEOUT"], "3600")
1541+
self.assertEqual(optimized_model.env["MODEL_CACHE_ROOT"], "/opt/ml/model")
1542+
self.assertEqual(optimized_model.env["SAGEMAKER_ENV"], "1")
1543+
self.assertEqual(optimized_model.env["HF_MODEL_ID"], "/opt/ml/model")
1544+
self.assertEqual(optimized_model.env["SAGEMAKER_MODEL_SERVER_WORKERS"], "1")

0 commit comments

Comments
 (0)