1717from contextlib import nullcontext
1818
1919import gguf
20- from gguf import GGMLQuantizationType as WeightType
2120import torch
2221import torch .nn as nn
2322
3332can_use_cuda_kernels = torch .cuda .is_available () and torch .cuda .get_device_capability ()[0 ] >= 7
3433if can_use_cuda_kernels and is_kernels_available ():
3534 from kernels import get_kernel
35+
3636 ops = get_kernel ("Isotr0py/ggml" )
3737else :
3838 ops = None
3939
40-
41- UNQUANTIZED_TYPES = {WeightType .F32 , WeightType .F16 , WeightType .BF16 }
40+ UNQUANTIZED_TYPES = {gguf .GGMLQuantizationType .F32 , gguf .GGMLQuantizationType .F16 , gguf .GGMLQuantizationType .BF16 }
4241STANDARD_QUANT_TYPES = {
43- WeightType .Q4_0 ,
44- WeightType .Q4_1 ,
45- WeightType .Q5_0 ,
46- WeightType .Q5_1 ,
47- WeightType .Q8_0 ,
48- WeightType .Q8_1 ,
42+ gguf . GGMLQuantizationType .Q4_0 ,
43+ gguf . GGMLQuantizationType .Q4_1 ,
44+ gguf . GGMLQuantizationType .Q5_0 ,
45+ gguf . GGMLQuantizationType .Q5_1 ,
46+ gguf . GGMLQuantizationType .Q8_0 ,
47+ gguf . GGMLQuantizationType .Q8_1 ,
4948}
5049KQUANT_TYPES = {
51- WeightType .Q2_K ,
52- WeightType .Q3_K ,
53- WeightType .Q4_K ,
54- WeightType .Q5_K ,
55- WeightType .Q6_K ,
50+ gguf . GGMLQuantizationType .Q2_K ,
51+ gguf . GGMLQuantizationType .Q3_K ,
52+ gguf . GGMLQuantizationType .Q4_K ,
53+ gguf . GGMLQuantizationType .Q5_K ,
54+ gguf . GGMLQuantizationType .Q6_K ,
5655}
5756IMATRIX_QUANT_TYPES = {
58- WeightType .IQ1_M ,
59- WeightType .IQ1_S ,
60- WeightType .IQ2_XXS ,
61- WeightType .IQ2_XS ,
62- WeightType .IQ2_S ,
63- WeightType .IQ3_XXS ,
64- WeightType .IQ3_S ,
65- WeightType .IQ4_XS ,
66- WeightType .IQ4_NL ,
57+ gguf . GGMLQuantizationType .IQ1_M ,
58+ gguf . GGMLQuantizationType .IQ1_S ,
59+ gguf . GGMLQuantizationType .IQ2_XXS ,
60+ gguf . GGMLQuantizationType .IQ2_XS ,
61+ gguf . GGMLQuantizationType .IQ2_S ,
62+ gguf . GGMLQuantizationType .IQ3_XXS ,
63+ gguf . GGMLQuantizationType .IQ3_S ,
64+ gguf . GGMLQuantizationType .IQ4_XS ,
65+ gguf . GGMLQuantizationType .IQ4_NL ,
6766}
6867# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
6968# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
7372MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
7473
7574
76- def _fused_mul_mat_gguf (x : torch .Tensor , qweight : torch .Tensor ,
77- qweight_type : int ) -> torch .Tensor :
75+ def _fused_mul_mat_gguf (x : torch .Tensor , qweight : torch .Tensor , qweight_type : int ) -> torch .Tensor :
7876 # there is no need to call any kernel for fp16/bf16
7977 if qweight_type in UNQUANTIZED_TYPES :
8078 return x @ qweight .T
@@ -87,8 +85,8 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
8785 # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
8886 # elif qweight_type in MMQ_QUANT_TYPES:
8987 # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
90- # If there is no available MMQ kernel, fallback to dequantize
9188
89+ # If there is no available MMQ kernel, fallback to dequantize
9290 elif qweight_type in DEQUANT_TYPES :
9391 block_size , type_size = gguf .GGML_QUANT_SIZES [qweight_type ]
9492 shape = (qweight .shape [0 ], qweight .shape [1 ] // type_size * block_size )
@@ -98,9 +96,8 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
9896 # Raise an error if the quantization type is not supported.
9997 # Might be useful if llama.cpp adds a new quantization type.
10098 # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
101- qweight_type = WeightType (qweight_type )
102- raise NotImplementedError (
103- f"Unsupported GGUF quantization type: { qweight_type } " )
99+ qweight_type = gguf .GGMLQuantizationType (qweight_type )
100+ raise NotImplementedError (f"Unsupported GGUF quantization type: { qweight_type } " )
104101 return y
105102
106103
@@ -528,25 +525,22 @@ def __init__(
528525 self .compute_dtype = compute_dtype
529526 self .device = device
530527
531- def forward (self , inputs ):
528+ def forward (self , inputs : torch . Tensor ):
532529 if ops is not None and self .weight .is_cuda and inputs .is_cuda :
533530 return self .forward_cuda (inputs )
534531 return self .forward_native (inputs )
535532
536- def forward_native (self , inputs ):
533+ def forward_native (self , inputs : torch . Tensor ):
537534 weight = dequantize_gguf_tensor (self .weight )
538535 weight = weight .to (self .compute_dtype )
539536 bias = self .bias .to (self .compute_dtype ) if self .bias is not None else None
540537
541538 output = torch .nn .functional .linear (inputs , weight , bias )
542539 return output
543540
544- def forward_cuda (self , inputs ):
541+ def forward_cuda (self , inputs : torch . Tensor ):
545542 quant_type = self .weight .quant_type
546- orig_shape = inputs .shape
547- inputs = inputs .view (- 1 , orig_shape [- 1 ])
548543 output = _fused_mul_mat_gguf (inputs .to (self .compute_dtype ), self .weight , quant_type )
549544 if self .bias is not None :
550- output = output + self .bias .to (self .compute_dtype )
551- return output .view (* orig_shape [:- 1 ], - 1 )
552-
545+ output += self .bias .to (self .compute_dtype )
546+ return output
0 commit comments