|
19 | 19 | from sagemaker.enums import Tag
|
20 | 20 | from sagemaker.serve.utils.optimize_utils import (
|
21 | 21 | _generate_optimized_model,
|
22 |
| - _is_inferentia_or_trainium, |
23 | 22 | _update_environment_variables,
|
24 | 23 | _is_image_compatible_with_optimization_job,
|
25 | 24 | _extract_speculative_draft_model_provider,
|
26 | 25 | _validate_optimization_inputs,
|
27 | 26 | _extracts_and_validates_speculative_model_source,
|
| 27 | + _is_s3_uri, |
| 28 | + _generate_additional_model_data_sources, |
| 29 | + _generate_channel_name, |
28 | 30 | )
|
29 | 31 |
|
30 | 32 | mock_optimization_job_output = {
|
31 |
| - "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:312206380606:" |
32 |
| - "optimization-job/modelbuilderjob-6b09ffebeb0741b8a28b85623fd9c968", |
| 33 | + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:312206380606:optimization-job/" |
| 34 | + "modelbuilderjob-3cbf9c40b63c455d85b60033f9a01691", |
33 | 35 | "OptimizationJobStatus": "COMPLETED",
|
34 |
| - "OptimizationJobName": "modelbuilderjob-6b09ffebeb0741b8a28b85623fd9c968", |
| 36 | + "OptimizationJobName": "modelbuilderjob-3cbf9c40b63c455d85b60033f9a01691", |
35 | 37 | "ModelSource": {
|
36 | 38 | "S3": {
|
37 | 39 | "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/"
|
|
46 | 48 | "SAGEMAKER_MODEL_SERVER_WORKERS": "1",
|
47 | 49 | "SAGEMAKER_PROGRAM": "inference.py",
|
48 | 50 | },
|
49 |
| - "DeploymentInstanceType": "ml.g5.48xlarge", |
| 51 | + "DeploymentInstanceType": "ml.g5.2xlarge", |
50 | 52 | "OptimizationConfigs": [
|
51 | 53 | {
|
52 | 54 | "ModelQuantizationConfig": {
|
|
55 | 57 | }
|
56 | 58 | }
|
57 | 59 | ],
|
58 |
| - "OutputConfig": { |
59 |
| - "S3OutputLocation": "s3://dont-delete-ss-jarvis-integ-test-312206380606-us-west-2/" |
60 |
| - }, |
| 60 | + "OutputConfig": {"S3OutputLocation": "s3://quicksilver-model-data/llama-3-8b/quantized-1/"}, |
61 | 61 | "OptimizationOutput": {
|
62 | 62 | "RecommendedInferenceImage": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124"
|
63 | 63 | },
|
64 |
| - "RoleArn": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628", |
| 64 | + "RoleArn": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20240116T151132", |
65 | 65 | "StoppingCondition": {"MaxRuntimeInSeconds": 36000},
|
66 | 66 | "ResponseMetadata": {
|
67 |
| - "RequestId": "17ae151f-b51d-4194-8ba9-edbba068c90b", |
| 67 | + "RequestId": "a95253d5-c045-4708-8aac-9f0d327515f7", |
68 | 68 | "HTTPStatusCode": 200,
|
69 | 69 | "HTTPHeaders": {
|
70 |
| - "x-amzn-requestid": "17ae151f-b51d-4194-8ba9-edbba068c90b", |
| 70 | + "x-amzn-requestid": "a95253d5-c045-4708-8aac-9f0d327515f7", |
71 | 71 | "content-type": "application/x-amz-json-1.1",
|
72 |
| - "content-length": "1380", |
73 |
| - "date": "Thu, 20 Jun 2024 19:25:53 GMT", |
| 72 | + "content-length": "1371", |
| 73 | + "date": "Fri, 21 Jun 2024 04:27:42 GMT", |
74 | 74 | },
|
75 | 75 | "RetryAttempts": 0,
|
76 | 76 | },
|
77 | 77 | }
|
78 | 78 |
|
79 | 79 |
|
80 |
| -@pytest.mark.parametrize( |
81 |
| - "instance, expected", |
82 |
| - [ |
83 |
| - ("ml.trn1.2xlarge", True), |
84 |
| - ("ml.inf2.xlarge", True), |
85 |
| - ("ml.c7gd.4xlarge", False), |
86 |
| - ], |
87 |
| -) |
88 |
| -def test_is_inferentia_or_trainium(instance, expected): |
89 |
| - assert _is_inferentia_or_trainium(instance) == expected |
90 |
| - |
91 |
| - |
92 | 80 | @pytest.mark.parametrize(
|
93 | 81 | "image_uri, expected",
|
94 | 82 | [
|
@@ -124,17 +112,21 @@ def test_generate_optimized_model():
|
124 | 112 | "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/"
|
125 | 113 | }
|
126 | 114 | }
|
| 115 | + pysdk_model.env = {"OPTION_QUANTIZE": "awq"} |
127 | 116 |
|
128 | 117 | optimized_model = _generate_optimized_model(pysdk_model, mock_optimization_job_output)
|
129 | 118 |
|
130 | 119 | assert (
|
131 | 120 | optimized_model.image_uri
|
132 | 121 | == mock_optimization_job_output["OptimizationOutput"]["RecommendedInferenceImage"]
|
133 | 122 | )
|
134 |
| - assert optimized_model.env == mock_optimization_job_output["OptimizationEnvironment"] |
| 123 | + assert optimized_model.env == { |
| 124 | + "OPTION_QUANTIZE": "awq", |
| 125 | + **mock_optimization_job_output["OptimizationEnvironment"], |
| 126 | + } |
135 | 127 | assert (
|
136 | 128 | optimized_model.model_data["S3DataSource"]["S3Uri"]
|
137 |
| - == mock_optimization_job_output["ModelSource"]["S3"] |
| 129 | + == mock_optimization_job_output["OutputConfig"]["S3OutputLocation"] |
138 | 130 | )
|
139 | 131 | assert optimized_model.instance_type == mock_optimization_job_output["DeploymentInstanceType"]
|
140 | 132 | pysdk_model.add_tags.assert_called_once_with(
|
@@ -209,3 +201,61 @@ def test_extract_speculative_draft_model_s3_uri():
|
209 | 201 | def test_extract_speculative_draft_model_s3_uri_ex():
|
210 | 202 | with pytest.raises(ValueError):
|
211 | 203 | _extracts_and_validates_speculative_model_source({"ModelSource": None})
|
| 204 | + |
| 205 | + |
| 206 | +def test_generate_channel_name(): |
| 207 | + assert _generate_channel_name(None) is not None |
| 208 | + |
| 209 | + additional_model_data_sources = _generate_additional_model_data_sources( |
| 210 | + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", True |
| 211 | + ) |
| 212 | + |
| 213 | + assert _generate_channel_name(additional_model_data_sources) == "channel_name" |
| 214 | + |
| 215 | + |
| 216 | +def test_generate_additional_model_data_sources(): |
| 217 | + model_source = _generate_additional_model_data_sources( |
| 218 | + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", True |
| 219 | + ) |
| 220 | + |
| 221 | + assert model_source == [ |
| 222 | + { |
| 223 | + "ChannelName": "channel_name", |
| 224 | + "S3DataSource": { |
| 225 | + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", |
| 226 | + "S3DataType": "S3Prefix", |
| 227 | + "CompressionType": "None", |
| 228 | + "ModelAccessConfig": {"ACCEPT_EULA": True}, |
| 229 | + }, |
| 230 | + } |
| 231 | + ] |
| 232 | + |
| 233 | + model_source = _generate_additional_model_data_sources( |
| 234 | + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", False |
| 235 | + ) |
| 236 | + |
| 237 | + assert model_source == [ |
| 238 | + { |
| 239 | + "ChannelName": "channel_name", |
| 240 | + "S3DataSource": { |
| 241 | + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", |
| 242 | + "S3DataType": "S3Prefix", |
| 243 | + "CompressionType": "None", |
| 244 | + }, |
| 245 | + } |
| 246 | + ] |
| 247 | + |
| 248 | + |
| 249 | +@pytest.mark.parametrize( |
| 250 | + "s3_uri, expected", |
| 251 | + [ |
| 252 | + ( |
| 253 | + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" |
| 254 | + "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/", |
| 255 | + True, |
| 256 | + ), |
| 257 | + ("invalid://", False), |
| 258 | + ], |
| 259 | +) |
| 260 | +def test_is_s3_uri(s3_uri, expected): |
| 261 | + assert _is_s3_uri(s3_uri) == expected |
0 commit comments