Skip to content

Commit 22fdc37

Browse files
committed
fixing validation bugs
1 parent 3d04384 commit 22fdc37

File tree

2 files changed

+161
-43
lines changed

2 files changed

+161
-43
lines changed

src/sagemaker/serve/validations/optimization.py

Lines changed: 79 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,85 +25,107 @@
2525
class _OptimizationContainer(Enum):
2626
"""Optimization containers"""
2727

28-
TRT = "trt"
29-
VLLM = "vllm"
30-
NEURON = "neuron"
28+
TRT = "TRT"
29+
VLLM = "vLLM"
30+
NEURON = "Neuron"
3131

3232

3333
class _OptimizationCombination(BaseModel):
3434
"""Optimization ruleset data structure for comparing input to ruleset"""
3535

3636
optimization_container: _OptimizationContainer = None
37-
compilation: bool
38-
speculative_decoding: bool
39-
sharding: bool
37+
compilation: Set[bool | None]
38+
speculative_decoding: Set[bool | None]
39+
sharding: Set[bool | None]
4040
quantization_technique: Set[str | None]
4141

4242
def validate_against(self, optimization_combination, rule_set: _OptimizationContainer):
4343
"""Validator for optimization containers"""
4444

45-
if not optimization_combination.compilation == self.compilation:
45+
# check the case where no optimization combination is provided
46+
if (
47+
optimization_combination.compilation == {None}
48+
and optimization_combination.quantization_technique == {None}
49+
and optimization_combination.speculative_decoding == {None}
50+
and optimization_combination.sharding == {None}
51+
):
52+
raise ValueError("Optimizations are not currently supported without optimization configurations.")
53+
54+
# check the validity of each individual field
55+
if not optimization_combination.compilation.issubset(self.compilation):
4656
raise ValueError("Compilation")
4757
if not optimization_combination.quantization_technique.issubset(
4858
self.quantization_technique
4959
):
60+
copy_quantization_technique = optimization_combination.quantization_technique.copy()
5061
raise ValueError(
51-
f"Quantization:{optimization_combination.quantization_technique.pop()}"
62+
f"Quantization:{copy_quantization_technique.pop()}"
5263
)
53-
if not optimization_combination.speculative_decoding == self.speculative_decoding:
64+
if not optimization_combination.speculative_decoding.issubset(self.speculative_decoding):
5465
raise ValueError("Speculative Decoding")
55-
if not optimization_combination.sharding == self.sharding:
66+
if not optimization_combination.sharding.issubset(self.sharding):
5667
raise ValueError("Sharding")
5768

58-
if rule_set == _OptimizationContainer == _OptimizationContainer.TRT:
69+
# optimization technique combinations that need to be validated
70+
if rule_set == _OptimizationContainer.TRT:
5971
if (
6072
optimization_combination.compilation
6173
and optimization_combination.speculative_decoding
6274
):
63-
raise ValueError("Compilation and Speculative Decoding")
75+
copy_compilation = optimization_combination.compilation.copy()
76+
copy_speculative_decoding = optimization_combination.speculative_decoding.copy()
77+
if (
78+
copy_compilation.pop()
79+
and copy_speculative_decoding.pop()
80+
): # Check that the 2 techniques are not None
81+
raise ValueError("Compilation and Speculative Decoding")
6482
else:
83+
copy_compilation = optimization_combination.compilation.copy()
84+
copy_quantization_technique = optimization_combination.quantization_technique.copy()
6585
if (
66-
optimization_combination.compilation
67-
and optimization_combination.quantization_technique
68-
):
86+
copy_compilation.pop()
87+
and copy_quantization_technique.pop()
88+
): # Check that the 2 techniques are not None
6989
raise ValueError(
7090
f"Compilation and Quantization:{optimization_combination.quantization_technique.pop()}"
7191
)
7292

7393

94+
TRUTHY_SET = {None, True}
95+
FALSY_SET = {None, False}
7496
TRT_CONFIGURATION = {
7597
"supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"},
7698
"optimization_combination": _OptimizationCombination(
7799
optimization_container=_OptimizationContainer.TRT,
78-
compilation=True,
79-
quantization_technique={"awq", "fp8", "smooth_quant"},
80-
speculative_decoding=False,
81-
sharding=False,
100+
compilation=TRUTHY_SET,
101+
quantization_technique={None, "awq", "fp8", "smooth_quant"},
102+
speculative_decoding=FALSY_SET,
103+
sharding=FALSY_SET,
82104
),
83105
}
84106
VLLM_CONFIGURATION = {
85107
"supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"},
86108
"optimization_combination": _OptimizationCombination(
87109
optimization_container=_OptimizationContainer.VLLM,
88-
compilation=False,
89-
quantization_technique={"awq", "fp8"},
90-
speculative_decoding=True,
91-
sharding=True,
110+
compilation=FALSY_SET,
111+
quantization_technique={None, "awq", "fp8"},
112+
speculative_decoding=TRUTHY_SET,
113+
sharding=TRUTHY_SET,
92114
),
93115
}
94116
NEURON_CONFIGURATION = {
95117
"supported_instance_families": {"inf2", "trn1", "trn1n"},
96118
"optimization_combination": _OptimizationCombination(
97119
optimization_container=_OptimizationContainer.NEURON,
98-
compilation=True,
99-
quantization_technique=set(),
100-
speculative_decoding=False,
101-
sharding=False,
120+
compilation=TRUTHY_SET,
121+
quantization_technique={None},
122+
speculative_decoding=FALSY_SET,
123+
sharding=FALSY_SET,
102124
),
103125
}
104126

