25
25
class _OptimizationContainer (Enum ):
26
26
"""Optimization containers"""
27
27
28
- TRT = "trt "
29
- VLLM = "vllm "
30
- NEURON = "neuron "
28
+ TRT = "TRT "
29
+ VLLM = "vLLM "
30
+ NEURON = "Neuron "
31
31
32
32
33
33
class _OptimizationCombination (BaseModel ):
34
34
"""Optimization ruleset data structure for comparing input to ruleset"""
35
35
36
36
optimization_container : _OptimizationContainer = None
37
- compilation : bool
38
- speculative_decoding : bool
39
- sharding : bool
37
+ compilation : Set [ bool | None ]
38
+ speculative_decoding : Set [ bool | None ]
39
+ sharding : Set [ bool | None ]
40
40
quantization_technique : Set [str | None ]
41
41
42
42
def validate_against (self , optimization_combination , rule_set : _OptimizationContainer ):
43
43
"""Validator for optimization containers"""
44
44
45
- if not optimization_combination .compilation == self .compilation :
45
+ # check the case where no optimization combination is provided
46
+ if (
47
+ optimization_combination .compilation == {None }
48
+ and optimization_combination .quantization_technique == {None }
49
+ and optimization_combination .speculative_decoding == {None }
50
+ and optimization_combination .sharding == {None }
51
+ ):
52
+ raise ValueError ("Optimizations are not currently supported without optimization configurations." )
53
+
54
+ # check the validity of each individual field
55
+ if not optimization_combination .compilation .issubset (self .compilation ):
46
56
raise ValueError ("Compilation" )
47
57
if not optimization_combination .quantization_technique .issubset (
48
58
self .quantization_technique
49
59
):
60
+ copy_quantization_technique = optimization_combination .quantization_technique .copy ()
50
61
raise ValueError (
51
- f"Quantization:{ optimization_combination . quantization_technique .pop ()} "
62
+ f"Quantization:{ copy_quantization_technique .pop ()} "
52
63
)
53
- if not optimization_combination .speculative_decoding == self .speculative_decoding :
64
+ if not optimization_combination .speculative_decoding . issubset ( self .speculative_decoding ) :
54
65
raise ValueError ("Speculative Decoding" )
55
- if not optimization_combination .sharding == self .sharding :
66
+ if not optimization_combination .sharding . issubset ( self .sharding ) :
56
67
raise ValueError ("Sharding" )
57
68
58
- if rule_set == _OptimizationContainer == _OptimizationContainer .TRT :
69
+ # optimization technique combinations that need to be validated
70
+ if rule_set == _OptimizationContainer .TRT :
59
71
if (
60
72
optimization_combination .compilation
61
73
and optimization_combination .speculative_decoding
62
74
):
63
- raise ValueError ("Compilation and Speculative Decoding" )
75
+ copy_compilation = optimization_combination .compilation .copy ()
76
+ copy_speculative_decoding = optimization_combination .speculative_decoding .copy ()
77
+ if (
78
+ copy_compilation .pop ()
79
+ and copy_speculative_decoding .pop ()
80
+ ): # Check that the 2 techniques are not None
81
+ raise ValueError ("Compilation and Speculative Decoding" )
64
82
else :
83
+ copy_compilation = optimization_combination .compilation .copy ()
84
+ copy_quantization_technique = optimization_combination .quantization_technique .copy ()
65
85
if (
66
- optimization_combination . compilation
67
- and optimization_combination . quantization_technique
68
- ):
86
+ copy_compilation . pop ()
87
+ and copy_quantization_technique . pop ()
88
+ ): # Check that the 2 techniques are not None
69
89
raise ValueError (
70
90
f"Compilation and Quantization:{ optimization_combination .quantization_technique .pop ()} "
71
91
)
72
92
73
93
94
+ TRUTHY_SET = {None , True }
95
+ FALSY_SET = {None , False }
74
96
TRT_CONFIGURATION = {
75
97
"supported_instance_families" : {"p4d" , "p4de" , "p5" , "g5" , "g6" },
76
98
"optimization_combination" : _OptimizationCombination (
77
99
optimization_container = _OptimizationContainer .TRT ,
78
- compilation = True ,
79
- quantization_technique = {"awq" , "fp8" , "smooth_quant" },
80
- speculative_decoding = False ,
81
- sharding = False ,
100
+ compilation = TRUTHY_SET ,
101
+ quantization_technique = {None , "awq" , "fp8" , "smooth_quant" },
102
+ speculative_decoding = FALSY_SET ,
103
+ sharding = FALSY_SET ,
82
104
),
83
105
}
84
106
VLLM_CONFIGURATION = {
85
107
"supported_instance_families" : {"p4d" , "p4de" , "p5" , "g5" , "g6" },
86
108
"optimization_combination" : _OptimizationCombination (
87
109
optimization_container = _OptimizationContainer .VLLM ,
88
- compilation = False ,
89
- quantization_technique = {"awq" , "fp8" },
90
- speculative_decoding = True ,
91
- sharding = True ,
110
+ compilation = FALSY_SET ,
111
+ quantization_technique = {None , "awq" , "fp8" },
112
+ speculative_decoding = TRUTHY_SET ,
113
+ sharding = TRUTHY_SET ,
92
114
),
93
115
}
94
116
NEURON_CONFIGURATION = {
95
117
"supported_instance_families" : {"inf2" , "trn1" , "trn1n" },
96
118
"optimization_combination" : _OptimizationCombination (
97
119
optimization_container = _OptimizationContainer .NEURON ,
98
- compilation = True ,
99
- quantization_technique = set () ,
100
- speculative_decoding = False ,
101
- sharding = False ,
120
+ compilation = TRUTHY_SET ,
121
+ quantization_technique = { None } ,
122
+ speculative_decoding = FALSY_SET ,
123
+ sharding = FALSY_SET ,
102
124
),
103
125
}
104
126
105
127
VALIDATION_ERROR_MSG = (
106
- "Optimizations that use {optimization_technique} "
128
+ "Optimizations for {optimization_container} that use {optimization_technique} "
107
129
"are not currently supported on {instance_type} instances"
108
130
)
109
131
@@ -117,28 +139,41 @@ def _validate_optimization_configuration(
117
139
):
118
140
"""Validate .optimize() input off of standard ruleset"""
119
141
120
- split_instance_type = instance_type .split ("." )
121
142
instance_family = None
122
- if len (split_instance_type ) == 3 : # invalid instance type will be caught below
123
- instance_family = split_instance_type [1 ]
143
+ if instance_type :
144
+ split_instance_type = instance_type .split ("." )
145
+ if len (split_instance_type ) == 3 :
146
+ instance_family = split_instance_type [1 ]
124
147
125
148
if (
126
149
instance_family not in TRT_CONFIGURATION ["supported_instance_families" ]
127
150
and instance_family not in VLLM_CONFIGURATION ["supported_instance_families" ]
128
151
and instance_family not in NEURON_CONFIGURATION ["supported_instance_families" ]
129
152
):
130
153
invalid_instance_type_msg = (
131
- f"Optimizations that use { instance_type } are not currently supported"
154
+ f"Optimizations that uses { instance_type } instance type are not currently supported"
132
155
)
133
156
raise ValueError (invalid_instance_type_msg )
134
157
158
+ quantization_technique = None
159
+ if (
160
+ quantization_config
161
+ and quantization_config .get ("OverrideEnvironment" )
162
+ and quantization_config .get ("OverrideEnvironment" ).get ("OPTION_QUANTIZE" )
163
+ ):
164
+ quantization_technique = quantization_config .get ("OverrideEnvironment" ).get ("OPTION_QUANTIZE" )
165
+
135
166
optimization_combination = _OptimizationCombination (
136
- compilation = not compilation_config ,
137
- speculative_decoding = not speculative_decoding_config ,
138
- sharding = not sharding_config ,
139
- quantization_technique = {
140
- quantization_config .get ("OPTION_QUANTIZE" ) if quantization_config else None
167
+ compilation = {
168
+ None if compilation_config is None else bool (compilation_config )
169
+ },
170
+ speculative_decoding = {
171
+ None if speculative_decoding_config is None else bool (speculative_decoding_config )
172
+ },
173
+ sharding = {
174
+ None if sharding_config is None else bool (sharding_config )
141
175
},
176
+ quantization_technique = {quantization_technique },
142
177
)
143
178
144
179
if instance_type in NEURON_CONFIGURATION ["supported_instance_families" ]:
@@ -151,7 +186,8 @@ def _validate_optimization_configuration(
151
186
except ValueError as neuron_compare_error :
152
187
raise ValueError (
153
188
VALIDATION_ERROR_MSG .format (
154
- optimization_container = str (neuron_compare_error ),
189
+ optimization_container = _OptimizationContainer .NEURON .value ,
190
+ optimization_technique = str (neuron_compare_error ),
155
191
instance_type = "Neuron" ,
156
192
)
157
193
)
@@ -171,10 +207,13 @@ def _validate_optimization_configuration(
171
207
)
172
208
except ValueError as vllm_compare_error :
173
209
trt_error_msg = VALIDATION_ERROR_MSG .format (
174
- optimization_container = str (trt_compare_error ), instance_type = "GPU"
210
+ optimization_container = _OptimizationContainer .TRT .value ,
211
+ optimization_technique = str (trt_compare_error ),
212
+ instance_type = "GPU"
175
213
)
176
214
vllm_error_msg = VALIDATION_ERROR_MSG .format (
177
- optimization_container = str (vllm_compare_error ),
215
+ optimization_container = _OptimizationContainer .VLLM .value ,
216
+ optimization_technique = str (vllm_compare_error ),
178
217
instance_type = "GPU" ,
179
218
)
180
219
joint_error_msg = f"""
0 commit comments