Skip to content

Commit 57123c9

Browse files
committed
add UTs
1 parent 22fdc37 commit 57123c9

File tree

2 files changed

+107
-31
lines changed

2 files changed

+107
-31
lines changed

src/sagemaker/serve/validations/optimization.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ class _OptimizationCombination(BaseModel):
4141

4242
def validate_against(self, optimization_combination, rule_set: _OptimizationContainer):
4343
"""Validator for optimization containers"""
44+
print(optimization_combination)
45+
print(rule_set)
46+
print(optimization_combination.speculative_decoding.issubset(self.speculative_decoding))
4447

4548
# check the case where no optimization combination is provided
4649
if (
@@ -49,7 +52,7 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
4952
and optimization_combination.speculative_decoding == {None}
5053
and optimization_combination.sharding == {None}
5154
):
52-
raise ValueError("Optimizations are not currently supported without optimization configurations.")
55+
raise ValueError("no optimization configurations")
5356

5457
# check the validity of each individual field
5558
if not optimization_combination.compilation.issubset(self.compilation):
@@ -58,9 +61,7 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
5861
self.quantization_technique
5962
):
6063
copy_quantization_technique = optimization_combination.quantization_technique.copy()
61-
raise ValueError(
62-
f"Quantization:{copy_quantization_technique.pop()}"
63-
)
64+
raise ValueError(f"Quantization:{copy_quantization_technique.pop()}")
6465
if not optimization_combination.speculative_decoding.issubset(self.speculative_decoding):
6566
raise ValueError("Speculative Decoding")
6667
if not optimization_combination.sharding.issubset(self.sharding):
@@ -75,16 +76,14 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
7576
copy_compilation = optimization_combination.compilation.copy()
7677
copy_speculative_decoding = optimization_combination.speculative_decoding.copy()
7778
if (
78-
copy_compilation.pop()
79-
and copy_speculative_decoding.pop()
79+
copy_compilation.pop() and copy_speculative_decoding.pop()
8080
): # Check that the 2 techniques are not None
8181
raise ValueError("Compilation and Speculative Decoding")
8282
else:
8383
copy_compilation = optimization_combination.compilation.copy()
8484
copy_quantization_technique = optimization_combination.quantization_technique.copy()
8585
if (
86-
copy_compilation.pop()
87-
and copy_quantization_technique.pop()
86+
copy_compilation.pop() and copy_quantization_technique.pop()
8887
): # Check that the 2 techniques are not None
8988
raise ValueError(
9089
f"Compilation and Quantization:{optimization_combination.quantization_technique.pop()}"
@@ -161,26 +160,24 @@ def _validate_optimization_configuration(
161160
and quantization_config.get("OverrideEnvironment")
162161
and quantization_config.get("OverrideEnvironment").get("OPTION_QUANTIZE")
163162
):
164-
quantization_technique = quantization_config.get("OverrideEnvironment").get("OPTION_QUANTIZE")
163+
quantization_technique = quantization_config.get("OverrideEnvironment").get(
164+
"OPTION_QUANTIZE"
165+
)
165166

166167
optimization_combination = _OptimizationCombination(
167-
compilation={
168-
None if compilation_config is None else bool(compilation_config)
169-
},
168+
compilation={None if compilation_config is None else bool(compilation_config)},
170169
speculative_decoding={
171170
None if speculative_decoding_config is None else bool(speculative_decoding_config)
172171
},
173-
sharding={
174-
None if sharding_config is None else bool(sharding_config)
175-
},
172+
sharding={None if sharding_config is None else bool(sharding_config)},
176173
quantization_technique={quantization_technique},
177174
)
178175

179-
if instance_type in NEURON_CONFIGURATION["supported_instance_families"]:
176+
if instance_family in NEURON_CONFIGURATION["supported_instance_families"]:
180177
try:
181178
(
182179
NEURON_CONFIGURATION["optimization_combination"].validate_against(
183-
optimization_combination, rule_set=_OptimizationContainer.VLLM
180+
optimization_combination, rule_set=_OptimizationContainer.NEURON
184181
)
185182
)
186183
except ValueError as neuron_compare_error:
@@ -209,7 +206,7 @@ def _validate_optimization_configuration(
209206
trt_error_msg = VALIDATION_ERROR_MSG.format(
210207
optimization_container=_OptimizationContainer.TRT.value,
211208
optimization_technique=str(trt_compare_error),
212-
instance_type="GPU"
209+
instance_type="GPU",
213210
)
214211
vllm_error_msg = VALIDATION_ERROR_MSG.format(
215212
optimization_container=_OptimizationContainer.VLLM.value,

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 92 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2386,11 +2386,11 @@ def test_optimize(
23862386
builder.pysdk_model = pysdk_model
23872387

23882388
job_name = "my-optimization-job"
2389-
instance_type = "ml.inf2.xlarge"
2389+
instance_type = "ml.g5.24xlarge"
23902390
output_path = "s3://my-bucket/output"
23912391
quantization_config = {
23922392
"Image": "quantization-image-uri",
2393-
"OverrideEnvironment": {"ENV_VAR": "value"},
2393+
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
23942394
}
23952395
env_vars = {"Var1": "value", "Var2": "value"}
23962396
kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id"
@@ -2428,15 +2428,15 @@ def test_optimize(
24282428
mock_send_telemetry.assert_called_once()
24292429
mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with(
24302430
OptimizationJobName="my-optimization-job",
2431-
DeploymentInstanceType="ml.inf2.xlarge",
2431+
DeploymentInstanceType="ml.g5.24xlarge",
24322432
RoleArn="arn:aws:iam::123456789012:role/SageMakerRole",
24332433
OptimizationEnvironment={"Var1": "value", "Var2": "value"},
24342434
ModelSource={"S3": {"S3Uri": "s3://uri"}},
24352435
OptimizationConfigs=[
24362436
{
24372437
"ModelQuantizationConfig": {
24382438
"Image": "quantization-image-uri",
2439-
"OverrideEnvironment": {"ENV_VAR": "value"},
2439+
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
24402440
}
24412441
}
24422442
],
@@ -2650,7 +2650,7 @@ def test_optimize_local_mode(self, mock_get_serve_setting):
26502650
"Model optimization is only supported in Sagemaker Endpoint Mode.",
26512651
lambda: model_builder.optimize(
26522652
instance_type="ml.g5.24xlarge",
2653-
quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}
2653+
quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
26542654
),
26552655
)
26562656

@@ -2842,16 +2842,22 @@ def test_corner_cases_throw_errors(self):
28422842
ValueError,
28432843
"Optimizations that uses None instance type are not currently supported",
28442844
lambda: _validate_optimization_configuration(
2845-
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
2845+
sharding_config={"key": "value"},
28462846
instance_type=None,
28472847
quantization_config=None,
28482848
speculative_decoding_config=None,
28492849
compilation_config=None,
28502850
),
28512851
)
2852+
2853+
expected_missing_optimization_configs_error_message = """
2854+
Optimization cannot be performed for the following reasons:
2855+
- Optimizations for TRT that use no optimization configurations are not currently supported on GPU instances
2856+
- Optimizations for vLLM that use no optimization configurations are not currently supported on GPU instances
2857+
"""
28522858
self.assertRaisesRegex(
28532859
ValueError,
2854-
"Optimizations are not currently supported without optimization configurations.",
2860+
textwrap.dedent(expected_missing_optimization_configs_error_message),
28552861
lambda: _validate_optimization_configuration(
28562862
instance_type="ml.g5.24xlarge",
28572863
quantization_config=None,
@@ -2881,11 +2887,39 @@ def test_trt_and_vllm_configurations_throw_errors_for_rule_set(self):
28812887
),
28822888
)
28832889

2884-
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
2885-
def test_neuron_configurations_throw_errors_for_rule_set(self, mock_get_serve_setting):
2886-
pass
2890+
def test_neuron_configurations_throw_errors_for_rule_set(self):
2891+
self.assertRaisesRegex(
2892+
ValueError,
2893+
(
2894+
"Optimizations for Neuron that use Speculative Decoding "
2895+
"are not currently supported on Neuron instances"
2896+
),
2897+
lambda: _validate_optimization_configuration(
2898+
instance_type="ml.inf2.xlarge",
2899+
quantization_config=None,
2900+
speculative_decoding_config={"key": "value"},
2901+
compilation_config=None,
2902+
sharding_config=None,
2903+
),
2904+
)
2905+
2906+
self.assertRaisesRegex(
2907+
ValueError,
2908+
(
2909+
"Optimizations for Neuron that use Sharding "
2910+
"are not currently supported on Neuron instances"
2911+
),
2912+
lambda: _validate_optimization_configuration(
2913+
instance_type="ml.inf2.xlarge",
2914+
quantization_config=None,
2915+
speculative_decoding_config=None,
2916+
compilation_config=None,
2917+
sharding_config={"key": "value"},
2918+
),
2919+
)
28872920

28882921
def test_trt_configurations_rule_set(self):
2922+
# Can be quantized
28892923
_validate_optimization_configuration(
28902924
instance_type="ml.g5.24xlarge",
28912925
quantization_config={
@@ -2896,6 +2930,51 @@ def test_trt_configurations_rule_set(self):
28962930
compilation_config=None,
28972931
)
28982932

2899-
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
2900-
def test_vllm_configurations_rule_set(self, mock_get_serve_setting):
2901-
pass
2933+
# Can be compiled
2934+
_validate_optimization_configuration(
2935+
instance_type="ml.g5.24xlarge",
2936+
quantization_config=None,
2937+
sharding_config=None,
2938+
speculative_decoding_config=None,
2939+
compilation_config={"key": "value"},
2940+
)
2941+
2942+
def test_vllm_configurations_rule_set(self):
2943+
# Can be quantized
2944+
_validate_optimization_configuration(
2945+
instance_type="ml.g5.24xlarge",
2946+
quantization_config={
2947+
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
2948+
},
2949+
sharding_config=None,
2950+
speculative_decoding_config=None,
2951+
compilation_config=None,
2952+
)
2953+
2954+
# Can use speculative decoding
2955+
_validate_optimization_configuration(
2956+
instance_type="ml.g5.24xlarge",
2957+
quantization_config=None,
2958+
sharding_config=None,
2959+
speculative_decoding_config={"key": "value"},
2960+
compilation_config=None,
2961+
)
2962+
2963+
# Can be sharded
2964+
_validate_optimization_configuration(
2965+
instance_type="ml.g5.24xlarge",
2966+
quantization_config=None,
2967+
sharding_config={"key": "value"},
2968+
speculative_decoding_config=None,
2969+
compilation_config=None,
2970+
)
2971+
2972+
def test_neuron_configurations_rule_set(self):
2973+
# Can be compiled
2974+
_validate_optimization_configuration(
2975+
instance_type="ml.inf2.xlarge",
2976+
quantization_config=None,
2977+
sharding_config=None,
2978+
speculative_decoding_config=None,
2979+
compilation_config={"key": "value"},
2980+
)

0 commit comments

Comments
 (0)