Skip to content

Commit f68b71b

Browse files
author
Jonathan Makunga
committed
Unit tests
1 parent 2d5cc56 commit f68b71b

File tree

1 file changed

+71
-36
lines changed

1 file changed

+71
-36
lines changed

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 71 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
from __future__ import absolute_import
1414
from unittest.mock import MagicMock, patch, Mock, mock_open
1515

16-
import pytest
17-
1816
import unittest
1917
from pathlib import Path
2018
from copy import deepcopy
2119

20+
from sagemaker.serve import SchemaBuilder
2221
from sagemaker.serve.builder.model_builder import ModelBuilder
2322
from sagemaker.serve.mode.function_pointers import Mode
2423
from sagemaker.serve.model_format.mlflow.constants import MLFLOW_TRACKING_ARN
@@ -2328,22 +2327,52 @@ def test_build_tensorflow_serving_non_mlflow_case(
23282327
mock_session,
23292328
)
23302329

2331-
@pytest.mark.skip(reason="Implementation not completed")
2330+
@patch.object(ModelBuilder, "_prepare_for_mode")
2331+
@patch.object(ModelBuilder, "_build_for_djl")
2332+
@patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=False)
23322333
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
23332334
@patch("sagemaker.serve.utils.telemetry_logger._send_telemetry")
2334-
def test_optimize(self, mock_send_telemetry, mock_get_serve_setting):
2335+
def test_optimize(
2336+
self,
2337+
mock_send_telemetry,
2338+
mock_get_serve_setting,
2339+
mock_is_jumpstart_model_id,
2340+
mock_build_for_djl,
2341+
mock_prepare_for_mode,
2342+
):
23352343
mock_sagemaker_session = Mock()
23362344

23372345
mock_settings = Mock()
23382346
mock_settings.telemetry_opt_out = False
23392347
mock_get_serve_setting.return_value = mock_settings
23402348

2349+
pysdk_model = Mock()
2350+
pysdk_model.env = {"key": "val"}
2351+
pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None
2352+
2353+
mock_build_for_djl.side_effect = lambda **kwargs: pysdk_model
2354+
mock_prepare_for_mode.side_effect = lambda *args, **kwargs: (
2355+
{
2356+
"S3DataSource": {
2357+
"S3Uri": "s3://uri",
2358+
"S3DataType": "S3Prefix",
2359+
"CompressionType": "None",
2360+
}
2361+
},
2362+
{"key": "val"},
2363+
)
2364+
23412365
builder = ModelBuilder(
2342-
model_path=MODEL_PATH,
2343-
schema_builder=schema_builder,
2344-
model=mock_fw_model,
2366+
schema_builder=SchemaBuilder(
2367+
sample_input={"inputs": "Hello", "parameters": {}},
2368+
sample_output=[{"generated_text": "Hello"}],
2369+
),
2370+
model="meta-llama/Meta-Llama-3-8B",
23452371
sagemaker_session=mock_sagemaker_session,
2372+
env_vars={"HF_TOKEN": "token"},
2373+
model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"},
23462374
)
2375+
builder.pysdk_model = pysdk_model
23472376

23482377
job_name = "my-optimization-job"
23492378
instance_type = "ml.inf1.xlarge"
@@ -2352,10 +2381,6 @@ def test_optimize(self, mock_send_telemetry, mock_get_serve_setting):
23522381
"Image": "quantization-image-uri",
23532382
"OverrideEnvironment": {"ENV_VAR": "value"},
23542383
}
2355-
compilation_config = {
2356-
"Image": "compilation-image-uri",
2357-
"OverrideEnvironment": {"ENV_VAR": "value"},
2358-
}
23592384
env_vars = {"Var1": "value", "Var2": "value"}
23602385
kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id"
23612386
max_runtime_in_sec = 3600
@@ -2368,46 +2393,55 @@ def test_optimize(self, mock_send_telemetry, mock_get_serve_setting):
23682393
"Subnets": ["subnet-01234567", "subnet-89abcdef"],
23692394
}
23702395

