@@ -41,18 +41,6 @@ class _OptimizationCombination(BaseModel):
41
41
42
42
def validate_against (self , optimization_combination , rule_set : _OptimizationContainer ):
43
43
"""Validator for optimization containers"""
44
- print (optimization_combination )
45
- print (rule_set )
46
- print (optimization_combination .speculative_decoding .issubset (self .speculative_decoding ))
47
-
48
- # check the case where no optimization combination is provided
49
- if (
50
- optimization_combination .compilation == {None }
51
- and optimization_combination .quantization_technique == {None }
52
- and optimization_combination .speculative_decoding == {None }
53
- and optimization_combination .sharding == {None }
54
- ):
55
- raise ValueError ("no optimization configurations" )
56
44
57
45
# check the validity of each individual field
58
46
if not optimization_combination .compilation .issubset (self .compilation ):
@@ -68,17 +56,22 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
68
56
raise ValueError ("Sharding" )
69
57
70
58
# optimization technique combinations that need to be validated
59
+ if optimization_combination .compilation and optimization_combination .speculative_decoding :
60
+ copy_compilation = optimization_combination .compilation .copy ()
61
+ copy_speculative_decoding = optimization_combination .speculative_decoding .copy ()
62
+ if (
63
+ copy_compilation .pop () and copy_speculative_decoding .pop ()
64
+ ): # Check that the 2 techniques are not None
65
+ raise ValueError ("Compilation and Speculative Decoding together" )
66
+
71
67
if rule_set == _OptimizationContainer .TRT :
72
68
if (
73
69
optimization_combination .compilation
74
- and optimization_combination .speculative_decoding
70
+ and not optimization_combination .quantization_technique
71
+ or not optimization_combination .compilation
72
+ and optimization_combination .quantization_technique
75
73
):
76
- copy_compilation = optimization_combination .compilation .copy ()
77
- copy_speculative_decoding = optimization_combination .speculative_decoding .copy ()
78
- if (
79
- copy_compilation .pop () and copy_speculative_decoding .pop ()
80
- ): # Check that the 2 techniques are not None
81
- raise ValueError ("Compilation and Speculative Decoding" )
74
+ raise ValueError ("Compilation must be provided with Quantization" )
82
75
else :
83
76
copy_compilation = optimization_combination .compilation .copy ()
84
77
copy_quantization_technique = optimization_combination .quantization_technique .copy ()
@@ -106,7 +99,7 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
106
99
"supported_instance_families" : {"p4d" , "p4de" , "p5" , "g5" , "g6" },
107
100
"optimization_combination" : _OptimizationCombination (
108
101
optimization_container = _OptimizationContainer .VLLM ,
109
- compilation = FALSY_SET ,
102
+ compilation = TRUTHY_SET ,
110
103
quantization_technique = {None , "awq" , "fp8" },
111
104
speculative_decoding = TRUTHY_SET ,
112
105
sharding = TRUTHY_SET ,
@@ -123,11 +116,6 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
123
116
),
124
117
}
125
118
126
- VALIDATION_ERROR_MSG = (
127
- "Optimizations for {optimization_container} that use {optimization_technique} "
128
- "are not currently supported on {instance_type} instances"
129
- )
130
-
131
119
132
120
def _validate_optimization_configuration (
133
121
instance_type : str ,
@@ -150,7 +138,8 @@ def _validate_optimization_configuration(
150
138
and instance_family not in NEURON_CONFIGURATION ["supported_instance_families" ]
151
139
):
152
140
invalid_instance_type_msg = (
153
- f"Optimizations that uses { instance_type } instance type are not currently supported"
141
+ f"Optimizations that uses { instance_type } instance type are "
142
+ "not currently supported both on GPU and Neuron instances"
154
143
)
155
144
raise ValueError (invalid_instance_type_msg )
156
145
@@ -166,13 +155,26 @@ def _validate_optimization_configuration(
166
155
167
156
optimization_combination = _OptimizationCombination (
168
157
compilation = {None if compilation_config is None else True },
169
- speculative_decoding = {
170
- None if speculative_decoding_config is None else True
171
- },
158
+ speculative_decoding = {None if speculative_decoding_config is None else True },
172
159
sharding = {None if sharding_config is None else True },
173
160
quantization_technique = {quantization_technique },
174
161
)
175
162
163
+ # Check the case where no optimization combination is provided
164
+ if (
165
+ optimization_combination .compilation == {None }
166
+ and optimization_combination .quantization_technique == {None }
167
+ and optimization_combination .speculative_decoding == {None }
168
+ and optimization_combination .sharding == {None }
169
+ ):
170
+ raise ValueError (
171
+ (
172
+ "Optimizations that provide no optimization configs "
173
+ "are currently not support on both GPU and Neuron instances."
174
+ )
175
+ )
176
+
177
+ # Validate based off of instance type
176
178
if instance_family in NEURON_CONFIGURATION ["supported_instance_families" ]:
177
179
try :
178
180
(
@@ -182,11 +184,7 @@ def _validate_optimization_configuration(
182
184
)
183
185
except ValueError as neuron_compare_error :
184
186
raise ValueError (
185
- VALIDATION_ERROR_MSG .format (
186
- optimization_container = _OptimizationContainer .NEURON .value ,
187
- optimization_technique = str (neuron_compare_error ),
188
- instance_type = "Neuron" ,
189
- )
187
+ f"Optimizations that use { neuron_compare_error } are not supported on Neuron instances."
190
188
)
191
189
else :
192
190
try :
@@ -203,19 +201,16 @@ def _validate_optimization_configuration(
203
201
)
204
202
)
205
203
except ValueError as vllm_compare_error :
206
- trt_error_msg = VALIDATION_ERROR_MSG .format (
207
- optimization_container = _OptimizationContainer .TRT .value ,
208
- optimization_technique = str (trt_compare_error ),
209
- instance_type = "GPU" ,
210
- )
211
- vllm_error_msg = VALIDATION_ERROR_MSG .format (
212
- optimization_container = _OptimizationContainer .VLLM .value ,
213
- optimization_technique = str (vllm_compare_error ),
214
- instance_type = "GPU" ,
215
- )
216
- joint_error_msg = f"""
217
- Optimization cannot be performed for the following reasons:
218
- - { trt_error_msg }
219
- - { vllm_error_msg }
220
- """
204
+ if trt_compare_error == "Compilation must be provided with Quantization" :
205
+ joint_error_msg = f"""
206
+ Optimization cannot be performed for the following reasons:
207
+ - Optimizations that use { trt_compare_error } and vice-versa for GPU instances.
208
+ - Optimizations that use { vllm_compare_error } are not supported for GPU instances.
209
+ """
210
+ else :
211
+ joint_error_msg = f"""
212
+ Optimization cannot be performed for the following reasons:
213
+ - Optimizations that use { trt_compare_error } are not supported for GPU instances.
214
+ - Optimizations that use { vllm_compare_error } are not supported for GPU instances.
215
+ """
221
216
raise ValueError (textwrap .dedent (joint_error_msg ))
0 commit comments