@@ -1605,3 +1605,84 @@ def test_optimize_compile_for_jumpstart_without_compilation_config(
1605
1605
self .assertEqual (optimized_model .env ["SAGEMAKER_ENV" ], "1" )
1606
1606
self .assertEqual (optimized_model .env ["HF_MODEL_ID" ], "/opt/ml/model" )
1607
1607
self .assertEqual (optimized_model .env ["SAGEMAKER_MODEL_SERVER_WORKERS" ], "1" )
1608
+
1609
+
1610
+ class TestJumpStartModelBuilderOptimizationUseCases (unittest .TestCase ):
1611
+
1612
+ @patch ("sagemaker.serve.builder.jumpstart_builder._capture_telemetry" , side_effect = None )
1613
+ @patch .object (ModelBuilder , "_get_serve_setting" , autospec = True )
1614
+ @patch (
1615
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model" ,
1616
+ return_value = True ,
1617
+ )
1618
+ @patch ("sagemaker.serve.builder.jumpstart_builder.JumpStartModel" )
1619
+ @patch (
1620
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id" ,
1621
+ return_value = True ,
1622
+ )
1623
+ @patch (
1624
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_fine_tuned_model" ,
1625
+ return_value = False ,
1626
+ )
1627
+ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations (
1628
+ self ,
1629
+ mock_is_fine_tuned ,
1630
+ mock_is_jumpstart_model ,
1631
+ mock_js_model ,
1632
+ mock_is_gated_model ,
1633
+ mock_serve_settings ,
1634
+ mock_telemetry ,
1635
+ ):
1636
+ mock_sagemaker_session = Mock ()
1637
+ mock_sagemaker_session .wait_for_optimization_job .side_effect = (
1638
+ lambda * args : mock_optimization_job_response
1639
+ )
1640
+
1641
+ mock_lmi_js_model = MagicMock ()
1642
+ mock_lmi_js_model .image_uri = mock_djl_image_uri
1643
+ mock_lmi_js_model .env = {
1644
+ "SAGEMAKER_PROGRAM" : "inference.py" ,
1645
+ "ENDPOINT_SERVER_TIMEOUT" : "3600" ,
1646
+ "MODEL_CACHE_ROOT" : "/opt/ml/model" ,
1647
+ "SAGEMAKER_ENV" : "1" ,
1648
+ "HF_MODEL_ID" : "/opt/ml/model" ,
1649
+ "OPTION_ENFORCE_EAGER" : "true" ,
1650
+ "SAGEMAKER_MODEL_SERVER_WORKERS" : "1" ,
1651
+ "OPTION_TENSOR_PARALLEL_DEGREE" : "8" ,
1652
+ }
1653
+
1654
+ mock_js_model .return_value = mock_lmi_js_model
1655
+
1656
+ model_builder = ModelBuilder (
1657
+ model = "meta-textgeneration-llama-3-1-70b-instruct" ,
1658
+ schema_builder = SchemaBuilder ("test" , "test" ),
1659
+ sagemaker_session = mock_sagemaker_session ,
1660
+ )
1661
+
1662
+ optimized_model = model_builder .optimize (
1663
+ accept_eula = True ,
1664
+ instance_type = "ml.g5.24xlarge" ,
1665
+ quantization_config = {
1666
+ "OverrideEnvironment" : {
1667
+ "OPTION_QUANTIZE" : "fp8" ,
1668
+ "OPTION_TENSOR_PARALLEL_DEGREE" : "4" ,
1669
+ },
1670
+ },
1671
+ output_path = "s3://bucket/code/" ,
1672
+ )
1673
+
1674
+ assert mock_lmi_js_model .set_deployment_config .call_args_list [0 ].kwargs == {
1675
+ "instance_type" : "ml.g5.24xlarge" ,
1676
+ "config_name" : "lmi" ,
1677
+ }
1678
+ assert optimized_model .env == {
1679
+ "SAGEMAKER_PROGRAM" : "inference.py" ,
1680
+ "ENDPOINT_SERVER_TIMEOUT" : "3600" ,
1681
+ "MODEL_CACHE_ROOT" : "/opt/ml/model" ,
1682
+ "SAGEMAKER_ENV" : "1" ,
1683
+ "HF_MODEL_ID" : "/opt/ml/model" ,
1684
+ "OPTION_ENFORCE_EAGER" : "true" ,
1685
+ "SAGEMAKER_MODEL_SERVER_WORKERS" : "1" ,
1686
+ "OPTION_TENSOR_PARALLEL_DEGREE" : "4" , # should be overridden from 8 to 4
1687
+ "OPTION_QUANTIZE" : "fp8" , # should be added to the env
1688
+ }
0 commit comments