2371-
expected_create_optimization_job_args = {
2372-
"ModelSource": {"S3": {"S3Uri": MODEL_PATH, "ModelAccessConfig": {"AcceptEula": True}}},
2373-
"DeploymentInstanceType": instance_type,
2374-
"OptimizationEnvironment": env_vars,
2375-
"OptimizationConfigs": [
2376-
{"ModelQuantizationConfig": quantization_config},
2377-
{"ModelCompilationConfig": compilation_config},
2378-
],
2379-
"OutputConfig": {"S3OutputLocation": output_path, "KmsKeyId": kms_key},
2380-
"RoleArn": mock_role_arn,
2381-
"OptimizationJobName": job_name,
2382-
"StoppingCondition": {"MaxRuntimeInSeconds": max_runtime_in_sec},
2383-
"Tags": [
2384-
{"Key": "Project", "Value": "my-project"},
2385-
{"Key": "Environment", "Value": "production"},
2386-
],
2387-
"VpcConfig": vpc_config,
2388-
}
2389-
2390-
mock_sagemaker_session.sagemaker_client.create_optimization_job.return_value = {
2391-
"OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job"
2396+
mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: {
2397+
"OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job",
2398+
"OptimizationJobName": "my-optimization-job",
23922399
}
23932400

23942401
builder.optimize(
23952402
instance_type=instance_type,
23962403
output_path=output_path,
2397-
role=mock_role_arn,
2404+
role_arn=mock_role_arn,
23982405
job_name=job_name,
23992406
quantization_config=quantization_config,
2400-
compilation_config=compilation_config,
24012407
env_vars=env_vars,
24022408
kms_key=kms_key,
24032409
max_runtime_in_sec=max_runtime_in_sec,
24042410
tags=tags,
24052411
vpc_config=vpc_config,
24062412
)
24072413

2414+
self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token")
2415+
self.assertEqual(builder.model_server, ModelServer.DJL_SERVING)
2416+
24082417
mock_send_telemetry.assert_called_once()
24092418
mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with(
2410-
**expected_create_optimization_job_args
2419+
OptimizationJobName="my-optimization-job",
2420+
DeploymentInstanceType="ml.inf1.xlarge",
2421+
RoleArn="arn:aws:iam::123456789012:role/SageMakerRole",
2422+
OptimizationEnvironment={"Var1": "value", "Var2": "value"},
2423+
ModelSource={"S3": {"S3Uri": "s3://uri"}},
2424+
OptimizationConfigs=[
2425+
{
2426+
"ModelQuantizationConfig": {
2427+
"Image": "quantization-image-uri",
2428+
"OverrideEnvironment": {"ENV_VAR": "value"},
2429+
}
2430+
}
2431+
],
2432+
OutputConfig={
2433+
"S3OutputLocation": "s3://my-bucket/output",
2434+
"KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id",
2435+
},
2436+
StoppingCondition={"MaxRuntimeInSeconds": 3600},
2437+
Tags=[
2438+
{"Key": "Project", "Value": "my-project"},
2439+
{"Key": "Environment", "Value": "production"},
2440+
],
2441+
VpcConfig={
2442+
"SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"],
2443+
"Subnets": ["subnet-01234567", "subnet-89abcdef"],
2444+
},
24112445
)
24122446

24132447
def test_handle_mlflow_input_without_mlflow_model_path(self):
@@ -2649,7 +2683,7 @@ def test_optimize_for_hf_with_custom_s3_path(
26492683

26502684
model_builder = ModelBuilder(
26512685
model="meta-llama/Meta-Llama-3-8B-Instruct",
2652-
env_vars={"HUGGING_FACE_HUB_TOKEN": "token"},
2686+
env_vars={"HF_TOKEN": "token"},
26532687
model_metadata={
26542688
"CUSTOM_MODEL_PATH": "s3://bucket/path/",
26552689
},
@@ -2667,6 +2701,7 @@ def test_optimize_for_hf_with_custom_s3_path(
26672701
output_path="s3://bucket/code/",
26682702
)
26692703

2704+
self.assertEqual(model_builder.env_vars["HF_TOKEN"], "token")
26702705
self.assertEqual(model_builder.role_arn, "role-arn")
26712706
self.assertEqual(model_builder.instance_type, "ml.g5.2xlarge")
26722707
self.assertEqual(model_builder.pysdk_model.env["OPTION_QUANTIZE"], "awq")

0 commit comments

Comments
 (0)