Skip to content

Commit bf55587

Browse files
committed
add optimization validations
1 parent 5360c94 commit bf55587

File tree

2 files changed

+175
-0
lines changed

2 files changed

+175
-0
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
get_huggingface_model_metadata,
105105
download_huggingface_model_metadata,
106106
)
107+
from sagemaker.serve.validations.optimization import validate_optimization_configuration
107108

108109
logger = logging.getLogger(__name__)
109110

@@ -1160,6 +1161,15 @@ def optimize(
11601161
Model: A deployable ``Model`` object.
11611162
"""
11621163

1164+
# TODO: ideally these dictionaries need to be sagemaker_core shapes
1165+
validate_optimization_configuration(
1166+
instance_type=instance_type,
1167+
quantization_config=quantization_config,
1168+
compilation_config=compilation_config,
1169+
sharding_config=sharding_config,
1170+
speculative_decoding_config=speculative_decoding_config,
1171+
)
1172+
11631173
# need to get telemetry_opt_out info before telemetry decorator is called
11641174
self.serve_settings = self._get_serve_setting()
11651175

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Holds the validation logic used for the .optimize() function"""
14+
from typing import Any, Dict, Set
15+
from enum import Enum
16+
from pydantic import BaseModel
17+
import textwrap
18+
import logging
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class OptimizationContainer(Enum):
24+
TRT = "trt"
25+
VLLM = "vllm"
26+
NEURON = "neuron"
27+
28+
29+
class OptimizationCombination(BaseModel):
30+
optimization_container: OptimizationContainer = None
31+
compilation: bool
32+
speculative_decoding: bool
33+
sharding: bool
34+
quantization_technique: Set[str | None]
35+
36+
def validate_against(self, optimization_combination, rule_set: OptimizationContainer):
37+
if not optimization_combination.compilation == self.compilation:
38+
raise ValueError("model compilation is not supported")
39+
if not optimization_combination.quantization_technique.issubset(self.quantization_technique):
40+
raise ValueError("model quantization is not supported")
41+
if not optimization_combination.speculative_decoding == self.speculative_decoding:
42+
raise ValueError("speculative decoding is not supported")
43+
if not optimization_combination.sharding == self.sharding:
44+
raise ValueError("model sharding is not supported")
45+
46+
if rule_set == OptimizationContainer == OptimizationContainer.TRT:
47+
if optimization_combination.compilation and optimization_combination.speculative_decoding:
48+
raise ValueError("model compilation and speculative decoding provided together ")
49+
else:
50+
if optimization_combination.compilation and optimization_combination.quantization_technique:
51+
raise ValueError("model compilation and model quantization provided together is not supported")
52+
53+
54+
TRT_CONFIGURATION = {
55+
"supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"},
56+
"optimization_combination": OptimizationCombination(
57+
optimization_container=OptimizationContainer.TRT,
58+
compilation=True,
59+
quantization_technique={"awq", "fp8", "smooth_quant"},
60+
speculative_decoding=False,
61+
sharding=False,
62+
)
63+
}
64+
VLLM_CONFIGURATION = {
65+
"supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"},
66+
"optimization_combination": OptimizationCombination(
67+
optimization_container=OptimizationContainer.VLLM,
68+
compilation=False,
69+
quantization_technique={"awq", "fp8"},
70+
speculative_decoding=True,
71+
sharding=True
72+
)
73+
}
74+
NEURON_CONFIGURATION = {
75+
"supported_instance_families": {"inf2", "trn1", "trn1n"},
76+
"optimization_combination": OptimizationCombination(
77+
optimization_container=OptimizationContainer.NEURON,
78+
compilation=True,
79+
quantization_technique=set(),
80+
speculative_decoding=False,
81+
sharding=False
82+
)
83+
}
84+
85+
VALIDATION_ERROR_MSG = (
86+
"The model cannot be optimized with the provided configurations on "
87+
"{optimization_container} supported {instance_type} because {validation_error}."
88+
)
89+
90+
91+
def validate_optimization_configuration(
92+
instance_type: str,
93+
quantization_config: Dict[str, Any],
94+
compilation_config: Dict[str, Any],
95+
sharding_config: Dict[str, Any],
96+
speculative_decoding_config: Dict[str, Any]
97+
):
98+
split_instance_type = instance_type.split(".")
99+
instance_family = None
100+
if len(split_instance_type) == 3: # invalid instance type will be caught below
101+
instance_family = split_instance_type[1]
102+
103+
if (
104+
not instance_family in TRT_CONFIGURATION["supported_instance_families"] and
105+
not instance_family in VLLM_CONFIGURATION["supported_instance_families"] and
106+
not instance_family in NEURON_CONFIGURATION["supported_instance_families"]
107+
):
108+
invalid_instance_type_msg = f"""
109+
The model cannot be optimized on {instance_type}. Please optimize on the following instance type families:
110+
- For {OptimizationContainer.TRT} optimized container: {TRT_CONFIGURATION["supported_instance_families"]}
111+
- For {OptimizationContainer.VLLM} optimized container: {VLLM_CONFIGURATION["supported_instance_families"]}
112+
- For {OptimizationContainer.NEURON} optimized container: {NEURON_CONFIGURATION["supported_instance_families"]}
113+
"""
114+
raise ValueError(textwrap.dedent(invalid_instance_type_msg))
115+
116+
optimization_combination = OptimizationCombination(
117+
compilation=not compilation_config,
118+
speculative_decoding=not speculative_decoding_config,
119+
sharding=not sharding_config,
120+
quantization_technique={quantization_config.get("OPTION_QUANTIZE") if quantization_config else None}
121+
)
122+
123+
if instance_type in NEURON_CONFIGURATION["supported_instance_families"]:
124+
try:
125+
(
126+
NEURON_CONFIGURATION["optimization_combination"]
127+
.validate_against(optimization_combination, rule_set=OptimizationContainer.VLLM)
128+
)
129+
except ValueError as neuron_compare_error:
130+
raise ValueError(
131+
VALIDATION_ERROR_MSG.format(
132+
optimization_container=OptimizationContainer.NEURON.value,
133+
instance_type=instance_type,
134+
validation_error=neuron_compare_error
135+
)
136+
)
137+
else:
138+
try:
139+
(
140+
TRT_CONFIGURATION["optimization_combination"]
141+
.validate_against(optimization_combination, rule_set=OptimizationContainer.TRT)
142+
)
143+
except ValueError as trt_compare_error:
144+
try:
145+
(
146+
VLLM_CONFIGURATION["optimization_combination"]
147+
.validate_against(optimization_combination, rule_set=OptimizationContainer.VLLM)
148+
)
149+
except ValueError as vllm_compare_error:
150+
trt_error_msg = VALIDATION_ERROR_MSG.format(
151+
optimization_container=OptimizationContainer.TRT.value,
152+
instance_type=instance_type,
153+
validation_error=trt_compare_error
154+
)
155+
vllm_error_msg = VALIDATION_ERROR_MSG.format(
156+
optimization_container=OptimizationContainer.VLLM.value,
157+
instance_type=instance_type,
158+
validation_error=vllm_compare_error
159+
)
160+
joint_error_msg = f"""
161+
The model cannot be optimized for the following reasons:
162+
- {trt_error_msg}
163+
- {vllm_error_msg}
164+
"""
165+
raise ValueError(textwrap.dedent(joint_error_msg))

0 commit comments

Comments
 (0)