Skip to content

Commit 06ed0e3

Browse files
committed
update ruleset
1 parent 062e29b commit 06ed0e3

File tree

2 files changed

+56
-50
lines changed

2 files changed

+56
-50
lines changed

src/sagemaker/serve/validations/optimization.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
6565
if rule_set == _OptimizationContainer.TRT:
6666
is_compiled = optimization_combination.compilation.copy().pop()
6767
is_quantized = optimization_combination.quantization_technique.copy().pop()
68-
if is_compiled and not is_quantized or is_quantized and not is_compiled:
69-
raise ValueError(f"Compilation must be provided with Quantization")
68+
if is_quantized and not is_compiled:
69+
raise ValueError(f"Quantization must be provided with Compilation")
7070

7171

7272
TRUTHY_SET = {None, True}
@@ -76,7 +76,7 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
7676
"optimization_combination": _OptimizationCombination(
7777
optimization_container=_OptimizationContainer.TRT,
7878
compilation=TRUTHY_SET,
79-
quantization_technique={None, "awq", "fp8", "smooth_quant"},
79+
quantization_technique={None, "awq", "fp8", "smoothquant"},
8080
speculative_decoding=FALSY_SET,
8181
sharding=FALSY_SET,
8282
),
@@ -189,18 +189,23 @@ def _validate_optimization_configuration(
189189
optimization_combination, rule_set=_OptimizationContainer.VLLM
190190
)
191191
)
192-
print("fsdafas")
193192
except ValueError as vllm_compare_error:
194-
if "Compilation must be provided with Quantization" in str(trt_compare_error):
193+
if "Quantization must be provided with Compilation" in str(trt_compare_error):
195194
joint_error_msg = f"""
196195
Optimization cannot be performed for the following reasons:
197-
- Optimizations that use {trt_compare_error} and vice-versa for GPU instances.
196+
- Optimizations that use {trt_compare_error} for GPU instances.
198197
- Optimizations that use {vllm_compare_error} are not supported for GPU instances.
199198
"""
200199
else:
201-
joint_error_msg = f"""
202-
Optimization cannot be performed for the following reasons:
203-
- Optimizations that use {trt_compare_error} are not supported for GPU instances.
204-
- Optimizations that use {vllm_compare_error} are not supported for GPU instances.
205-
"""
200+
if str(trt_compare_error) == str(vllm_compare_error):
201+
joint_error_msg = f"""
202+
Optimization cannot be performed for the following reasons:
203+
- Optimizations that use {trt_compare_error} are not supported for GPU instances.
204+
"""
205+
else:
206+
joint_error_msg = f"""
207+
Optimization cannot be performed for the following reasons:
208+
- Optimizations that use {trt_compare_error} are not supported for GPU instances.
209+
- Optimizations that use {vllm_compare_error} are not supported for GPU instances.
210+
"""
206211
raise ValueError(textwrap.dedent(joint_error_msg))

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2683,7 +2683,7 @@ def test_optimize_exclusive_sharding(self, mock_get_serve_setting):
26832683
expected_error_message = """
26842684
Optimization cannot be performed for the following reasons:
26852685
- Optimizations that use Sharding are not supported for GPU instances.
2686-
- Optimizations that use Compilation and Quantization:awq are not supported for GPU instances.
2686+
- Optimizations that use Compilation are not supported for GPU instances.
26872687
"""
26882688

26892689
self.assertRaisesRegex(
@@ -2866,10 +2866,28 @@ def test_corner_cases_throw_errors(self):
28662866
)
28672867

