13
13
from __future__ import absolute_import
14
14
from unittest .mock import MagicMock , patch , Mock , mock_open
15
15
16
- import pytest
17
-
18
16
import unittest
19
17
from pathlib import Path
20
18
from copy import deepcopy
21
19
20
+ from sagemaker .serve import SchemaBuilder
22
21
from sagemaker .serve .builder .model_builder import ModelBuilder
23
22
from sagemaker .serve .mode .function_pointers import Mode
24
23
from sagemaker .serve .model_format .mlflow .constants import MLFLOW_TRACKING_ARN
@@ -2328,22 +2327,52 @@ def test_build_tensorflow_serving_non_mlflow_case(
2328
2327
mock_session ,
2329
2328
)
2330
2329
2331
- @pytest .mark .skip (reason = "Implementation not completed" )
2330
+ @patch .object (ModelBuilder , "_prepare_for_mode" )
2331
+ @patch .object (ModelBuilder , "_build_for_djl" )
2332
+ @patch .object (ModelBuilder , "_is_jumpstart_model_id" , return_value = False )
2332
2333
@patch .object (ModelBuilder , "_get_serve_setting" , autospec = True )
2333
2334
@patch ("sagemaker.serve.utils.telemetry_logger._send_telemetry" )
2334
- def test_optimize (self , mock_send_telemetry , mock_get_serve_setting ):
2335
+ def test_optimize (
2336
+ self ,
2337
+ mock_send_telemetry ,
2338
+ mock_get_serve_setting ,
2339
+ mock_is_jumpstart_model_id ,
2340
+ mock_build_for_djl ,
2341
+ mock_prepare_for_mode ,
2342
+ ):
2335
2343
mock_sagemaker_session = Mock ()
2336
2344
2337
2345
mock_settings = Mock ()
2338
2346
mock_settings .telemetry_opt_out = False
2339
2347
mock_get_serve_setting .return_value = mock_settings
2340
2348
2349
+ pysdk_model = Mock ()
2350
+ pysdk_model .env = {"key" : "val" }
2351
+ pysdk_model .add_tags .side_effect = lambda * arg , ** kwargs : None
2352
+
2353
+ mock_build_for_djl .side_effect = lambda ** kwargs : pysdk_model
2354
+ mock_prepare_for_mode .side_effect = lambda * args , ** kwargs : (
2355
+ {
2356
+ "S3DataSource" : {
2357
+ "S3Uri" : "s3://uri" ,
2358
+ "S3DataType" : "S3Prefix" ,
2359
+ "CompressionType" : "None" ,
2360
+ }
2361
+ },
2362
+ {"key" : "val" },
2363
+ )
2364
+
2341
2365
builder = ModelBuilder (
2342
- model_path = MODEL_PATH ,
2343
- schema_builder = schema_builder ,
2344
- model = mock_fw_model ,
2366
+ schema_builder = SchemaBuilder (
2367
+ sample_input = {"inputs" : "Hello" , "parameters" : {}},
2368
+ sample_output = [{"generated_text" : "Hello" }],
2369
+ ),
2370
+ model = "meta-llama/Meta-Llama-3-8B" ,
2345
2371
sagemaker_session = mock_sagemaker_session ,
2372
+ env_vars = {"HF_TOKEN" : "token" },
2373
+ model_metadata = {"CUSTOM_MODEL_PATH" : "/tmp/modelbuilders/code" },
2346
2374
)
2375
+ builder .pysdk_model = pysdk_model
2347
2376
2348
2377
job_name = "my-optimization-job"
2349
2378
instance_type = "ml.inf1.xlarge"
@@ -2352,10 +2381,6 @@ def test_optimize(self, mock_send_telemetry, mock_get_serve_setting):
2352
2381
"Image" : "quantization-image-uri" ,
2353
2382
"OverrideEnvironment" : {"ENV_VAR" : "value" },
2354
2383
}
2355
- compilation_config = {
2356
- "Image" : "compilation-image-uri" ,
2357
- "OverrideEnvironment" : {"ENV_VAR" : "value" },
2358
- }
2359
2384
env_vars = {"Var1" : "value" , "Var2" : "value" }
2360
2385
kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id"
2361
2386
max_runtime_in_sec = 3600
@@ -2368,46 +2393,55 @@ def test_optimize(self, mock_send_telemetry, mock_get_serve_setting):
2368
2393
"Subnets" : ["subnet-01234567" , "subnet-89abcdef" ],
2369
2394
}
2370
2395
2371
- expected_create_optimization_job_args = {
2372
- "ModelSource" : {"S3" : {"S3Uri" : MODEL_PATH , "ModelAccessConfig" : {"AcceptEula" : True }}},
2373
- "DeploymentInstanceType" : instance_type ,
2374
- "OptimizationEnvironment" : env_vars ,
2375
- "OptimizationConfigs" : [
2376
- {"ModelQuantizationConfig" : quantization_config },
2377
- {"ModelCompilationConfig" : compilation_config },
2378
- ],
2379
- "OutputConfig" : {"S3OutputLocation" : output_path , "KmsKeyId" : kms_key },
2380
- "RoleArn" : mock_role_arn ,
2381
- "OptimizationJobName" : job_name ,
2382
- "StoppingCondition" : {"MaxRuntimeInSeconds" : max_runtime_in_sec },
2383
- "Tags" : [
2384
- {"Key" : "Project" , "Value" : "my-project" },
2385
- {"Key" : "Environment" , "Value" : "production" },
2386
- ],
2387
- "VpcConfig" : vpc_config ,
2388
- }
2389
-
2390
- mock_sagemaker_session .sagemaker_client .create_optimization_job .return_value = {
2391
- "OptimizationJobArn" : "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job"
2396
+ mock_sagemaker_session .wait_for_optimization_job .side_effect = lambda * args , ** kwargs : {
2397
+ "OptimizationJobArn" : "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job" ,
2398
+ "OptimizationJobName" : "my-optimization-job" ,
2392
2399
}
2393
2400
2394
2401
builder .optimize (
2395
2402
instance_type = instance_type ,
2396
2403
output_path = output_path ,
2397
- role = mock_role_arn ,
2404
+ role_arn = mock_role_arn ,
2398
2405
job_name = job_name ,
2399
2406
quantization_config = quantization_config ,
2400
- compilation_config = compilation_config ,
2401
2407
env_vars = env_vars ,
2402
2408
kms_key = kms_key ,
2403
2409
max_runtime_in_sec = max_runtime_in_sec ,
2404
2410
tags = tags ,
2405
2411
vpc_config = vpc_config ,
2406
2412
)
2407
2413
2414
+ self .assertEqual (builder .env_vars ["HUGGING_FACE_HUB_TOKEN" ], "token" )
2415
+ self .assertEqual (builder .model_server , ModelServer .DJL_SERVING )
2416
+
2408
2417
mock_send_telemetry .assert_called_once ()
2409
2418
mock_sagemaker_session .sagemaker_client .create_optimization_job .assert_called_once_with (
2410
- ** expected_create_optimization_job_args
2419
+ OptimizationJobName = "my-optimization-job" ,
2420
+ DeploymentInstanceType = "ml.inf1.xlarge" ,
2421
+ RoleArn = "arn:aws:iam::123456789012:role/SageMakerRole" ,
2422
+ OptimizationEnvironment = {"Var1" : "value" , "Var2" : "value" },
2423
+ ModelSource = {"S3" : {"S3Uri" : "s3://uri" }},
2424
+ OptimizationConfigs = [
2425
+ {
2426
+ "ModelQuantizationConfig" : {
2427
+ "Image" : "quantization-image-uri" ,
2428
+ "OverrideEnvironment" : {"ENV_VAR" : "value" },
2429
+ }
2430
+ }
2431
+ ],
2432
+ OutputConfig = {
2433
+ "S3OutputLocation" : "s3://my-bucket/output" ,
2434
+ "KmsKeyId" : "arn:aws:kms:us-west-2:123456789012:key/my-key-id" ,
2435
+ },
2436
+ StoppingCondition = {"MaxRuntimeInSeconds" : 3600 },
2437
+ Tags = [
2438
+ {"Key" : "Project" , "Value" : "my-project" },
2439
+ {"Key" : "Environment" , "Value" : "production" },
2440
+ ],
2441
+ VpcConfig = {
2442
+ "SecurityGroupIds" : ["sg-01234567890abcdef" , "sg-fedcba9876543210" ],
2443
+ "Subnets" : ["subnet-01234567" , "subnet-89abcdef" ],
2444
+ },
2411
2445
)
2412
2446
2413
2447
def test_handle_mlflow_input_without_mlflow_model_path (self ):
@@ -2649,7 +2683,7 @@ def test_optimize_for_hf_with_custom_s3_path(
2649
2683
2650
2684
model_builder = ModelBuilder (
2651
2685
model = "meta-llama/Meta-Llama-3-8B-Instruct" ,
2652
- env_vars = {"HUGGING_FACE_HUB_TOKEN " : "token" },
2686
+ env_vars = {"HF_TOKEN " : "token" },
2653
2687
model_metadata = {
2654
2688
"CUSTOM_MODEL_PATH" : "s3://bucket/path/" ,
2655
2689
},
@@ -2667,6 +2701,7 @@ def test_optimize_for_hf_with_custom_s3_path(
2667
2701
output_path = "s3://bucket/code/" ,
2668
2702
)
2669
2703
2704
+ self .assertEqual (model_builder .env_vars ["HF_TOKEN" ], "token" )
2670
2705
self .assertEqual (model_builder .role_arn , "role-arn" )
2671
2706
self .assertEqual (model_builder .instance_type , "ml.g5.2xlarge" )
2672
2707
self .assertEqual (model_builder .pysdk_model .env ["OPTION_QUANTIZE" ], "awq" )
0 commit comments