@@ -194,33 +194,12 @@ def get_kv_cache_torch_dtype(
194194 return torch_dtype
195195
196196
197- def get_kv_cache_quant_algo_dtype (quant_cfg : dict [str , Any ]) -> torch .dtype | None :
198- quant_method = quant_cfg .get ("quant_method" , "" )
199- if quant_method .startswith ("modelopt" ):
200- quantization_inner = quant_cfg .get ("quantization" , quant_cfg )
201- # Check if quant config is specified and use kv cache quant algo
202- kv_algo = quantization_inner .get ("kv_cache_quant_algo" ) or quant_cfg .get (
203- "kv_cache_quant_algo"
204- )
205- if isinstance (kv_algo , str ):
206- return STR_DTYPE_TO_TORCH_DTYPE [kv_algo .lower ()]
207- return None
208-
209-
210197def kv_cache_dtype_str_to_dtype (
211198 kv_cache_dtype : str , model_config : ModelConfig
212199) -> torch .dtype :
213- # Model config may not be specified for unit tests, default to float16
214- dtype = model_config .dtype if model_config else torch .half
215200 if kv_cache_dtype == "auto" :
216- hf_cfg = getattr (model_config , "hf_config" , None )
217- if hf_cfg is not None :
218- quant_cfg = getattr (hf_cfg , "quantization_config" , None )
219- if quant_cfg is not None :
220- kv_algo_dtype = get_kv_cache_quant_algo_dtype (quant_cfg )
221- return kv_algo_dtype if kv_algo_dtype is not None else dtype
222- return dtype
223-
201+ # Model config may not be specified for unit tests, default to float16
202+ return model_config .dtype if model_config else torch .half
224203 return STR_DTYPE_TO_TORCH_DTYPE [kv_cache_dtype ]
225204
226205
0 commit comments