2525class _OptimizationContainer (Enum ):
2626 """Optimization containers"""
2727
28- TRT = "trt "
29- VLLM = "vllm "
30- NEURON = "neuron "
28+ TRT = "TRT "
29+ VLLM = "vLLM "
30+ NEURON = "Neuron "
3131
3232
3333class _OptimizationCombination (BaseModel ):
3434 """Optimization ruleset data structure for comparing input to ruleset"""
3535
3636 optimization_container : _OptimizationContainer = None
37- compilation : bool
38- speculative_decoding : bool
39- sharding : bool
37+ compilation : Set [ bool | None ]
38+ speculative_decoding : Set [ bool | None ]
39+ sharding : Set [ bool | None ]
4040 quantization_technique : Set [str | None ]
4141
4242 def validate_against (self , optimization_combination , rule_set : _OptimizationContainer ):
4343 """Validator for optimization containers"""
4444
45- if not optimization_combination .compilation == self .compilation :
45+ # check the case where no optimization combination is provided
46+ if (
47+ optimization_combination .compilation == {None }
48+ and optimization_combination .quantization_technique == {None }
49+ and optimization_combination .speculative_decoding == {None }
50+ and optimization_combination .sharding == {None }
51+ ):
52+ raise ValueError ("Optimizations are not currently supported without optimization configurations." )
53+
54+ # check the validity of each individual field
55+ if not optimization_combination .compilation .issubset (self .compilation ):
4656 raise ValueError ("Compilation" )
4757 if not optimization_combination .quantization_technique .issubset (
4858 self .quantization_technique
4959 ):
60+ copy_quantization_technique = optimization_combination .quantization_technique .copy ()
5061 raise ValueError (
51- f"Quantization:{ optimization_combination . quantization_technique .pop ()} "
62+ f"Quantization:{ copy_quantization_technique .pop ()} "
5263 )
53- if not optimization_combination .speculative_decoding == self .speculative_decoding :
64+ if not optimization_combination .speculative_decoding . issubset ( self .speculative_decoding ) :
5465 raise ValueError ("Speculative Decoding" )
55- if not optimization_combination .sharding == self .sharding :
66+ if not optimization_combination .sharding . issubset ( self .sharding ) :
5667 raise ValueError ("Sharding" )
5768
58- if rule_set == _OptimizationContainer == _OptimizationContainer .TRT :
69+ # optimization technique combinations that need to be validated
70+ if rule_set == _OptimizationContainer .TRT :
5971 if (
6072 optimization_combination .compilation
6173 and optimization_combination .speculative_decoding
6274 ):
63- raise ValueError ("Compilation and Speculative Decoding" )
75+ copy_compilation = optimization_combination .compilation .copy ()
76+ copy_speculative_decoding = optimization_combination .speculative_decoding .copy ()
77+ if (
78+ copy_compilation .pop ()
79+ and copy_speculative_decoding .pop ()
80+ ): # Check that the 2 techniques are not None
81+ raise ValueError ("Compilation and Speculative Decoding" )
6482 else :
83+ copy_compilation = optimization_combination .compilation .copy ()
84+ copy_quantization_technique = optimization_combination .quantization_technique .copy ()
6585 if (
66- optimization_combination . compilation
67- and optimization_combination . quantization_technique
68- ):
86+ copy_compilation . pop ()
87+ and copy_quantization_technique . pop ()
88+ ): # Check that the 2 techniques are not None
6989 raise ValueError (
7090 f"Compilation and Quantization:{ optimization_combination .quantization_technique .pop ()} "
7191 )
7292
7393
94+ TRUTHY_SET = {None , True }
95+ FALSY_SET = {None , False }
7496TRT_CONFIGURATION = {
7597 "supported_instance_families" : {"p4d" , "p4de" , "p5" , "g5" , "g6" },
7698 "optimization_combination" : _OptimizationCombination (
7799 optimization_container = _OptimizationContainer .TRT ,
78- compilation = True ,
79- quantization_technique = {"awq" , "fp8" , "smooth_quant" },
80- speculative_decoding = False ,
81- sharding = False ,
100+ compilation = TRUTHY_SET ,
101+ quantization_technique = {None , "awq" , "fp8" , "smooth_quant" },
102+ speculative_decoding = FALSY_SET ,
103+ sharding = FALSY_SET ,
82104 ),
83105}
84106VLLM_CONFIGURATION = {
85107 "supported_instance_families" : {"p4d" , "p4de" , "p5" , "g5" , "g6" },
86108 "optimization_combination" : _OptimizationCombination (
87109 optimization_container = _OptimizationContainer .VLLM ,
88- compilation = False ,
89- quantization_technique = {"awq" , "fp8" },
90- speculative_decoding = True ,
91- sharding = True ,
110+ compilation = FALSY_SET ,
111+ quantization_technique = {None , "awq" , "fp8" },
112+ speculative_decoding = TRUTHY_SET ,
113+ sharding = TRUTHY_SET ,
92114 ),
93115}
94116NEURON_CONFIGURATION = {
95117 "supported_instance_families" : {"inf2" , "trn1" , "trn1n" },
96118 "optimization_combination" : _OptimizationCombination (
97119 optimization_container = _OptimizationContainer .NEURON ,
98- compilation = True ,
99- quantization_technique = set () ,
100- speculative_decoding = False ,
101- sharding = False ,
120+ compilation = TRUTHY_SET ,
121+ quantization_technique = { None } ,
122+ speculative_decoding = FALSY_SET ,
123+ sharding = FALSY_SET ,
102124 ),
103125}
104126
105127VALIDATION_ERROR_MSG = (
106- "Optimizations that use {optimization_technique} "
128+ "Optimizations for {optimization_container} that use {optimization_technique} "
107129 "are not currently supported on {instance_type} instances"
108130)
109131
@@ -117,28 +139,41 @@ def _validate_optimization_configuration(
117139):
118140 """Validate .optimize() input off of standard ruleset"""
119141
120- split_instance_type = instance_type .split ("." )
121142 instance_family = None
122- if len (split_instance_type ) == 3 : # invalid instance type will be caught below
123- instance_family = split_instance_type [1 ]
143+ if instance_type :
144+ split_instance_type = instance_type .split ("." )
145+ if len (split_instance_type ) == 3 :
146+ instance_family = split_instance_type [1 ]
124147
125148 if (
126149 instance_family not in TRT_CONFIGURATION ["supported_instance_families" ]
127150 and instance_family not in VLLM_CONFIGURATION ["supported_instance_families" ]
128151 and instance_family not in NEURON_CONFIGURATION ["supported_instance_families" ]
129152 ):
130153 invalid_instance_type_msg = (
131- f"Optimizations that use { instance_type } are not currently supported"
154+ f"Optimizations that uses { instance_type } instance type are not currently supported"
132155 )
133156 raise ValueError (invalid_instance_type_msg )
134157
158+ quantization_technique = None
159+ if (
160+ quantization_config
161+ and quantization_config .get ("OverrideEnvironment" )
162+ and quantization_config .get ("OverrideEnvironment" ).get ("OPTION_QUANTIZE" )
163+ ):
164+ quantization_technique = quantization_config .get ("OverrideEnvironment" ).get ("OPTION_QUANTIZE" )
165+
135166 optimization_combination = _OptimizationCombination (
136- compilation = not compilation_config ,
137- speculative_decoding = not speculative_decoding_config ,
138- sharding = not sharding_config ,
139- quantization_technique = {
140- quantization_config .get ("OPTION_QUANTIZE" ) if quantization_config else None
167+ compilation = {
168+ None if compilation_config is None else bool (compilation_config )
169+ },
170+ speculative_decoding = {
171+ None if speculative_decoding_config is None else bool (speculative_decoding_config )
172+ },
173+ sharding = {
174+ None if sharding_config is None else bool (sharding_config )
141175 },
176+ quantization_technique = {quantization_technique },
142177 )
143178
144179 if instance_type in NEURON_CONFIGURATION ["supported_instance_families" ]:
@@ -151,7 +186,8 @@ def _validate_optimization_configuration(
151186 except ValueError as neuron_compare_error :
152187 raise ValueError (
153188 VALIDATION_ERROR_MSG .format (
154- optimization_container = str (neuron_compare_error ),
189+ optimization_container = _OptimizationContainer .NEURON .value ,
190+ optimization_technique = str (neuron_compare_error ),
155191 instance_type = "Neuron" ,
156192 )
157193 )
@@ -171,10 +207,13 @@ def _validate_optimization_configuration(
171207 )
172208 except ValueError as vllm_compare_error :
173209 trt_error_msg = VALIDATION_ERROR_MSG .format (
174- optimization_container = str (trt_compare_error ), instance_type = "GPU"
210+ optimization_container = _OptimizationContainer .TRT .value ,
211+ optimization_technique = str (trt_compare_error ),
212+ instance_type = "GPU"
175213 )
176214 vllm_error_msg = VALIDATION_ERROR_MSG .format (
177- optimization_container = str (vllm_compare_error ),
215+ optimization_container = _OptimizationContainer .VLLM .value ,
216+ optimization_technique = str (vllm_compare_error ),
178217 instance_type = "GPU" ,
179218 )
180219 joint_error_msg = f"""
0 commit comments