Skip to content

Commit 76a4102

Browse files
committed
fix UTs
1 parent 955479a commit 76a4102

File tree

2 files changed

+45
-42
lines changed

2 files changed

+45
-42
lines changed

src/sagemaker/serve/validations/optimization.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,21 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
5757

5858
# optimization technique combinations that need to be validated
5959
if optimization_combination.compilation and optimization_combination.speculative_decoding:
60-
copy_compilation = optimization_combination.compilation.copy()
61-
copy_speculative_decoding = optimization_combination.speculative_decoding.copy()
62-
if (
63-
copy_compilation.pop() and copy_speculative_decoding.pop()
64-
): # Check that the 2 techniques are not None
60+
is_compiled = optimization_combination.compilation.copy().pop()
61+
is_speculative_decoding = optimization_combination.speculative_decoding.copy().pop()
62+
if is_compiled and is_speculative_decoding:
6563
raise ValueError("Compilation and Speculative Decoding together")
6664

6765
if rule_set == _OptimizationContainer.TRT:
68-
if (
69-
optimization_combination.compilation
70-
and not optimization_combination.quantization_technique
71-
or not optimization_combination.compilation
72-
and optimization_combination.quantization_technique
73-
):
66+
is_compiled = optimization_combination.compilation.copy().pop()
67+
is_quantized = optimization_combination.quantization_technique.copy().pop()
68+
if is_compiled and not is_quantized or is_quantized and not is_compiled:
7469
raise ValueError("Compilation must be provided with Quantization")
7570
else:
76-
copy_compilation = optimization_combination.compilation.copy()
77-
copy_quantization_technique = optimization_combination.quantization_technique.copy()
71+
is_compiled = optimization_combination.compilation.copy().pop()
72+
is_quantization_technique = optimization_combination.quantization_technique.copy().pop()
7873
if (
79-
copy_compilation.pop() and copy_quantization_technique.pop()
74+
is_compiled and is_quantization_technique
8075
): # Check that the 2 techniques are not None
8176
raise ValueError(
8277
f"Compilation and Quantization:{optimization_combination.quantization_technique.pop()}"
@@ -99,8 +94,8 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
9994
"supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"},
10095
"optimization_combination": _OptimizationCombination(
10196
optimization_container=_OptimizationContainer.VLLM,
102-
compilation=TRUTHY_SET,
103-
quantization_technique={None, "awq", "fp8"},
97+
compilation=FALSY_SET,
98+
quantization_technique={None},
10499
speculative_decoding=TRUTHY_SET,
105100
sharding=TRUTHY_SET,
106101
),
@@ -203,8 +198,9 @@ def _validate_optimization_configuration(
203198
optimization_combination, rule_set=_OptimizationContainer.VLLM
204199
)
205200
)
201+
print("fsdafas")
206202
except ValueError as vllm_compare_error:
207-
if trt_compare_error == "Compilation must be provided with Quantization":
203+
if str(trt_compare_error) == "Compilation must be provided with Quantization":
208204
joint_error_msg = f"""
209205
Optimization cannot be performed for the following reasons:
210206
- Optimizations that use {trt_compare_error} and vice-versa for GPU instances.

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2912,37 +2912,44 @@ def test_neuron_configurations_throw_errors_for_rule_set(self):
29122912

29132913
def test_trt_configurations_rule_set(self):
29142914
# Can be quantized
2915-
_validate_optimization_configuration(
2916-
instance_type="ml.g5.24xlarge",
2917-
quantization_config={
2918-
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
2919-
},
2920-
sharding_config=None,
2921-
speculative_decoding_config=None,
2922-
compilation_config=None,
2915+
expected_compilation_quantization_error_message = """
2916+
Optimization cannot be performed for the following reasons:
2917+
- Optimizations that use Compilation must be provided with Quantization and vice-versa for GPU instances.
2918+
- Optimizations that use Quantization:awq are not supported for GPU instances.
2919+
"""
2920+
self.assertRaisesRegex(
2921+
ValueError,
2922+
textwrap.dedent(expected_compilation_quantization_error_message),
2923+
lambda: _validate_optimization_configuration(
2924+
instance_type="ml.g5.24xlarge",
2925+
quantization_config={
2926+
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
2927+
},
2928+
sharding_config=None,
2929+
speculative_decoding_config=None,
2930+
compilation_config=None,
2931+
),
29232932
)
29242933

29252934
# Can be compiled
2926-
_validate_optimization_configuration(
2927-
instance_type="ml.g5.24xlarge",
2928-
quantization_config=None,
2929-
sharding_config=None,
2930-
speculative_decoding_config=None,
2931-
compilation_config={"key": "value"},
2935+
expected_compilation_quantization_error_message = """
2936+
Optimization cannot be performed for the following reasons:
2937+
- Optimizations that use Compilation must be provided with Quantization and vice-versa for GPU instances.
2938+
- Optimizations that use Compilation are not supported for GPU instances.
2939+
"""
2940+
self.assertRaisesRegex(
2941+
ValueError,
2942+
textwrap.dedent(expected_compilation_quantization_error_message),
2943+
lambda: _validate_optimization_configuration(
2944+
instance_type="ml.g5.24xlarge",
2945+
quantization_config=None,
2946+
sharding_config=None,
2947+
speculative_decoding_config=None,
2948+
compilation_config={"key": "value"},
2949+
),
29322950
)
29332951

29342952
def test_vllm_configurations_rule_set(self):
2935-
# Can be quantized
2936-
_validate_optimization_configuration(
2937-
instance_type="ml.g5.24xlarge",
2938-
quantization_config={
2939-
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
2940-
},
2941-
sharding_config=None,
2942-
speculative_decoding_config=None,
2943-
compilation_config=None,
2944-
)
2945-
29462953
# Can use speculative decoding
29472954
_validate_optimization_configuration(
29482955
instance_type="ml.g5.24xlarge",

0 commit comments

Comments
 (0)