Skip to content

Commit abd3e92

Browse files
committed
update validation logic
1 parent c727efb commit abd3e92

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

src/sagemaker/serve/validations/optimization.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,36 +176,49 @@ def _validate_optimization_configuration(
176176
)
177177
)
178178
else:
179-
try:
180-
(
179+
if optimization_combination.compilation.copy().pop(): # Compilation is only enabled for TRT
180+
try:
181181
TRT_CONFIGURATION["optimization_combination"].validate_against(
182182
optimization_combination, rule_set=_OptimizationContainer.TRT
183183
)
184-
)
185-
except ValueError as trt_compare_error:
184+
except ValueError as trt_compare_error:
185+
raise ValueError(
186+
(
187+
f"Optimizations that use Compilation and {trt_compare_error} "
188+
"are not supported for GPU instances."
189+
)
190+
)
191+
else:
186192
try:
187193
(
188194
VLLM_CONFIGURATION["optimization_combination"].validate_against(
189195
optimization_combination, rule_set=_OptimizationContainer.VLLM
190196
)
191197
)
192198
except ValueError as vllm_compare_error:
193-
if "Quantization must be provided with Compilation" in str(trt_compare_error):
194-
joint_error_msg = f"""
199+
try: # try both VLLM and TRT to cover both rule sets
200+
(
201+
TRT_CONFIGURATION["optimization_combination"].validate_against(
202+
optimization_combination, rule_set=_OptimizationContainer.TRT
203+
)
204+
)
205+
except ValueError as trt_compare_error:
206+
if "Quantization must be provided with Compilation" in str(trt_compare_error):
207+
joint_error_msg = f"""
195208
Optimization cannot be performed for the following reasons:
196209
- Optimizations that use {trt_compare_error} for GPU instances.
197210
- Optimizations that use {vllm_compare_error} are not supported for GPU instances.
198211
"""
199-
else:
200-
if str(trt_compare_error) == str(vllm_compare_error):
201-
joint_error_msg = f"""
212+
else:
213+
if str(trt_compare_error) == str(vllm_compare_error):
214+
joint_error_msg = f"""
202215
Optimization cannot be performed for the following reasons:
203216
- Optimizations that use {trt_compare_error} are not supported for GPU instances.
204217
"""
205-
else:
206-
joint_error_msg = f"""
218+
else:
219+
joint_error_msg = f"""
207220
Optimization cannot be performed for the following reasons:
208221
- Optimizations that use {trt_compare_error} are not supported for GPU instances.
209222
- Optimizations that use {vllm_compare_error} are not supported for GPU instances.
210223
"""
211-
raise ValueError(textwrap.dedent(joint_error_msg))
224+
raise ValueError(textwrap.dedent(joint_error_msg))

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2680,15 +2680,9 @@ def test_optimize_exclusive_sharding(self, mock_get_serve_setting):
26802680
sagemaker_session=mock_sagemaker_session,
26812681
)
26822682

2683-
expected_error_message = """
2684-
Optimization cannot be performed for the following reasons:
2685-
- Optimizations that use Sharding are not supported for GPU instances.
2686-
- Optimizations that use Compilation are not supported for GPU instances.
2687-
"""
2688-
26892683
self.assertRaisesRegex(
26902684
ValueError,
2691-
textwrap.dedent(expected_error_message),
2685+
"Optimizations that use Compilation and Sharding are not supported for GPU instances.",
26922686
lambda: model_builder.optimize(
26932687
instance_type="ml.g5.24xlarge",
26942688
quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},

0 commit comments

Comments
 (0)