|
10 | 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
11 | 11 | # ANY KIND, either express or implied. See the License for the specific
|
12 | 12 | # language governing permissions and limitations under the License.
|
13 |
| -"""Holds the validation logic used for the .optimize() function""" |
| 13 | +"""Holds the validation logic used for the .optimize() function. INTERNAL only""" |
| 14 | +from __future__ import absolute_import |
| 15 | + |
| 16 | +import textwrap |
| 17 | +import logging |
14 | 18 | from typing import Any, Dict, Set
|
15 | 19 | from enum import Enum
|
16 | 20 | from pydantic import BaseModel
|
17 |
| -import textwrap |
18 |
| -import logging |
19 | 21 |
|
20 | 22 | logger = logging.getLogger(__name__)
|
21 | 23 |
|
22 | 24 |
|
23 |
| -class OptimizationContainer(Enum): |
| 25 | +class _OptimizationContainer(Enum): |
| 26 | + """Optimization containers""" |
| 27 | + |
24 | 28 | TRT = "trt"
|
25 | 29 | VLLM = "vllm"
|
26 | 30 | NEURON = "neuron"
|
27 | 31 |
|
28 | 32 |
|
29 |
| -class OptimizationCombination(BaseModel): |
30 |
| - optimization_container: OptimizationContainer = None |
| 33 | +class _OptimizationCombination(BaseModel): |
| 34 | + """Optimization ruleset data structure for comparing input to ruleset""" |
| 35 | + |
| 36 | + optimization_container: _OptimizationContainer = None |
31 | 37 | compilation: bool
|
32 | 38 | speculative_decoding: bool
|
33 | 39 | sharding: bool
|
34 | 40 | quantization_technique: Set[str | None]
|
35 | 41 |
|
36 |
| - def validate_against(self, optimization_combination, rule_set: OptimizationContainer): |
| 42 | + def validate_against(self, optimization_combination, rule_set: _OptimizationContainer): |
| 43 | + """Validator for optimization containers""" |
| 44 | + |
37 | 45 | if not optimization_combination.compilation == self.compilation:
|
38 |
| - raise ValueError("model compilation is not supported") |
39 |
| - if not optimization_combination.quantization_technique.issubset(self.quantization_technique): |
40 |
| - raise ValueError("model quantization is not supported") |
| 46 | + raise ValueError("Compilation") |
| 47 | + if not optimization_combination.quantization_technique.issubset( |
| 48 | + self.quantization_technique |
| 49 | + ): |
| 50 | + raise ValueError( |
| 51 | + f"Quantization:{optimization_combination.quantization_technique.pop()}" |
| 52 | + ) |
41 | 53 | if not optimization_combination.speculative_decoding == self.speculative_decoding:
|
42 |
| - raise ValueError("speculative decoding is not supported") |
| 54 | + raise ValueError("Speculative Decoding") |
43 | 55 | if not optimization_combination.sharding == self.sharding:
|
44 |
| - raise ValueError("model sharding is not supported") |
| 56 | + raise ValueError("Sharding") |
45 | 57 |
|
46 |
| - if rule_set == OptimizationContainer == OptimizationContainer.TRT: |
47 |
| - if optimization_combination.compilation and optimization_combination.speculative_decoding: |
48 |
| - raise ValueError("model compilation and speculative decoding provided together ") |
| 58 | + if rule_set == _OptimizationContainer == _OptimizationContainer.TRT: |
| 59 | + if ( |
| 60 | + optimization_combination.compilation |
| 61 | + and optimization_combination.speculative_decoding |
| 62 | + ): |
| 63 | + raise ValueError("Compilation and Speculative Decoding") |
49 | 64 | else:
|
50 |
| - if optimization_combination.compilation and optimization_combination.quantization_technique: |
51 |
| - raise ValueError("model compilation and model quantization provided together is not supported") |
| 65 | + if ( |
| 66 | + optimization_combination.compilation |
| 67 | + and optimization_combination.quantization_technique |
| 68 | + ): |
| 69 | + raise ValueError( |
| 70 | + f"Compilation and Quantization:{optimization_combination.quantization_technique.pop()}" |
| 71 | + ) |
52 | 72 |
|
53 | 73 |
|
54 | 74 | TRT_CONFIGURATION = {
|
55 | 75 | "supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"},
|
56 |
| - "optimization_combination": OptimizationCombination( |
57 |
| - optimization_container=OptimizationContainer.TRT, |
| 76 | + "optimization_combination": _OptimizationCombination( |
| 77 | + optimization_container=_OptimizationContainer.TRT, |
58 | 78 | compilation=True,
|
59 | 79 | quantization_technique={"awq", "fp8", "smooth_quant"},
|
60 | 80 | speculative_decoding=False,
|
61 | 81 | sharding=False,
|
62 |
| - ) |
| 82 | + ), |
63 | 83 | }
|
64 | 84 | VLLM_CONFIGURATION = {
|
65 | 85 | "supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"},
|
66 |
| - "optimization_combination": OptimizationCombination( |
67 |
| - optimization_container=OptimizationContainer.VLLM, |
| 86 | + "optimization_combination": _OptimizationCombination( |
| 87 | + optimization_container=_OptimizationContainer.VLLM, |
68 | 88 | compilation=False,
|
69 | 89 | quantization_technique={"awq", "fp8"},
|
70 | 90 | speculative_decoding=True,
|
71 |
| - sharding=True |
72 |
| - ) |
| 91 | + sharding=True, |
| 92 | + ), |
73 | 93 | }
|
74 | 94 | NEURON_CONFIGURATION = {
|
75 | 95 | "supported_instance_families": {"inf2", "trn1", "trn1n"},
|
76 |
| - "optimization_combination": OptimizationCombination( |
77 |
| - optimization_container=OptimizationContainer.NEURON, |
| 96 | + "optimization_combination": _OptimizationCombination( |
| 97 | + optimization_container=_OptimizationContainer.NEURON, |
78 | 98 | compilation=True,
|
79 | 99 | quantization_technique=set(),
|
80 | 100 | speculative_decoding=False,
|
81 |
| - sharding=False |
82 |
| - ) |
| 101 | + sharding=False, |
| 102 | + ), |
83 | 103 | }
|
84 | 104 |
|
85 | 105 | VALIDATION_ERROR_MSG = (
|
86 |
| - "The model cannot be optimized with the provided configurations on " |
87 |
| - "{optimization_container} supported {instance_type} because {validation_error}." |
| 106 | + "Optimizations that use {optimization_technique} " |
| 107 | + "are not currently supported on {instance_type} instances" |
88 | 108 | )
|
89 | 109 |
|
90 | 110 |
|
91 |
| -def validate_optimization_configuration( |
| 111 | +def _validate_optimization_configuration( |
92 | 112 | instance_type: str,
|
93 | 113 | quantization_config: Dict[str, Any],
|
94 | 114 | compilation_config: Dict[str, Any],
|
95 | 115 | sharding_config: Dict[str, Any],
|
96 |
| - speculative_decoding_config: Dict[str, Any] |
| 116 | + speculative_decoding_config: Dict[str, Any], |
97 | 117 | ):
|
| 118 | + """Validate .optimize() input off of standard ruleset""" |
| 119 | + |
98 | 120 | split_instance_type = instance_type.split(".")
|
99 | 121 | instance_family = None
|
100 | 122 | if len(split_instance_type) == 3: # invalid instance type will be caught below
|
101 | 123 | instance_family = split_instance_type[1]
|
102 | 124 |
|
103 | 125 | if (
|
104 |
| - not instance_family in TRT_CONFIGURATION["supported_instance_families"] and |
105 |
| - not instance_family in VLLM_CONFIGURATION["supported_instance_families"] and |
106 |
| - not instance_family in NEURON_CONFIGURATION["supported_instance_families"] |
| 126 | + instance_family not in TRT_CONFIGURATION["supported_instance_families"] |
| 127 | + and instance_family not in VLLM_CONFIGURATION["supported_instance_families"] |
| 128 | + and instance_family not in NEURON_CONFIGURATION["supported_instance_families"] |
107 | 129 | ):
|
108 |
| - invalid_instance_type_msg = f""" |
109 |
| - The model cannot be optimized on {instance_type}. Please optimize on the following instance type families: |
110 |
| - - For {OptimizationContainer.TRT} optimized container: {TRT_CONFIGURATION["supported_instance_families"]} |
111 |
| - - For {OptimizationContainer.VLLM} optimized container: {VLLM_CONFIGURATION["supported_instance_families"]} |
112 |
| - - For {OptimizationContainer.NEURON} optimized container: {NEURON_CONFIGURATION["supported_instance_families"]} |
113 |
| - """ |
114 |
| - raise ValueError(textwrap.dedent(invalid_instance_type_msg)) |
115 |
| - |
116 |
| - optimization_combination = OptimizationCombination( |
| 130 | + invalid_instance_type_msg = ( |
| 131 | + f"Optimizations that use {instance_type} are not currently supported" |
| 132 | + ) |
| 133 | + raise ValueError(invalid_instance_type_msg) |
| 134 | + |
| 135 | + optimization_combination = _OptimizationCombination( |
117 | 136 | compilation=not compilation_config,
|
118 | 137 | speculative_decoding=not speculative_decoding_config,
|
119 | 138 | sharding=not sharding_config,
|
120 |
| - quantization_technique={quantization_config.get("OPTION_QUANTIZE") if quantization_config else None} |
| 139 | + quantization_technique={ |
| 140 | + quantization_config.get("OPTION_QUANTIZE") if quantization_config else None |
| 141 | + }, |
121 | 142 | )
|
122 | 143 |
|
123 | 144 | if instance_type in NEURON_CONFIGURATION["supported_instance_families"]:
|
124 | 145 | try:
|
125 | 146 | (
|
126 |
| - NEURON_CONFIGURATION["optimization_combination"] |
127 |
| - .validate_against(optimization_combination, rule_set=OptimizationContainer.VLLM) |
| 147 | + NEURON_CONFIGURATION["optimization_combination"].validate_against( |
| 148 | + optimization_combination, rule_set=_OptimizationContainer.VLLM |
| 149 | + ) |
128 | 150 | )
|
129 | 151 | except ValueError as neuron_compare_error:
|
130 | 152 | raise ValueError(
|
131 | 153 | VALIDATION_ERROR_MSG.format(
|
132 |
| - optimization_container=OptimizationContainer.NEURON.value, |
133 |
| - instance_type=instance_type, |
134 |
| - validation_error=neuron_compare_error |
| 154 | + optimization_container=str(neuron_compare_error), |
| 155 | + instance_type="Neuron", |
135 | 156 | )
|
136 | 157 | )
|
137 | 158 | else:
|
138 | 159 | try:
|
139 | 160 | (
|
140 |
| - TRT_CONFIGURATION["optimization_combination"] |
141 |
| - .validate_against(optimization_combination, rule_set=OptimizationContainer.TRT) |
| 161 | + TRT_CONFIGURATION["optimization_combination"].validate_against( |
| 162 | + optimization_combination, rule_set=_OptimizationContainer.TRT |
| 163 | + ) |
142 | 164 | )
|
143 | 165 | except ValueError as trt_compare_error:
|
144 | 166 | try:
|
145 | 167 | (
|
146 |
| - VLLM_CONFIGURATION["optimization_combination"] |
147 |
| - .validate_against(optimization_combination, rule_set=OptimizationContainer.VLLM) |
| 168 | + VLLM_CONFIGURATION["optimization_combination"].validate_against( |
| 169 | + optimization_combination, rule_set=_OptimizationContainer.VLLM |
| 170 | + ) |
148 | 171 | )
|
149 | 172 | except ValueError as vllm_compare_error:
|
150 | 173 | trt_error_msg = VALIDATION_ERROR_MSG.format(
|
151 |
| - optimization_container=OptimizationContainer.TRT.value, |
152 |
| - instance_type=instance_type, |
153 |
| - validation_error=trt_compare_error |
| 174 | + optimization_container=str(trt_compare_error), instance_type="GPU" |
154 | 175 | )
|
155 | 176 | vllm_error_msg = VALIDATION_ERROR_MSG.format(
|
156 |
| - optimization_container=OptimizationContainer.VLLM.value, |
157 |
| - instance_type=instance_type, |
158 |
| - validation_error=vllm_compare_error |
| 177 | + optimization_container=str(vllm_compare_error), |
| 178 | + instance_type="GPU", |
159 | 179 | )
|
160 | 180 | joint_error_msg = f"""
|
161 |
| - The model cannot be optimized for the following reasons: |
| 181 | + Optimization cannot be performed for the following reasons: |
162 | 182 | - {trt_error_msg}
|
163 | 183 | - {vllm_error_msg}
|
164 | 184 | """
|
|
0 commit comments