28682868
def test_trt_and_vllm_configurations_throw_errors_for_rule_set(self):
2869+
expected_compilation_quantization_error_message = """
2870+
Optimization cannot be performed for the following reasons:
2871+
- Optimizations that use Quantization must be provided with Compilation for GPU instances.
2872+
- Optimizations that use Quantization:smoothquant are not supported for GPU instances.
2873+
"""
2874+
self.assertRaisesRegex(
2875+
ValueError,
2876+
textwrap.dedent(expected_compilation_quantization_error_message),
2877+
lambda: _validate_optimization_configuration(
2878+
instance_type="ml.g5.24xlarge",
2879+
quantization_config={
2880+
"OverrideEnvironment": {"OPTION_QUANTIZE": "smoothquant"},
2881+
},
2882+
sharding_config=None,
2883+
speculative_decoding_config=None,
2884+
compilation_config=None,
2885+
),
2886+
)
2887+
28692888
expected_quantization_error_message = """
28702889
Optimization cannot be performed for the following reasons:
28712890
- Optimizations that use Quantization:test are not supported for GPU instances.
2872-
- Optimizations that use Quantization:test are not supported for GPU instances.
28732891
"""
28742892
self.assertRaisesRegex(
28752893
ValueError,
@@ -2910,43 +2928,6 @@ def test_neuron_configurations_throw_errors_for_rule_set(self):
29102928
),
29112929
)
29122930

2913-
def test_trt_configurations_throw_errors_for_rule_se(self):
2914-
expected_compilation_quantization_error_message = """
2915-
Optimization cannot be performed for the following reasons:
2916-
- Optimizations that use Compilation must be provided with Quantization and vice-versa for GPU instances.
2917-
- Optimizations that use Quantization:awq are not supported for GPU instances.
2918-
"""
2919-
self.assertRaisesRegex(
2920-
ValueError,
2921-
textwrap.dedent(expected_compilation_quantization_error_message),
2922-
lambda: _validate_optimization_configuration(
2923-
instance_type="ml.g5.24xlarge",
2924-
quantization_config={
2925-
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
2926-
},
2927-
sharding_config=None,
2928-
speculative_decoding_config=None,
2929-
compilation_config=None,
2930-
),
2931-
)
2932-
2933-
expected_compilation_quantization_error_message = """
2934-
Optimization cannot be performed for the following reasons:
2935-
- Optimizations that use Compilation must be provided with Quantization and vice-versa for GPU instances.
2936-
- Optimizations that use Compilation are not supported for GPU instances.
2937-
"""
2938-
self.assertRaisesRegex(
2939-
ValueError,
2940-
textwrap.dedent(expected_compilation_quantization_error_message),
2941-
lambda: _validate_optimization_configuration(
2942-
instance_type="ml.g5.24xlarge",
2943-
quantization_config=None,
2944-
sharding_config=None,
2945-
speculative_decoding_config=None,
2946-
compilation_config={"key": "value"},
2947-
),
2948-
)
2949-
29502931
def test_trt_configurations_rule_set(self):
29512932
# Can be compiled with quantization
29522933
_validate_optimization_configuration(
@@ -2959,6 +2940,15 @@ def test_trt_configurations_rule_set(self):
29592940
compilation_config={"key": "value"},
29602941
),
29612942

2943+
# Can be just compiled
2944+
_validate_optimization_configuration(
2945+
instance_type="ml.g5.24xlarge",
2946+
quantization_config=None,
2947+
sharding_config=None,
2948+
speculative_decoding_config=None,
2949+
compilation_config={"key": "value"},
2950+
)
2951+
29622952
def test_vllm_configurations_rule_set(self):
29632953
# Can use speculative decoding
29642954
_validate_optimization_configuration(
@@ -2969,6 +2959,17 @@ def test_vllm_configurations_rule_set(self):
29692959
compilation_config=None,
29702960
)
29712961

2962+
# Cab be quantized
2963+
_validate_optimization_configuration(
2964+
instance_type="ml.g5.24xlarge",
2965+
quantization_config={
2966+
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
2967+
},
2968+
sharding_config=None,
2969+
speculative_decoding_config=None,
2970+
compilation_config=None,
2971+
)
2972+
29722973
# Can be sharded
29732974
_validate_optimization_configuration(
29742975
instance_type="ml.g5.24xlarge",

0 commit comments

Comments
 (0)