105127
VALIDATION_ERROR_MSG = (
106-
"Optimizations that use {optimization_technique} "
128+
"Optimizations for {optimization_container} that use {optimization_technique} "
107129
"are not currently supported on {instance_type} instances"
108130
)
109131

@@ -117,28 +139,41 @@ def _validate_optimization_configuration(
117139
):
118140
"""Validate .optimize() input off of standard ruleset"""
119141

120-
split_instance_type = instance_type.split(".")
121142
instance_family = None
122-
if len(split_instance_type) == 3: # invalid instance type will be caught below
123-
instance_family = split_instance_type[1]
143+
if instance_type:
144+
split_instance_type = instance_type.split(".")
145+
if len(split_instance_type) == 3:
146+
instance_family = split_instance_type[1]
124147

125148
if (
126149
instance_family not in TRT_CONFIGURATION["supported_instance_families"]
127150
and instance_family not in VLLM_CONFIGURATION["supported_instance_families"]
128151
and instance_family not in NEURON_CONFIGURATION["supported_instance_families"]
129152
):
130153
invalid_instance_type_msg = (
131-
f"Optimizations that use {instance_type} are not currently supported"
154+
f"Optimizations that uses {instance_type} instance type are not currently supported"
132155
)
133156
raise ValueError(invalid_instance_type_msg)
134157

