@@ -1686,3 +1686,80 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations(
16861686 "OPTION_TENSOR_PARALLEL_DEGREE" : "4" , # should be overridden from 8 to 4
16871687 "OPTION_QUANTIZE" : "fp8" , # should be added to the env
16881688 }
1689+
1690+ @patch ("sagemaker.serve.builder.jumpstart_builder._capture_telemetry" , side_effect = None )
1691+ @patch .object (ModelBuilder , "_get_serve_setting" , autospec = True )
1692+ @patch (
1693+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model" ,
1694+ return_value = True ,
1695+ )
1696+ @patch ("sagemaker.serve.builder.jumpstart_builder.JumpStartModel" )
1697+ @patch (
1698+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id" ,
1699+ return_value = True ,
1700+ )
1701+ @patch (
1702+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_fine_tuned_model" ,
1703+ return_value = False ,
1704+ )
1705+ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_override (
1706+ self ,
1707+ mock_is_fine_tuned ,
1708+ mock_is_jumpstart_model ,
1709+ mock_js_model ,
1710+ mock_is_gated_model ,
1711+ mock_serve_settings ,
1712+ mock_telemetry ,
1713+ ):
1714+ mock_sagemaker_session = Mock ()
1715+ mock_sagemaker_session .wait_for_optimization_job .side_effect = (
1716+ lambda * args : mock_optimization_job_response
1717+ )
1718+
1719+ mock_lmi_js_model = MagicMock ()
1720+ mock_lmi_js_model .image_uri = mock_djl_image_uri
1721+ mock_lmi_js_model .env = {
1722+ "SAGEMAKER_PROGRAM" : "inference.py" ,
1723+ "ENDPOINT_SERVER_TIMEOUT" : "3600" ,
1724+ "MODEL_CACHE_ROOT" : "/opt/ml/model" ,
1725+ "SAGEMAKER_ENV" : "1" ,
1726+ "HF_MODEL_ID" : "/opt/ml/model" ,
1727+ "OPTION_ENFORCE_EAGER" : "true" ,
1728+ "SAGEMAKER_MODEL_SERVER_WORKERS" : "1" ,
1729+ "OPTION_TENSOR_PARALLEL_DEGREE" : "8" ,
1730+ }
1731+
1732+ mock_js_model .return_value = mock_lmi_js_model
1733+
1734+ model_builder = ModelBuilder (
1735+ model = "meta-textgeneration-llama-3-1-70b-instruct" ,
1736+ schema_builder = SchemaBuilder ("test" , "test" ),
1737+ sagemaker_session = mock_sagemaker_session ,
1738+ )
1739+
1740+ optimized_model = model_builder .optimize (
1741+ accept_eula = True ,
1742+ instance_type = "ml.g5.24xlarge" ,
1743+ quantization_config = {
1744+ "OverrideEnvironment" : {
1745+ "OPTION_QUANTIZE" : "fp8" ,
1746+ },
1747+ },
1748+ output_path = "s3://bucket/code/" ,
1749+ )
1750+
1751+ assert mock_lmi_js_model .set_deployment_config .call_args_list [0 ].kwargs == {
1752+ "instance_type" : "ml.g5.24xlarge" ,
1753+ "config_name" : "lmi" ,
1754+ }
1755+ assert optimized_model .env == {
1756+ "SAGEMAKER_PROGRAM" : "inference.py" ,
1757+ "ENDPOINT_SERVER_TIMEOUT" : "3600" ,
1758+ "MODEL_CACHE_ROOT" : "/opt/ml/model" ,
1759+ "SAGEMAKER_ENV" : "1" ,
1760+ "HF_MODEL_ID" : "/opt/ml/model" ,
1761+ "OPTION_ENFORCE_EAGER" : "true" ,
1762+ "SAGEMAKER_MODEL_SERVER_WORKERS" : "1" ,
1763+ "OPTION_TENSOR_PARALLEL_DEGREE" : "8" ,
1764+ "OPTION_QUANTIZE" : "fp8" , # should be added to the env
1765+ }
0 commit comments