1717from  contextlib  import  nullcontext 
1818
1919import  gguf 
20- from  gguf  import  GGMLQuantizationType  as  WeightType 
2120import  torch 
2221import  torch .nn  as  nn 
2322
3332can_use_cuda_kernels  =  torch .cuda .is_available () and  torch .cuda .get_device_capability ()[0 ] >=  7 
3433if  can_use_cuda_kernels  and  is_kernels_available ():
3534    from  kernels  import  get_kernel 
35+ 
3636    ops  =  get_kernel ("Isotr0py/ggml" )
3737else :
3838    ops  =  None 
3939
40- 
41- UNQUANTIZED_TYPES  =  {WeightType .F32 , WeightType .F16 , WeightType .BF16 }
40+ UNQUANTIZED_TYPES  =  {gguf .GGMLQuantizationType .F32 , gguf .GGMLQuantizationType .F16 , gguf .GGMLQuantizationType .BF16 }
4241STANDARD_QUANT_TYPES  =  {
43-     WeightType .Q4_0 ,
44-     WeightType .Q4_1 ,
45-     WeightType .Q5_0 ,
46-     WeightType .Q5_1 ,
47-     WeightType .Q8_0 ,
48-     WeightType .Q8_1 ,
42+     gguf . GGMLQuantizationType .Q4_0 ,
43+     gguf . GGMLQuantizationType .Q4_1 ,
44+     gguf . GGMLQuantizationType .Q5_0 ,
45+     gguf . GGMLQuantizationType .Q5_1 ,
46+     gguf . GGMLQuantizationType .Q8_0 ,
47+     gguf . GGMLQuantizationType .Q8_1 ,
4948}
5049KQUANT_TYPES  =  {
51-     WeightType .Q2_K ,
52-     WeightType .Q3_K ,
53-     WeightType .Q4_K ,
54-     WeightType .Q5_K ,
55-     WeightType .Q6_K ,
50+     gguf . GGMLQuantizationType .Q2_K ,
51+     gguf . GGMLQuantizationType .Q3_K ,
52+     gguf . GGMLQuantizationType .Q4_K ,
53+     gguf . GGMLQuantizationType .Q5_K ,
54+     gguf . GGMLQuantizationType .Q6_K ,
5655}
5756IMATRIX_QUANT_TYPES  =  {
58-     WeightType .IQ1_M ,
59-     WeightType .IQ1_S ,
60-     WeightType .IQ2_XXS ,
61-     WeightType .IQ2_XS ,
62-     WeightType .IQ2_S ,
63-     WeightType .IQ3_XXS ,
64-     WeightType .IQ3_S ,
65-     WeightType .IQ4_XS ,
66-     WeightType .IQ4_NL ,
57+     gguf . GGMLQuantizationType .IQ1_M ,
58+     gguf . GGMLQuantizationType .IQ1_S ,
59+     gguf . GGMLQuantizationType .IQ2_XXS ,
60+     gguf . GGMLQuantizationType .IQ2_XS ,
61+     gguf . GGMLQuantizationType .IQ2_S ,
62+     gguf . GGMLQuantizationType .IQ3_XXS ,
63+     gguf . GGMLQuantizationType .IQ3_S ,
64+     gguf . GGMLQuantizationType .IQ4_XS ,
65+     gguf . GGMLQuantizationType .IQ4_NL ,
6766}
6867# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. 
6968# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add 
7372MMQ_QUANT_TYPES  =  STANDARD_QUANT_TYPES  |  KQUANT_TYPES 
7473
7574
76- def  _fused_mul_mat_gguf (x : torch .Tensor , qweight : torch .Tensor ,
77-                         qweight_type : int ) ->  torch .Tensor :
75+ def  _fused_mul_mat_gguf (x : torch .Tensor , qweight : torch .Tensor , qweight_type : int ) ->  torch .Tensor :
7876    # there is no need to call any kernel for fp16/bf16 
7977    if  qweight_type  in  UNQUANTIZED_TYPES :
8078        return  x  @ qweight .T 
@@ -87,8 +85,8 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
8785    #     y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) 
8886    # elif qweight_type in MMQ_QUANT_TYPES: 
8987    #     y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) 
90-     # If there is no available MMQ kernel, fallback to dequantize 
9188
89+     # If there is no available MMQ kernel, fallback to dequantize 
9290    elif  qweight_type  in  DEQUANT_TYPES :
9391        block_size , type_size  =  gguf .GGML_QUANT_SIZES [qweight_type ]
9492        shape  =  (qweight .shape [0 ], qweight .shape [1 ] //  type_size  *  block_size )
@@ -98,9 +96,8 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
9896        # Raise an error if the quantization type is not supported. 
9997        # Might be useful if llama.cpp adds a new quantization type. 
10098        # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. 
101-         qweight_type  =  WeightType (qweight_type )
102-         raise  NotImplementedError (
103-             f"Unsupported GGUF quantization type: { qweight_type }  )
99+         qweight_type  =  gguf .GGMLQuantizationType (qweight_type )
100+         raise  NotImplementedError (f"Unsupported GGUF quantization type: { qweight_type }  )
104101    return  y 
105102
106103
@@ -528,25 +525,22 @@ def __init__(
528525        self .compute_dtype  =  compute_dtype 
529526        self .device  =  device 
530527
531-     def  forward (self , inputs ):
528+     def  forward (self , inputs :  torch . Tensor ):
532529        if  ops  is  not None  and  self .weight .is_cuda  and  inputs .is_cuda :
533530            return  self .forward_cuda (inputs )
534531        return  self .forward_native (inputs )
535532
536-     def  forward_native (self , inputs ):
533+     def  forward_native (self , inputs :  torch . Tensor ):
537534        weight  =  dequantize_gguf_tensor (self .weight )
538535        weight  =  weight .to (self .compute_dtype )
539536        bias  =  self .bias .to (self .compute_dtype ) if  self .bias  is  not None  else  None 
540537
541538        output  =  torch .nn .functional .linear (inputs , weight , bias )
542539        return  output 
543540
544-     def  forward_cuda (self , inputs ):
541+     def  forward_cuda (self , inputs :  torch . Tensor ):
545542        quant_type  =  self .weight .quant_type 
546-         orig_shape  =  inputs .shape 
547-         inputs  =  inputs .view (- 1 , orig_shape [- 1 ])
548543        output  =  _fused_mul_mat_gguf (inputs .to (self .compute_dtype ), self .weight , quant_type )
549544        if  self .bias  is  not None :
550-             output  =  output  +  self .bias .to (self .compute_dtype )
551-         return  output .view (* orig_shape [:- 1 ], - 1 )
552- 
545+             output  +=  self .bias .to (self .compute_dtype )
546+         return  output 
0 commit comments