@@ -1686,3 +1686,80 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations(
1686
1686
"OPTION_TENSOR_PARALLEL_DEGREE" : "4" , # should be overridden from 8 to 4
1687
1687
"OPTION_QUANTIZE" : "fp8" , # should be added to the env
1688
1688
}
1689
+
1690
+ @patch ("sagemaker.serve.builder.jumpstart_builder._capture_telemetry" , side_effect = None )
1691
+ @patch .object (ModelBuilder , "_get_serve_setting" , autospec = True )
1692
+ @patch (
1693
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model" ,
1694
+ return_value = True ,
1695
+ )
1696
+ @patch ("sagemaker.serve.builder.jumpstart_builder.JumpStartModel" )
1697
+ @patch (
1698
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id" ,
1699
+ return_value = True ,
1700
+ )
1701
+ @patch (
1702
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_fine_tuned_model" ,
1703
+ return_value = False ,
1704
+ )
1705
+ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_override (
1706
+ self ,
1707
+ mock_is_fine_tuned ,
1708
+ mock_is_jumpstart_model ,
1709
+ mock_js_model ,
1710
+ mock_is_gated_model ,
1711
+ mock_serve_settings ,
1712
+ mock_telemetry ,
1713
+ ):
1714
+ mock_sagemaker_session = Mock ()
1715
+ mock_sagemaker_session .wait_for_optimization_job .side_effect = (
1716
+ lambda * args : mock_optimization_job_response
1717
+ )
1718
+
1719
+ mock_lmi_js_model = MagicMock ()
1720
+ mock_lmi_js_model .image_uri = mock_djl_image_uri
1721
+ mock_lmi_js_model .env = {
1722
+ "SAGEMAKER_PROGRAM" : "inference.py" ,
1723
+ "ENDPOINT_SERVER_TIMEOUT" : "3600" ,
1724
+ "MODEL_CACHE_ROOT" : "/opt/ml/model" ,
1725
+ "SAGEMAKER_ENV" : "1" ,
1726
+ "HF_MODEL_ID" : "/opt/ml/model" ,
1727
+ "OPTION_ENFORCE_EAGER" : "true" ,
1728
+ "SAGEMAKER_MODEL_SERVER_WORKERS" : "1" ,
1729
+ "OPTION_TENSOR_PARALLEL_DEGREE" : "8" ,
1730
+ }
1731
+
1732
+ mock_js_model .return_value = mock_lmi_js_model
1733
+
1734
+ model_builder = ModelBuilder (
1735
+ model = "meta-textgeneration-llama-3-1-70b-instruct" ,
1736
+ schema_builder = SchemaBuilder ("test" , "test" ),
1737
+ sagemaker_session = mock_sagemaker_session ,
1738
+ )
1739
+
1740
+ optimized_model = model_builder .optimize (
1741
+ accept_eula = True ,
1742
+ instance_type = "ml.g5.24xlarge" ,
1743
+ quantization_config = {
1744
+ "OverrideEnvironment" : {
1745
+ "OPTION_QUANTIZE" : "fp8" ,
1746
+ },
1747
+ },
1748
+ output_path = "s3://bucket/code/" ,
1749
+ )
1750
+
1751
+ assert mock_lmi_js_model .set_deployment_config .call_args_list [0 ].kwargs == {
1752
+ "instance_type" : "ml.g5.24xlarge" ,
1753
+ "config_name" : "lmi" ,
1754
+ }
1755
+ assert optimized_model .env == {
1756
+ "SAGEMAKER_PROGRAM" : "inference.py" ,
1757
+ "ENDPOINT_SERVER_TIMEOUT" : "3600" ,
1758
+ "MODEL_CACHE_ROOT" : "/opt/ml/model" ,
1759
+ "SAGEMAKER_ENV" : "1" ,
1760
+ "HF_MODEL_ID" : "/opt/ml/model" ,
1761
+ "OPTION_ENFORCE_EAGER" : "true" ,
1762
+ "SAGEMAKER_MODEL_SERVER_WORKERS" : "1" ,
1763
+ "OPTION_TENSOR_PARALLEL_DEGREE" : "8" ,
1764
+ "OPTION_QUANTIZE" : "fp8" , # should be added to the env
1765
+ }
0 commit comments