Skip to content

Commit 062e29b

Browse files
committed
fix validations
1 parent b7b8d3c commit 062e29b

File tree

1 file changed

+8
-17
lines changed

1 file changed

+8
-17
lines changed

src/sagemaker/serve/validations/optimization.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import textwrap
1717
import logging
18-
from typing import Any, Dict, Set
18+
from typing import Any, Dict, Set, Optional
1919
from enum import Enum
2020
from pydantic import BaseModel
2121

@@ -34,10 +34,10 @@ class _OptimizationCombination(BaseModel):
3434
"""Optimization ruleset data structure for comparing input to ruleset"""
3535

3636
optimization_container: _OptimizationContainer = None
37-
compilation: Set[bool | None]
38-
speculative_decoding: Set[bool | None]
39-
sharding: Set[bool | None]
40-
quantization_technique: Set[str | None]
37+
compilation: Set[Optional[bool]]
38+
speculative_decoding: Set[Optional[bool]]
39+
sharding: Set[Optional[bool]]
40+
quantization_technique: Set[Optional[str]]
4141

4242
def validate_against(self, optimization_combination, rule_set: _OptimizationContainer):
4343
"""Validator for optimization containers"""
@@ -66,16 +66,7 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
6666
is_compiled = optimization_combination.compilation.copy().pop()
6767
is_quantized = optimization_combination.quantization_technique.copy().pop()
6868
if is_compiled and not is_quantized or is_quantized and not is_compiled:
69-
raise ValueError("Compilation must be provided with Quantization")
70-
else:
71-
is_compiled = optimization_combination.compilation.copy().pop()
72-
is_quantization_technique = optimization_combination.quantization_technique.copy().pop()
73-
if (
74-
is_compiled and is_quantization_technique
75-
): # Check that the 2 techniques are not None
76-
raise ValueError(
77-
f"Compilation and Quantization:{optimization_combination.quantization_technique.pop()}"
78-
)
69+
raise ValueError(f"Compilation must be provided with Quantization")
7970

8071

8172
TRUTHY_SET = {None, True}
@@ -95,7 +86,7 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
9586
"optimization_combination": _OptimizationCombination(
9687
optimization_container=_OptimizationContainer.VLLM,
9788
compilation=FALSY_SET,
98-
quantization_technique={None},
89+
quantization_technique={None, "awq", "fp8"},
9990
speculative_decoding=TRUTHY_SET,
10091
sharding=TRUTHY_SET,
10192
),
@@ -200,7 +191,7 @@ def _validate_optimization_configuration(
200191
)
201192
print("fsdafas")
202193
except ValueError as vllm_compare_error:
203-
if str(trt_compare_error) == "Compilation must be provided with Quantization":
194+
if "Compilation must be provided with Quantization" in str(trt_compare_error):
204195
joint_error_msg = f"""
205196
Optimization cannot be performed for the following reasons:
206197
- Optimizations that use {trt_compare_error} and vice-versa for GPU instances.

0 commit comments

Comments
 (0)