|
10 | 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
11 | 11 | # ANY KIND, either express or implied. See the License for the specific |
12 | 12 | # language governing permissions and limitations under the License. |
13 | | -"""Holds the validation logic used for the .optimize() function""" |
| 13 | +"""Holds the validation logic used for the .optimize() function. INTERNAL only""" |
| 14 | +from __future__ import absolute_import |
| 15 | + |
| 16 | +import textwrap |
| 17 | +import logging |
14 | 18 | from typing import Any, Dict, Set |
15 | 19 | from enum import Enum |
16 | 20 | from pydantic import BaseModel |
17 | | -import textwrap |
18 | | -import logging |
19 | 21 |
|
20 | 22 | logger = logging.getLogger(__name__) |
21 | 23 |
|
22 | 24 |
|
23 | | -class OptimizationContainer(Enum): |
| 25 | +class _OptimizationContainer(Enum): |
| 26 | + """Optimization containers""" |
| 27 | + |
24 | 28 | TRT = "trt" |
25 | 29 | VLLM = "vllm" |
26 | 30 | NEURON = "neuron" |
27 | 31 |
|
28 | 32 |
|
29 | | -class OptimizationCombination(BaseModel): |
30 | | - optimization_container: OptimizationContainer = None |
| 33 | +class _OptimizationCombination(BaseModel): |
| 34 | + """Optimization ruleset data structure for comparing input to ruleset""" |
| 35 | + |
| 36 | + optimization_container: _OptimizationContainer = None |
31 | 37 | compilation: bool |
32 | 38 | speculative_decoding: bool |
33 | 39 | sharding: bool |
34 | 40 | quantization_technique: Set[str | None] |
35 | 41 |
|
36 | | - def validate_against(self, optimization_combination, rule_set: OptimizationContainer): |
| 42 | + def validate_against(self, optimization_combination, rule_set: _OptimizationContainer): |
| 43 | + """Validator for optimization containers""" |
| 44 | + |
37 | 45 | if not optimization_combination.compilation == self.compilation: |
38 | | - raise ValueError("model compilation is not supported") |
39 | | - if not optimization_combination.quantization_technique.issubset(self.quantization_technique): |
40 | | - raise ValueError("model quantization is not supported") |
| 46 | + raise ValueError("Compilation") |
| 47 | + if not optimization_combination.quantization_technique.issubset( |
| 48 | + self.quantization_technique |
| 49 | + ): |
| 50 | + raise ValueError( |
| 51 | + f"Quantization:{optimization_combination.quantization_technique.pop()}" |
| 52 | + ) |
41 | 53 | if not optimization_combination.speculative_decoding == self.speculative_decoding: |
42 | | - raise ValueError("speculative decoding is not supported") |
| 54 | + raise ValueError("Speculative Decoding") |
43 | 55 | if not optimization_combination.sharding == self.sharding: |
44 | | - raise ValueError("model sharding is not supported") |
| 56 | + raise ValueError("Sharding") |
45 | 57 |
|
46 | | - if rule_set == OptimizationContainer == OptimizationContainer.TRT: |
47 | | - if optimization_combination.compilation and optimization_combination.speculative_decoding: |
48 | | - raise ValueError("model compilation and speculative decoding provided together ") |
| 58 | + if rule_set == _OptimizationContainer == _OptimizationContainer.TRT: |
| 59 | + if ( |
| 60 | + optimization_combination.compilation |
| 61 | + and optimization_combination.speculative_decoding |
| 62 | + ): |
| 63 | + raise ValueError("Compilation and Speculative Decoding") |
49 | 64 | else: |
50 | | - if optimization_combination.compilation and optimization_combination.quantization_technique: |
51 | | - raise ValueError("model compilation and model quantization provided together is not supported") |
| 65 | + if ( |
| 66 | + optimization_combination.compilation |
| 67 | + and optimization_combination.quantization_technique |
| 68 | + ): |
| 69 | + raise ValueError( |
| 70 | + f"Compilation and Quantization:{optimization_combination.quantization_technique.pop()}" |
| 71 | + ) |
52 | 72 |
|
53 | 73 |
|
54 | 74 | TRT_CONFIGURATION = { |
55 | 75 | "supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"}, |
56 | | - "optimization_combination": OptimizationCombination( |
57 | | - optimization_container=OptimizationContainer.TRT, |
| 76 | + "optimization_combination": _OptimizationCombination( |
| 77 | + optimization_container=_OptimizationContainer.TRT, |
58 | 78 | compilation=True, |
59 | 79 | quantization_technique={"awq", "fp8", "smooth_quant"}, |
60 | 80 | speculative_decoding=False, |
61 | 81 | sharding=False, |
62 | | - ) |
| 82 | + ), |
63 | 83 | } |
64 | 84 | VLLM_CONFIGURATION = { |
65 | 85 | "supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"}, |
66 | | - "optimization_combination": OptimizationCombination( |
67 | | - optimization_container=OptimizationContainer.VLLM, |
| 86 | + "optimization_combination": _OptimizationCombination( |
| 87 | + optimization_container=_OptimizationContainer.VLLM, |
68 | 88 | compilation=False, |
69 | 89 | quantization_technique={"awq", "fp8"}, |
70 | 90 | speculative_decoding=True, |
71 | | - sharding=True |
72 | | - ) |
| 91 | + sharding=True, |
| 92 | + ), |
73 | 93 | } |
74 | 94 | NEURON_CONFIGURATION = { |
75 | 95 | "supported_instance_families": {"inf2", "trn1", "trn1n"}, |
76 | | - "optimization_combination": OptimizationCombination( |
77 | | - optimization_container=OptimizationContainer.NEURON, |
| 96 | + "optimization_combination": _OptimizationCombination( |
| 97 | + optimization_container=_OptimizationContainer.NEURON, |
78 | 98 | compilation=True, |
79 | 99 | quantization_technique=set(), |
80 | 100 | speculative_decoding=False, |
81 | | - sharding=False |
82 | | - ) |
| 101 | + sharding=False, |
| 102 | + ), |
83 | 103 | } |
84 | 104 |
|
85 | 105 | VALIDATION_ERROR_MSG = ( |
86 | | - "The model cannot be optimized with the provided configurations on " |
87 | | - "{optimization_container} supported {instance_type} because {validation_error}." |
| 106 | + "Optimizations that use {optimization_technique} " |
| 107 | + "are not currently supported on {instance_type} instances" |
88 | 108 | ) |
89 | 109 |
|
90 | 110 |
|
91 | | -def validate_optimization_configuration( |
| 111 | +def _validate_optimization_configuration( |
92 | 112 | instance_type: str, |
93 | 113 | quantization_config: Dict[str, Any], |
94 | 114 | compilation_config: Dict[str, Any], |
95 | 115 | sharding_config: Dict[str, Any], |
96 | | - speculative_decoding_config: Dict[str, Any] |
| 116 | + speculative_decoding_config: Dict[str, Any], |
97 | 117 | ): |
| 118 | + """Validate .optimize() input off of standard ruleset""" |
| 119 | + |
98 | 120 | split_instance_type = instance_type.split(".") |
99 | 121 | instance_family = None |
100 | 122 | if len(split_instance_type) == 3: # invalid instance type will be caught below |
101 | 123 | instance_family = split_instance_type[1] |
102 | 124 |
|
103 | 125 | if ( |
104 | | - not instance_family in TRT_CONFIGURATION["supported_instance_families"] and |
105 | | - not instance_family in VLLM_CONFIGURATION["supported_instance_families"] and |
106 | | - not instance_family in NEURON_CONFIGURATION["supported_instance_families"] |
| 126 | + instance_family not in TRT_CONFIGURATION["supported_instance_families"] |
| 127 | + and instance_family not in VLLM_CONFIGURATION["supported_instance_families"] |
| 128 | + and instance_family not in NEURON_CONFIGURATION["supported_instance_families"] |
107 | 129 | ): |
108 | | - invalid_instance_type_msg = f""" |
109 | | - The model cannot be optimized on {instance_type}. Please optimize on the following instance type families: |
110 | | - - For {OptimizationContainer.TRT} optimized container: {TRT_CONFIGURATION["supported_instance_families"]} |
111 | | - - For {OptimizationContainer.VLLM} optimized container: {VLLM_CONFIGURATION["supported_instance_families"]} |
112 | | - - For {OptimizationContainer.NEURON} optimized container: {NEURON_CONFIGURATION["supported_instance_families"]} |
113 | | - """ |
114 | | - raise ValueError(textwrap.dedent(invalid_instance_type_msg)) |
115 | | - |
116 | | - optimization_combination = OptimizationCombination( |
| 130 | + invalid_instance_type_msg = ( |
| 131 | + f"Optimizations that use {instance_type} are not currently supported" |
| 132 | + ) |
| 133 | + raise ValueError(invalid_instance_type_msg) |
| 134 | + |
| 135 | + optimization_combination = _OptimizationCombination( |
117 | 136 | compilation=not compilation_config, |
118 | 137 | speculative_decoding=not speculative_decoding_config, |
119 | 138 | sharding=not sharding_config, |
120 | | - quantization_technique={quantization_config.get("OPTION_QUANTIZE") if quantization_config else None} |
| 139 | + quantization_technique={ |
| 140 | + quantization_config.get("OPTION_QUANTIZE") if quantization_config else None |
| 141 | + }, |
121 | 142 | ) |
122 | 143 |
|
123 | 144 | if instance_type in NEURON_CONFIGURATION["supported_instance_families"]: |
124 | 145 | try: |
125 | 146 | ( |
126 | | - NEURON_CONFIGURATION["optimization_combination"] |
127 | | - .validate_against(optimization_combination, rule_set=OptimizationContainer.VLLM) |
| 147 | + NEURON_CONFIGURATION["optimization_combination"].validate_against( |
| 148 | + optimization_combination, rule_set=_OptimizationContainer.VLLM |
| 149 | + ) |
128 | 150 | ) |
129 | 151 | except ValueError as neuron_compare_error: |
130 | 152 | raise ValueError( |
131 | 153 | VALIDATION_ERROR_MSG.format( |
132 | | - optimization_container=OptimizationContainer.NEURON.value, |
133 | | - instance_type=instance_type, |
134 | | - validation_error=neuron_compare_error |
| 154 | + optimization_container=str(neuron_compare_error), |
| 155 | + instance_type="Neuron", |
135 | 156 | ) |
136 | 157 | ) |
137 | 158 | else: |
138 | 159 | try: |
139 | 160 | ( |
140 | | - TRT_CONFIGURATION["optimization_combination"] |
141 | | - .validate_against(optimization_combination, rule_set=OptimizationContainer.TRT) |
| 161 | + TRT_CONFIGURATION["optimization_combination"].validate_against( |
| 162 | + optimization_combination, rule_set=_OptimizationContainer.TRT |
| 163 | + ) |
142 | 164 | ) |
143 | 165 | except ValueError as trt_compare_error: |
144 | 166 | try: |
145 | 167 | ( |
146 | | - VLLM_CONFIGURATION["optimization_combination"] |
147 | | - .validate_against(optimization_combination, rule_set=OptimizationContainer.VLLM) |
| 168 | + VLLM_CONFIGURATION["optimization_combination"].validate_against( |
| 169 | + optimization_combination, rule_set=_OptimizationContainer.VLLM |
| 170 | + ) |
148 | 171 | ) |
149 | 172 | except ValueError as vllm_compare_error: |
150 | 173 | trt_error_msg = VALIDATION_ERROR_MSG.format( |
151 | | - optimization_container=OptimizationContainer.TRT.value, |
152 | | - instance_type=instance_type, |
153 | | - validation_error=trt_compare_error |
| 174 | + optimization_container=str(trt_compare_error), instance_type="GPU" |
154 | 175 | ) |
155 | 176 | vllm_error_msg = VALIDATION_ERROR_MSG.format( |
156 | | - optimization_container=OptimizationContainer.VLLM.value, |
157 | | - instance_type=instance_type, |
158 | | - validation_error=vllm_compare_error |
| 177 | + optimization_container=str(vllm_compare_error), |
| 178 | + instance_type="GPU", |
159 | 179 | ) |
160 | 180 | joint_error_msg = f""" |
161 | | - The model cannot be optimized for the following reasons: |
| 181 | + Optimization cannot be performed for the following reasons: |
162 | 182 | - {trt_error_msg} |
163 | 183 | - {vllm_error_msg} |
164 | 184 | """ |
|
0 commit comments