Skip to content

Commit 37af43a

Browse files
committed
add e2e UT for lmi + .optimize()
1 parent 33dcf96 commit 37af43a

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,3 +1605,84 @@ def test_optimize_compile_for_jumpstart_without_compilation_config(
16051605
self.assertEqual(optimized_model.env["SAGEMAKER_ENV"], "1")
16061606
self.assertEqual(optimized_model.env["HF_MODEL_ID"], "/opt/ml/model")
16071607
self.assertEqual(optimized_model.env["SAGEMAKER_MODEL_SERVER_WORKERS"], "1")
1608+
1609+
1610+
class TestJumpStartModelBuilderOptimizationUseCases(unittest.TestCase):
1611+
1612+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
1613+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
1614+
@patch(
1615+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model",
1616+
return_value=True,
1617+
)
1618+
@patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel")
1619+
@patch(
1620+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
1621+
return_value=True,
1622+
)
1623+
@patch(
1624+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_fine_tuned_model",
1625+
return_value=False,
1626+
)
1627+
def test_optimize_on_js_model_should_ignore_pre_optimized_configurations(
1628+
self,
1629+
mock_is_fine_tuned,
1630+
mock_is_jumpstart_model,
1631+
mock_js_model,
1632+
mock_is_gated_model,
1633+
mock_serve_settings,
1634+
mock_telemetry,
1635+
):
1636+
mock_sagemaker_session = Mock()
1637+
mock_sagemaker_session.wait_for_optimization_job.side_effect = (
1638+
lambda *args: mock_optimization_job_response
1639+
)
1640+
1641+
mock_lmi_js_model = MagicMock()
1642+
mock_lmi_js_model.image_uri = mock_djl_image_uri
1643+
mock_lmi_js_model.env = {
1644+
"SAGEMAKER_PROGRAM": "inference.py",
1645+
"ENDPOINT_SERVER_TIMEOUT": "3600",
1646+
"MODEL_CACHE_ROOT": "/opt/ml/model",
1647+
"SAGEMAKER_ENV": "1",
1648+
"HF_MODEL_ID": "/opt/ml/model",
1649+
"OPTION_ENFORCE_EAGER": "true",
1650+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
1651+
"OPTION_TENSOR_PARALLEL_DEGREE": "8",
1652+
}
1653+
1654+
mock_js_model.return_value = mock_lmi_js_model
1655+
1656+
model_builder = ModelBuilder(
1657+
model="meta-textgeneration-llama-3-1-70b-instruct",
1658+
schema_builder=SchemaBuilder("test", "test"),
1659+
sagemaker_session=mock_sagemaker_session,
1660+
)
1661+
1662+
optimized_model = model_builder.optimize(
1663+
accept_eula=True,
1664+
instance_type="ml.g5.24xlarge",
1665+
quantization_config={
1666+
"OverrideEnvironment": {
1667+
"OPTION_QUANTIZE": "fp8",
1668+
"OPTION_TENSOR_PARALLEL_DEGREE": "4",
1669+
},
1670+
},
1671+
output_path="s3://bucket/code/",
1672+
)
1673+
1674+
assert mock_lmi_js_model.set_deployment_config.call_args_list[0].kwargs == {
1675+
"instance_type": "ml.g5.24xlarge",
1676+
"config_name": "lmi",
1677+
}
1678+
assert optimized_model.env == {
1679+
"SAGEMAKER_PROGRAM": "inference.py",
1680+
"ENDPOINT_SERVER_TIMEOUT": "3600",
1681+
"MODEL_CACHE_ROOT": "/opt/ml/model",
1682+
"SAGEMAKER_ENV": "1",
1683+
"HF_MODEL_ID": "/opt/ml/model",
1684+
"OPTION_ENFORCE_EAGER": "true",
1685+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
1686+
"OPTION_TENSOR_PARALLEL_DEGREE": "4", # should be overridden from 8 to 4
1687+
"OPTION_QUANTIZE": "fp8", # should be added to the env
1688+
}

0 commit comments

Comments
 (0)