@@ -1605,3 +1605,84 @@ def test_optimize_compile_for_jumpstart_without_compilation_config(
16051605 self .assertEqual (optimized_model .env ["SAGEMAKER_ENV" ], "1" )
16061606 self .assertEqual (optimized_model .env ["HF_MODEL_ID" ], "/opt/ml/model" )
16071607 self .assertEqual (optimized_model .env ["SAGEMAKER_MODEL_SERVER_WORKERS" ], "1" )
1608+
1609+
1610+ class TestJumpStartModelBuilderOptimizationUseCases (unittest .TestCase ):
1611+
1612+ @patch ("sagemaker.serve.builder.jumpstart_builder._capture_telemetry" , side_effect = None )
1613+ @patch .object (ModelBuilder , "_get_serve_setting" , autospec = True )
1614+ @patch (
1615+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model" ,
1616+ return_value = True ,
1617+ )
1618+ @patch ("sagemaker.serve.builder.jumpstart_builder.JumpStartModel" )
1619+ @patch (
1620+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id" ,
1621+ return_value = True ,
1622+ )
1623+ @patch (
1624+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_fine_tuned_model" ,
1625+ return_value = False ,
1626+ )
1627+ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations (
1628+ self ,
1629+ mock_is_fine_tuned ,
1630+ mock_is_jumpstart_model ,
1631+ mock_js_model ,
1632+ mock_is_gated_model ,
1633+ mock_serve_settings ,
1634+ mock_telemetry ,
1635+ ):
1636+ mock_sagemaker_session = Mock ()
1637+ mock_sagemaker_session .wait_for_optimization_job .side_effect = (
1638+ lambda * args : mock_optimization_job_response
1639+ )
1640+
1641+ mock_lmi_js_model = MagicMock ()
1642+ mock_lmi_js_model .image_uri = mock_djl_image_uri
1643+ mock_lmi_js_model .env = {
1644+ "SAGEMAKER_PROGRAM" : "inference.py" ,
1645+ "ENDPOINT_SERVER_TIMEOUT" : "3600" ,
1646+ "MODEL_CACHE_ROOT" : "/opt/ml/model" ,
1647+ "SAGEMAKER_ENV" : "1" ,
1648+ "HF_MODEL_ID" : "/opt/ml/model" ,
1649+ "OPTION_ENFORCE_EAGER" : "true" ,
1650+ "SAGEMAKER_MODEL_SERVER_WORKERS" : "1" ,
1651+ "OPTION_TENSOR_PARALLEL_DEGREE" : "8" ,
1652+ }
1653+
1654+ mock_js_model .return_value = mock_lmi_js_model
1655+
1656+ model_builder = ModelBuilder (
1657+ model = "meta-textgeneration-llama-3-1-70b-instruct" ,
1658+ schema_builder = SchemaBuilder ("test" , "test" ),
1659+ sagemaker_session = mock_sagemaker_session ,
1660+ )
1661+
1662+ optimized_model = model_builder .optimize (
1663+ accept_eula = True ,
1664+ instance_type = "ml.g5.24xlarge" ,
1665+ quantization_config = {
1666+ "OverrideEnvironment" : {
1667+ "OPTION_QUANTIZE" : "fp8" ,
1668+ "OPTION_TENSOR_PARALLEL_DEGREE" : "4" ,
1669+ },
1670+ },
1671+ output_path = "s3://bucket/code/" ,
1672+ )
1673+
1674+ assert mock_lmi_js_model .set_deployment_config .call_args_list [0 ].kwargs == {
1675+ "instance_type" : "ml.g5.24xlarge" ,
1676+ "config_name" : "lmi" ,
1677+ }
1678+ assert optimized_model .env == {
1679+ "SAGEMAKER_PROGRAM" : "inference.py" ,
1680+ "ENDPOINT_SERVER_TIMEOUT" : "3600" ,
1681+ "MODEL_CACHE_ROOT" : "/opt/ml/model" ,
1682+ "SAGEMAKER_ENV" : "1" ,
1683+ "HF_MODEL_ID" : "/opt/ml/model" ,
1684+ "OPTION_ENFORCE_EAGER" : "true" ,
1685+ "SAGEMAKER_MODEL_SERVER_WORKERS" : "1" ,
1686+ "OPTION_TENSOR_PARALLEL_DEGREE" : "4" , # should be overridden from 8 to 4
1687+ "OPTION_QUANTIZE" : "fp8" , # should be added to the env
1688+ }
0 commit comments