@@ -1442,3 +1442,103 @@ def test_optimize_compile_for_jumpstart_with_neuron_env(
14421442 self .assertEqual (optimized_model .env ["OPTION_ROLLING_BATCH" ], "auto" )
14431443 self .assertEqual (optimized_model .env ["OPTION_MAX_ROLLING_BATCH_SIZE" ], "4" )
14441444 self .assertEqual (optimized_model .env ["OPTION_NEURON_OPTIMIZE_LEVEL" ], "2" )
1445+
1446+ @patch ("sagemaker.serve.builder.jumpstart_builder._capture_telemetry" , side_effect = None )
1447+ @patch .object (ModelBuilder , "_get_serve_setting" , autospec = True )
1448+ @patch (
1449+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model" ,
1450+ return_value = True ,
1451+ )
1452+ @patch ("sagemaker.serve.builder.jumpstart_builder.JumpStartModel" )
1453+ @patch (
1454+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id" ,
1455+ return_value = True ,
1456+ )
1457+ @patch ("sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model" )
1458+ @patch (
1459+ "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources" ,
1460+ return_value = ({"model_type" : "t5" , "n_head" : 71 }, True ),
1461+ )
1462+ def test_optimize_compile_for_jumpstart_without_compilation_config (
1463+ self ,
1464+ mock_prepare_for_tgi ,
1465+ mock_pre_trained_model ,
1466+ mock_is_jumpstart_model ,
1467+ mock_js_model ,
1468+ mock_is_gated_model ,
1469+ mock_serve_settings ,
1470+ mock_telemetry ,
1471+ ):
1472+ mock_sagemaker_session = Mock ()
1473+ mock_metadata_config = Mock ()
1474+ mock_sagemaker_session .wait_for_optimization_job .side_effect = (
1475+ lambda * args : mock_optimization_job_response
1476+ )
1477+
1478+ mock_metadata_config .resolved_config = {
1479+ "supported_inference_instance_types" : ["ml.inf2.48xlarge" ],
1480+ "hosting_neuron_model_id" : "huggingface-llmneuron-mistral-7b" ,
1481+ }
1482+
1483+ mock_js_model .return_value = MagicMock ()
1484+ mock_js_model .return_value .env = {
1485+ "SAGEMAKER_PROGRAM" : "inference.py" ,
1486+ "ENDPOINT_SERVER_TIMEOUT" : "3600" ,
1487+ "MODEL_CACHE_ROOT" : "/opt/ml/model" ,
1488+ "SAGEMAKER_ENV" : "1" ,
1489+ "HF_MODEL_ID" : "/opt/ml/model" ,
1490+ "SAGEMAKER_MODEL_SERVER_WORKERS" : "1" ,
1491+ }
1492+
1493+ mock_pre_trained_model .return_value = MagicMock ()
1494+ mock_pre_trained_model .return_value .env = dict ()
1495+ mock_pre_trained_model .return_value .config_name = "config_name"
1496+ mock_pre_trained_model .return_value .model_data = mock_model_data
1497+ mock_pre_trained_model .return_value .image_uri = mock_tgi_image_uri
1498+ mock_pre_trained_model .return_value .list_deployment_configs .return_value = (
1499+ DEPLOYMENT_CONFIGS
1500+ )
1501+ mock_pre_trained_model .return_value .deployment_config = DEPLOYMENT_CONFIGS [0 ]
1502+ mock_pre_trained_model .return_value ._metadata_configs = {
1503+ "config_name" : mock_metadata_config
1504+ }
1505+
1506+ sample_input = {
1507+ "inputs" : "The diamondback terrapin or simply terrapin is a species "
1508+ "of turtle native to the brackish coastal tidal marshes of the" ,
1509+ "parameters" : {"max_new_tokens" : 1024 },
1510+ }
1511+ sample_output = [
1512+ {
1513+ "generated_text" : "The diamondback terrapin or simply terrapin is a "
1514+ "species of turtle native to the brackish coastal "
1515+ "tidal marshes of the east coast."
1516+ }
1517+ ]
1518+
1519+ model_builder = ModelBuilder (
1520+ model = "meta-textgeneration-llama-3-70b" ,
1521+ schema_builder = SchemaBuilder (sample_input , sample_output ),
1522+ sagemaker_session = mock_sagemaker_session ,
1523+ )
1524+
1525+ optimized_model = model_builder .optimize (
1526+ accept_eula = True ,
1527+ instance_type = "ml.inf2.24xlarge" ,
1528+ output_path = "s3://bucket/code/" ,
1529+ )
1530+
1531+ self .assertEqual (
1532+ optimized_model .image_uri ,
1533+ mock_optimization_job_response ["OptimizationOutput" ]["RecommendedInferenceImage" ],
1534+ )
1535+ self .assertEqual (
1536+ optimized_model .model_data ["S3DataSource" ]["S3Uri" ],
1537+ mock_optimization_job_response ["OutputConfig" ]["S3OutputLocation" ],
1538+ )
1539+ self .assertEqual (optimized_model .env ["SAGEMAKER_PROGRAM" ], "inference.py" )
1540+ self .assertEqual (optimized_model .env ["ENDPOINT_SERVER_TIMEOUT" ], "3600" )
1541+ self .assertEqual (optimized_model .env ["MODEL_CACHE_ROOT" ], "/opt/ml/model" )
1542+ self .assertEqual (optimized_model .env ["SAGEMAKER_ENV" ], "1" )
1543+ self .assertEqual (optimized_model .env ["HF_MODEL_ID" ], "/opt/ml/model" )
1544+ self .assertEqual (optimized_model .env ["SAGEMAKER_MODEL_SERVER_WORKERS" ], "1" )
0 commit comments