@@ -41,18 +41,6 @@ class _OptimizationCombination(BaseModel):
4141
4242 def validate_against (self , optimization_combination , rule_set : _OptimizationContainer ):
4343 """Validator for optimization containers"""
44- print (optimization_combination )
45- print (rule_set )
46- print (optimization_combination .speculative_decoding .issubset (self .speculative_decoding ))
47-
48- # check the case where no optimization combination is provided
49- if (
50- optimization_combination .compilation == {None }
51- and optimization_combination .quantization_technique == {None }
52- and optimization_combination .speculative_decoding == {None }
53- and optimization_combination .sharding == {None }
54- ):
55- raise ValueError ("no optimization configurations" )
5644
5745 # check the validity of each individual field
5846 if not optimization_combination .compilation .issubset (self .compilation ):
@@ -68,17 +56,22 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
6856 raise ValueError ("Sharding" )
6957
7058 # optimization technique combinations that need to be validated
59+ if optimization_combination .compilation and optimization_combination .speculative_decoding :
60+ copy_compilation = optimization_combination .compilation .copy ()
61+ copy_speculative_decoding = optimization_combination .speculative_decoding .copy ()
62+ if (
63+ copy_compilation .pop () and copy_speculative_decoding .pop ()
64+ ): # Check that the 2 techniques are not None
65+ raise ValueError ("Compilation and Speculative Decoding together" )
66+
7167 if rule_set == _OptimizationContainer .TRT :
7268 if (
7369 optimization_combination .compilation
74- and optimization_combination .speculative_decoding
70+ and not optimization_combination .quantization_technique
71+ or not optimization_combination .compilation
72+ and optimization_combination .quantization_technique
7573 ):
76- copy_compilation = optimization_combination .compilation .copy ()
77- copy_speculative_decoding = optimization_combination .speculative_decoding .copy ()
78- if (
79- copy_compilation .pop () and copy_speculative_decoding .pop ()
80- ): # Check that the 2 techniques are not None
81- raise ValueError ("Compilation and Speculative Decoding" )
74+ raise ValueError ("Compilation must be provided with Quantization" )
8275 else :
8376 copy_compilation = optimization_combination .compilation .copy ()
8477 copy_quantization_technique = optimization_combination .quantization_technique .copy ()
@@ -106,7 +99,7 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
10699 "supported_instance_families" : {"p4d" , "p4de" , "p5" , "g5" , "g6" },
107100 "optimization_combination" : _OptimizationCombination (
108101 optimization_container = _OptimizationContainer .VLLM ,
109- compilation = FALSY_SET ,
102+ compilation = TRUTHY_SET ,
110103 quantization_technique = {None , "awq" , "fp8" },
111104 speculative_decoding = TRUTHY_SET ,
112105 sharding = TRUTHY_SET ,
@@ -123,11 +116,6 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
123116 ),
124117}
125118
126- VALIDATION_ERROR_MSG = (
127- "Optimizations for {optimization_container} that use {optimization_technique} "
128- "are not currently supported on {instance_type} instances"
129- )
130-
131119
132120def _validate_optimization_configuration (
133121 instance_type : str ,
@@ -150,7 +138,8 @@ def _validate_optimization_configuration(
150138 and instance_family not in NEURON_CONFIGURATION ["supported_instance_families" ]
151139 ):
152140 invalid_instance_type_msg = (
153- f"Optimizations that uses { instance_type } instance type are not currently supported"
141+ f"Optimizations that uses { instance_type } instance type are "
142+ "not currently supported both on GPU and Neuron instances"
154143 )
155144 raise ValueError (invalid_instance_type_msg )
156145
@@ -166,13 +155,26 @@ def _validate_optimization_configuration(
166155
167156 optimization_combination = _OptimizationCombination (
168157 compilation = {None if compilation_config is None else True },
169- speculative_decoding = {
170- None if speculative_decoding_config is None else True
171- },
158+ speculative_decoding = {None if speculative_decoding_config is None else True },
172159 sharding = {None if sharding_config is None else True },
173160 quantization_technique = {quantization_technique },
174161 )
175162
163+ # Check the case where no optimization combination is provided
164+ if (
165+ optimization_combination .compilation == {None }
166+ and optimization_combination .quantization_technique == {None }
167+ and optimization_combination .speculative_decoding == {None }
168+ and optimization_combination .sharding == {None }
169+ ):
170+ raise ValueError (
171+ (
172+ "Optimizations that provide no optimization configs "
173+ "are currently not support on both GPU and Neuron instances."
174+ )
175+ )
176+
177+ # Validate based off of instance type
176178 if instance_family in NEURON_CONFIGURATION ["supported_instance_families" ]:
177179 try :
178180 (
@@ -182,11 +184,7 @@ def _validate_optimization_configuration(
182184 )
183185 except ValueError as neuron_compare_error :
184186 raise ValueError (
185- VALIDATION_ERROR_MSG .format (
186- optimization_container = _OptimizationContainer .NEURON .value ,
187- optimization_technique = str (neuron_compare_error ),
188- instance_type = "Neuron" ,
189- )
187+ f"Optimizations that use { neuron_compare_error } are not supported on Neuron instances."
190188 )
191189 else :
192190 try :
@@ -203,19 +201,16 @@ def _validate_optimization_configuration(
203201 )
204202 )
205203 except ValueError as vllm_compare_error :
206- trt_error_msg = VALIDATION_ERROR_MSG .format (
207- optimization_container = _OptimizationContainer .TRT .value ,
208- optimization_technique = str (trt_compare_error ),
209- instance_type = "GPU" ,
210- )
211- vllm_error_msg = VALIDATION_ERROR_MSG .format (
212- optimization_container = _OptimizationContainer .VLLM .value ,
213- optimization_technique = str (vllm_compare_error ),
214- instance_type = "GPU" ,
215- )
216- joint_error_msg = f"""
217- Optimization cannot be performed for the following reasons:
218- - { trt_error_msg }
219- - { vllm_error_msg }
220- """
204+ if trt_compare_error == "Compilation must be provided with Quantization" :
205+ joint_error_msg = f"""
206+ Optimization cannot be performed for the following reasons:
207+ - Optimizations that use { trt_compare_error } and vice-versa for GPU instances.
208+ - Optimizations that use { vllm_compare_error } are not supported for GPU instances.
209+ """
210+ else :
211+ joint_error_msg = f"""
212+ Optimization cannot be performed for the following reasons:
213+ - Optimizations that use { trt_compare_error } are not supported for GPU instances.
214+ - Optimizations that use { vllm_compare_error } are not supported for GPU instances.
215+ """
221216 raise ValueError (textwrap .dedent (joint_error_msg ))
0 commit comments