@@ -251,12 +251,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
251251
252252 return [(self .map_tensor_name (name ), data_torch )]
253253
254- def extra_f32_tensors (self , name : str , new_name : str , bid : int | None , n_dims : int ) -> bool :
255- del name , new_name , bid , n_dims # unused
256-
257- return False
258-
259- def extra_f16_tensors (self , name : str , new_name : str , bid : int | None , n_dims : int ) -> bool :
254+ def tensor_force_quant (self , name : str , new_name : str , bid : int | None , n_dims : int ) -> gguf .GGMLQuantizationType | bool :
260255 del name , new_name , bid , n_dims # unused
261256
262257 return False
@@ -285,54 +280,46 @@ def prepare_tensors(self):
285280 for new_name , data in ((n , d .squeeze ().numpy ()) for n , d in self .modify_tensors (data_torch , name , bid )):
286281 data : np .ndarray # type hint
287282 n_dims = len (data .shape )
288- data_dtype = data .dtype
289- data_qtype : gguf .GGMLQuantizationType | None = None
290-
291- # when both are True, f32 should win
292- extra_f32 = self .extra_f32_tensors (name , new_name , bid , n_dims )
293- extra_f16 = self .extra_f16_tensors (name , new_name , bid , n_dims )
283+ data_qtype : gguf .GGMLQuantizationType | bool = self .tensor_force_quant (name , new_name , bid , n_dims )
294284
295285 # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
296- # Conditions should closely match those in llama_model_quantize_internal in llama.cpp
297- extra_f32 = any (cond for cond in (
298- extra_f32 ,
299- n_dims == 1 ,
300- new_name .endswith ("_norm.weight" ),
301- ))
286+ if n_dims <= 1 or new_name .endswith ("_norm.weight" ):
287+ data_qtype = gguf .GGMLQuantizationType .F32
302288
289+ # Conditions should closely match those in llama_model_quantize_internal in llama.cpp
303290 # Some tensor types are always in float32
304- extra_f32 = extra_f32 or any (self .match_model_tensor_name (new_name , key , bid ) for key in (
305- gguf .MODEL_TENSOR .FFN_GATE_INP ,
306- gguf .MODEL_TENSOR .POS_EMBD ,
307- gguf .MODEL_TENSOR .TOKEN_TYPES ,
308- ))
309-
310- # if f16 desired, convert any float32 2-dim weight tensors to float16
311- extra_f16 = any (cond for cond in (
312- extra_f16 ,
313- (name .endswith (".weight" ) and n_dims >= 2 ),
314- ))
315-
316- if self .ftype != gguf .LlamaFileType .ALL_F32 and extra_f16 and not extra_f32 :
317- if self .ftype == gguf .LlamaFileType .MOSTLY_BF16 :
318- data = gguf .quantize_bf16 (data )
319- assert data .dtype == np .uint16
320- data_qtype = gguf .GGMLQuantizationType .BF16
321-
322- elif self .ftype == gguf .LlamaFileType .MOSTLY_Q8_0 and gguf .can_quantize_to_q8_0 (data ):
323- data = gguf .quantize_q8_0 (data )
324- assert data .dtype == np .uint8
325- data_qtype = gguf .GGMLQuantizationType .Q8_0
291+ if data_qtype is False and (
292+ any (
293+ self .match_model_tensor_name (new_name , key , bid )
294+ for key in (
295+ gguf .MODEL_TENSOR .FFN_GATE_INP ,
296+ gguf .MODEL_TENSOR .POS_EMBD ,
297+ gguf .MODEL_TENSOR .TOKEN_TYPES ,
298+ )
299+ )
300+ or not name .endswith (".weight" )
301+ ):
302+ data_qtype = gguf .GGMLQuantizationType .F32
326303
327- else : # default to float16 for quantized tensors
328- if data_dtype != np .float16 :
329- data = data .astype (np .float16 )
304+ # No override (data_qtype is False), or wants to be quantized (data_qtype is True)
305+ if isinstance (data_qtype , bool ):
306+ if self .ftype == gguf .LlamaFileType .ALL_F32 :
307+ data_qtype = gguf .GGMLQuantizationType .F32
308+ elif self .ftype == gguf .LlamaFileType .MOSTLY_F16 :
330309 data_qtype = gguf .GGMLQuantizationType .F16
310+ elif self .ftype == gguf .LlamaFileType .MOSTLY_BF16 :
311+ data_qtype = gguf .GGMLQuantizationType .BF16
312+ elif self .ftype == gguf .LlamaFileType .MOSTLY_Q8_0 :
313+ data_qtype = gguf .GGMLQuantizationType .Q8_0
314+ else :
315+ raise ValueError (f"Unknown file type: { self .ftype .name } " )
331316
332- if data_qtype is None : # by default, convert to float32
333- if data_dtype != np .float32 :
334- data = data .astype (np .float32 )
335- data_qtype = gguf .GGMLQuantizationType .F32
317+ try :
318+ data = gguf .quants .quantize (data , data_qtype )
319+ except gguf .QuantError as e :
320+ logger .warning ("%s, %s" , e , "falling back to F16" )
321+ data_qtype = gguf .GGMLQuantizationType .F16
322+ data = gguf .quants .quantize (data , data_qtype )
336323
337324 shape = gguf .quant_shape_from_byte_shape (data .shape , data_qtype ) if data .dtype == np .uint8 else data .shape
338325
@@ -1765,7 +1752,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
17651752
17661753 return [(new_name , data_torch )]
17671754
1768- def extra_f16_tensors (self , name : str , new_name : str , bid : int | None , n_dims : int ) -> bool :
1755+ def tensor_force_quant (self , name : str , new_name : str , bid : int | None , n_dims : int ) -> gguf . GGMLQuantizationType | bool :
17691756 del name , new_name , bid # unused
17701757
17711758 return n_dims > 1
@@ -2786,18 +2773,22 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27862773
27872774 return [(new_name , data_torch )]
27882775
2789- def extra_f32_tensors (self , name : str , new_name : str , bid : int | None , n_dims : int ) -> bool :
2790- del n_dims # unused
2791-
2792- return bid is not None and new_name in (
2793- self .format_tensor_name (n , bid , ".weight" if name .endswith (".weight" ) else "" ) for n in [
2776+ def tensor_force_quant (self , name : str , new_name : str , bid : int | None , n_dims : int ) -> gguf .GGMLQuantizationType | bool :
2777+ if bid is not None and new_name in (
2778+ self .format_tensor_name (
2779+ n , bid , ".weight" if name .endswith (".weight" ) else ""
2780+ )
2781+ for n in [
27942782 gguf .MODEL_TENSOR .SSM_CONV1D ,
27952783 gguf .MODEL_TENSOR .SSM_X ,
27962784 gguf .MODEL_TENSOR .SSM_DT ,
27972785 gguf .MODEL_TENSOR .SSM_A ,
27982786 gguf .MODEL_TENSOR .SSM_D ,
27992787 ]
2800- )
2788+ ):
2789+ return gguf .GGMLQuantizationType .F32
2790+
2791+ return super ().tensor_force_quant (name , new_name , bid , n_dims )
28012792
28022793
28032794@Model .register ("CohereForCausalLM" )
0 commit comments