158+
quantization_technique = None
159+
if (
160+
quantization_config
161+
and quantization_config.get("OverrideEnvironment")
162+
and quantization_config.get("OverrideEnvironment").get("OPTION_QUANTIZE")
163+
):
164+
quantization_technique = quantization_config.get("OverrideEnvironment").get("OPTION_QUANTIZE")
165+
135166
optimization_combination = _OptimizationCombination(
136-
compilation=not compilation_config,
137-
speculative_decoding=not speculative_decoding_config,
138-
sharding=not sharding_config,
139-
quantization_technique={
140-
quantization_config.get("OPTION_QUANTIZE") if quantization_config else None
167+
compilation={
168+
None if compilation_config is None else bool(compilation_config)
169+
},
170+
speculative_decoding={
171+
None if speculative_decoding_config is None else bool(speculative_decoding_config)
172+
},
173+
sharding={
174+
None if sharding_config is None else bool(sharding_config)
141175
},
176+
quantization_technique={quantization_technique},
142177
)
143178

144179
if instance_type in NEURON_CONFIGURATION["supported_instance_families"]:
@@ -151,7 +186,8 @@ def _validate_optimization_configuration(
151186
except ValueError as neuron_compare_error:
152187
raise ValueError(
153188
VALIDATION_ERROR_MSG.format(
154-
optimization_container=str(neuron_compare_error),
189+
optimization_container=_OptimizationContainer.NEURON.value,
190+
optimization_technique=str(neuron_compare_error),
155191
instance_type="Neuron",
156192
)
157193
)
@@ -171,10 +207,13 @@ def _validate_optimization_configuration(
171207
)
172208
except ValueError as vllm_compare_error:
173209
trt_error_msg = VALIDATION_ERROR_MSG.format(
174-
optimization_container=str(trt_compare_error), instance_type="GPU"
210+
optimization_container=_OptimizationContainer.TRT.value,
211+
optimization_technique=str(trt_compare_error),
212+
instance_type="GPU"
175213
)
176214
vllm_error_msg = VALIDATION_ERROR_MSG.format(
177-
optimization_container=str(vllm_compare_error),
215+
optimization_container=_OptimizationContainer.VLLM.value,
216+
optimization_technique=str(vllm_compare_error),
178217
instance_type="GPU",
179218
)
180219
joint_error_msg = f"""

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
15+
import textwrap
1416
from unittest.mock import MagicMock, patch, Mock, mock_open
1517

1618
import unittest
@@ -25,6 +27,7 @@
2527
from sagemaker.serve.utils.exceptions import TaskNotFoundException
2628
from sagemaker.serve.utils.predictors import TensorflowServingLocalPredictor
2729
from sagemaker.serve.utils.types import ModelServer
30+
from sagemaker.serve.validations.optimization import _validate_optimization_configuration
2831
from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG
2932

3033
schema_builder = MagicMock()
@@ -2383,7 +2386,7 @@ def test_optimize(
23832386
builder.pysdk_model = pysdk_model
23842387

23852388
job_name = "my-optimization-job"
2386-
instance_type = "ml.inf1.xlarge"
2389+
instance_type = "ml.inf2.xlarge"
23872390
output_path = "s3://my-bucket/output"
23882391
quantization_config = {
23892392
"Image": "quantization-image-uri",
@@ -2425,7 +2428,7 @@ def test_optimize(
24252428
mock_send_telemetry.assert_called_once()
24262429
mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with(
24272430
OptimizationJobName="my-optimization-job",
2428-
DeploymentInstanceType="ml.inf1.xlarge",
2431+
DeploymentInstanceType="ml.inf2.xlarge",
24292432
RoleArn="arn:aws:iam::123456789012:role/SageMakerRole",
24302433
OptimizationEnvironment={"Var1": "value", "Var2": "value"},
24312434
ModelSource={"S3": {"S3Uri": "s3://uri"}},
@@ -2646,6 +2649,7 @@ def test_optimize_local_mode(self, mock_get_serve_setting):
26462649
ValueError,
26472650
"Model optimization is only supported in Sagemaker Endpoint Mode.",
26482651
lambda: model_builder.optimize(
2652+
instance_type="ml.g5.24xlarge",
26492653
quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}
26502654
),
26512655
)
@@ -2662,6 +2666,7 @@ def test_optimize_exclusive_args(self, mock_get_serve_setting):
26622666
ValueError,
26632667
"Quantization config and compilation config are mutually exclusive.",
26642668
lambda: model_builder.optimize(
2669+
instance_type="ml.g5.24xlarge",
26652670
quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
26662671
compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
26672672
),
@@ -2675,10 +2680,17 @@ def test_optimize_exclusive_sharding(self, mock_get_serve_setting):
26752680
sagemaker_session=mock_sagemaker_session,
26762681
)
26772682

2683+
expected_error_message = """
2684+
Optimization cannot be performed for the following reasons:
2685+
- Optimizations for TRT that use Sharding are not currently supported on GPU instances
2686+
- Optimizations for vLLM that use Compilation are not currently supported on GPU instances
2687+
"""
2688+
26782689
self.assertRaisesRegex(
26792690
ValueError,
2680-
"Sharding config is mutually exclusive and cannot be combined with any other optimization.",
2691+
textwrap.dedent(expected_error_message),
26812692
lambda: model_builder.optimize(
2693+
instance_type="ml.g5.24xlarge",
26822694
quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
26832695
compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
26842696
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
@@ -2697,6 +2709,7 @@ def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting):
26972709
ValueError,
26982710
"OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.",
26992711
lambda: model_builder.optimize(
2712+
instance_type="ml.g5.24xlarge",
27002713
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
27012714
),
27022715
)
@@ -2820,3 +2833,69 @@ def test_optimize_for_hf_without_custom_s3_path(
28202833
"OutputConfig": {"S3OutputLocation": "s3://bucket/code/"},
28212834
},
28222835
)
2836+
2837+
2838+
class TestModelBuilderOptimizeValidations(unittest.TestCase):
2839+
2840+
def test_corner_cases_throw_errors(self):
2841+
self.assertRaisesRegex(
2842+
ValueError,
2843+
"Optimizations that uses None instance type are not currently supported",
2844+
lambda: _validate_optimization_configuration(
2845+
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
2846+
instance_type=None,
2847+
quantization_config=None,
2848+
speculative_decoding_config=None,
2849+
compilation_config=None,
2850+
),
2851+
)
2852+
self.assertRaisesRegex(
2853+
ValueError,
2854+
"Optimizations are not currently supported without optimization configurations.",
2855+
lambda: _validate_optimization_configuration(
2856+
instance_type="ml.g5.24xlarge",
2857+
quantization_config=None,
2858+
speculative_decoding_config=None,
2859+
compilation_config=None,
2860+
sharding_config=None,
2861+
),
2862+
)
2863+
2864+
def test_trt_and_vllm_configurations_throw_errors_for_rule_set(self):
2865+
expected_quantization_error_message = """
2866+
Optimization cannot be performed for the following reasons:
2867+
- Optimizations for TRT that use Quantization:test are not currently supported on GPU instances
2868+
- Optimizations for vLLM that use Quantization:test are not currently supported on GPU instances
2869+
"""
2870+
self.assertRaisesRegex(
2871+
ValueError,
2872+
textwrap.dedent(expected_quantization_error_message),
2873+
lambda: _validate_optimization_configuration(
2874+
instance_type="ml.g5.24xlarge",
2875+
quantization_config={
2876+
"OverrideEnvironment": {"OPTION_QUANTIZE": "test"},
2877+
},
2878+
sharding_config=None,
2879+
speculative_decoding_config=None,
2880+
compilation_config=None,
2881+
),
2882+
)
2883+
2884+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
2885+
def test_neuron_configurations_throw_errors_for_rule_set(self, mock_get_serve_setting):
2886+
pass
2887+
2888+
def test_trt_configurations_rule_set(self):
2889+
_validate_optimization_configuration(
2890+
instance_type="ml.g5.24xlarge",
2891+
quantization_config={
2892+
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
2893+
},
2894+
sharding_config=None,
2895+
speculative_decoding_config=None,
2896+
compilation_config=None,
2897+
)
2898+
2899+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
2900+
def test_vllm_configurations_rule_set(self, mock_get_serve_setting):
2901+
pass

0 commit comments

Comments
 (0)