Skip to content

Commit 54acaf7

Browse files
committed
add optimization validations
1 parent 5833143 commit 54acaf7

File tree

2 files changed

+172
-0
lines changed

2 files changed

+172
-0
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
get_huggingface_model_metadata,
105105
download_huggingface_model_metadata,
106106
)
107+
from sagemaker.serve.validations.optimization import validate_optimization_configuration
107108

108109
logger = logging.getLogger(__name__)
109110

@@ -1160,6 +1161,15 @@ def optimize(
11601161
Model: A deployable ``Model`` object.
11611162
"""
11621163

1164+
# TODO: ideally these dictionaries need to be sagemaker_core shapes
1165+
validate_optimization_configuration(
1166+
instance_type=instance_type,
1167+
quantization_config=quantization_config,
1168+
compilation_config=compilation_config,
1169+
sharding_config=sharding_config,
1170+
speculative_decoding_config=speculative_decoding_config,
1171+
)
1172+
11631173
# need to get telemetry_opt_out info before telemetry decorator is called
11641174
self.serve_settings = self._get_serve_setting()
11651175

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Holds the validation logic used for the .optimize() function"""
14+
from typing import Any, Dict, Set
15+
from enum import Enum
16+
from pydantic import BaseModel
17+
import textwrap
18+
import logging
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class OptimizationContainer(Enum):
24+
TRT = "trt"
25+
VLLM = "vllm"
26+
NEURON = "neuron"
27+
28+
29+
class OptimizationCombination(BaseModel):
30+
optimization_container: OptimizationContainer = None
31+
compilation: bool
32+
speculative_decoding: bool
33+
sharding: bool
34+
quantization_technique: Set[str | None]
35+
36+
def validate_against(self, optimization_combination, rule_set: OptimizationContainer):
37+
if not optimization_combination.compilation == self.compilation:
38+
raise ValueError("model compilation is not supported")
39+
if not optimization_combination.quantization_technique.issubset(self.quantization_technique):
40+
raise ValueError("model quantization is not supported")
41+
if not optimization_combination.speculative_decoding == self.speculative_decoding:
42+
raise ValueError("speculative decoding is not supported")
43+
if not optimization_combination.sharding == self.sharding:
44+
raise ValueError("model sharding is not supported")
45+
46+
if optimization_combination.compilation and optimization_combination.quantization_technique:
47+
if not rule_set == OptimizationContainer.TRT:
48+
raise ValueError("model compilation and model quantization provided together is not supported")
49+
50+
51+
TRT_CONFIGURATION = {
52+
"supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"},
53+
"optimization_combination": OptimizationCombination(
54+
optimization_container=OptimizationContainer.TRT,
55+
compilation=True,
56+
quantization_technique={"awq", "fp8", "smooth_quant"},
57+
speculative_decoding=False,
58+
sharding=False,
59+
)
60+
}
61+
VLLM_CONFIGURATION = {
62+
"supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"},
63+
"optimization_combination": OptimizationCombination(
64+
optimization_container=OptimizationContainer.VLLM,
65+
compilation=False,
66+
quantization_technique={"awq", "fp8"},
67+
speculative_decoding=True,
68+
sharding=True
69+
)
70+
}
71+
NEURON_CONFIGURATION = {
72+
"supported_instance_families": {"inf2", "trn1", "trn1n"},
73+
"optimization_combination": OptimizationCombination(
74+
optimization_container=OptimizationContainer.NEURON,
75+
compilation=True,
76+
quantization_technique=set(),
77+
speculative_decoding=False,
78+
sharding=False
79+
)
80+
}
81+
82+
VALIDATION_ERROR_MSG = (
83+
"The model cannot be optimized with the provided configurations on "
84+
"{optimization_container} supported {instance_type} because {validation_error}."
85+
)
86+
87+
88+
def validate_optimization_configuration(
89+
instance_type: str,
90+
quantization_config: Dict[str, Any],
91+
compilation_config: Dict[str, Any],
92+
sharding_config: Dict[str, Any],
93+
speculative_decoding_config: Dict[str, Any]
94+
):
95+
split_instance_type = instance_type.split(".")
96+
instance_family = None
97+
if len(split_instance_type) == 3: # invalid instance type will be caught below
98+
instance_family = split_instance_type[1]
99+
100+
if (
101+
not instance_family in TRT_CONFIGURATION["supported_instance_families"] and
102+
not instance_family in VLLM_CONFIGURATION["supported_instance_families"] and
103+
not instance_family in NEURON_CONFIGURATION["supported_instance_families"]
104+
):
105+
invalid_instance_type_msg = f"""
106+
The model cannot be optimized on {instance_type}. Please optimize on the following instance type families:
107+
- For {OptimizationContainer.TRT} optimized container: {TRT_CONFIGURATION["supported_instance_families"]}
108+
- For {OptimizationContainer.VLLM} optimized container: {VLLM_CONFIGURATION["supported_instance_families"]}
109+
- For {OptimizationContainer.NEURON} optimized container: {NEURON_CONFIGURATION["supported_instance_families"]}
110+
"""
111+
raise ValueError(textwrap.dedent(invalid_instance_type_msg))
112+
113+
optimization_combination = OptimizationCombination(
114+
compilation=not compilation_config,
115+
speculative_decoding=not speculative_decoding_config,
116+
sharding=not sharding_config,
117+
quantization_technique={quantization_config.get("OPTION_QUANTIZE") if quantization_config else None}
118+
)
119+
120+
if instance_type in NEURON_CONFIGURATION["supported_instance_families"]:
121+
try:
122+
(
123+
NEURON_CONFIGURATION["optimization_combination"]
124+
.validate_against(optimization_combination, rule_set=OptimizationContainer.VLLM)
125+
)
126+
except ValueError as neuron_compare_error:
127+
raise ValueError(
128+
VALIDATION_ERROR_MSG.format(
129+
optimization_container=OptimizationContainer.NEURON.value,
130+
instance_type=instance_type,
131+
validation_error=neuron_compare_error
132+
)
133+
)
134+
else:
135+
try:
136+
(
137+
TRT_CONFIGURATION["optimization_combination"]
138+
.validate_against(optimization_combination, rule_set=OptimizationContainer.TRT)
139+
)
140+
except ValueError as trt_compare_error:
141+
try:
142+
(
143+
VLLM_CONFIGURATION["optimization_combination"]
144+
.validate_against(optimization_combination, rule_set=OptimizationContainer.VLLM)
145+
)
146+
except ValueError as vllm_compare_error:
147+
trt_error_msg = VALIDATION_ERROR_MSG.format(
148+
optimization_container=OptimizationContainer.TRT.value,
149+
instance_type=instance_type,
150+
validation_error=trt_compare_error
151+
)
152+
vllm_error_msg = VALIDATION_ERROR_MSG.format(
153+
optimization_container=OptimizationContainer.VLLM.value,
154+
instance_type=instance_type,
155+
validation_error=vllm_compare_error
156+
)
157+
joint_error_msg = f"""
158+
The model cannot be optimized for the following reasons:
159+
- {trt_error_msg}
160+
- {vllm_error_msg}
161+
"""
162+
raise ValueError(textwrap.dedent(joint_error_msg))

0 commit comments

Comments
 (0)