|
17 | 17 | from contextlib import nullcontext |
18 | 18 |
|
19 | 19 | import gguf |
| 20 | +from gguf import GGMLQuantizationType as WeightType |
20 | 21 | import torch |
21 | 22 | import torch.nn as nn |
22 | 23 |
|
23 | | -from ...utils import is_accelerate_available |
| 24 | +from ...utils import is_accelerate_available, is_kernels_available |
24 | 25 |
|
25 | 26 |
|
26 | 27 | if is_accelerate_available(): |
|
29 | 30 | from accelerate.hooks import add_hook_to_module, remove_hook_from_module |
30 | 31 |
|
31 | 32 |
|
| 33 | +can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7 |
| 34 | +if can_use_cuda_kernels and is_kernels_available(): |
| 35 | + from kernels import get_kernel |
| 36 | + ops = get_kernel("Isotr0py/ggml") |
| 37 | +else: |
| 38 | + ops = None |
| 39 | + |
| 40 | + |
| 41 | +UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16} |
| 42 | +STANDARD_QUANT_TYPES = { |
| 43 | + WeightType.Q4_0, |
| 44 | + WeightType.Q4_1, |
| 45 | + WeightType.Q5_0, |
| 46 | + WeightType.Q5_1, |
| 47 | + WeightType.Q8_0, |
| 48 | + WeightType.Q8_1, |
| 49 | +} |
| 50 | +KQUANT_TYPES = { |
| 51 | + WeightType.Q2_K, |
| 52 | + WeightType.Q3_K, |
| 53 | + WeightType.Q4_K, |
| 54 | + WeightType.Q5_K, |
| 55 | + WeightType.Q6_K, |
| 56 | +} |
| 57 | +IMATRIX_QUANT_TYPES = { |
| 58 | + WeightType.IQ1_M, |
| 59 | + WeightType.IQ1_S, |
| 60 | + WeightType.IQ2_XXS, |
| 61 | + WeightType.IQ2_XS, |
| 62 | + WeightType.IQ2_S, |
| 63 | + WeightType.IQ3_XXS, |
| 64 | + WeightType.IQ3_S, |
| 65 | + WeightType.IQ4_XS, |
| 66 | + WeightType.IQ4_NL, |
| 67 | +} |
| 68 | +# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. |
| 69 | +# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add |
| 70 | +# MMQ kernel for I-Matrix quantization. |
| 71 | +DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES |
| 72 | +MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES |
| 73 | +MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES |
| 74 | + |
| 75 | + |
| 76 | +def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, |
| 77 | + qweight_type: int) -> torch.Tensor: |
| 78 | + # there is no need to call any kernel for fp16/bf16 |
| 79 | + if qweight_type in UNQUANTIZED_TYPES: |
| 80 | + return x @ qweight.T |
| 81 | + # enable MMVQ in contiguous batching with batch_size=1 |
| 82 | + if qweight_type in MMVQ_QUANT_TYPES: |
| 83 | + y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) |
| 84 | + # Use MMQ Kernel if it's available (standard + k-quants) |
| 85 | + elif qweight_type in MMQ_QUANT_TYPES: |
| 86 | + y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) |
| 87 | + # If there is no available MMQ kernel, fallback to dequantize |
| 88 | + elif qweight_type in DEQUANT_TYPES: |
| 89 | + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] |
| 90 | + shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) |
| 91 | + weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype) |
| 92 | + y = x @ weight.T |
| 93 | + else: |
| 94 | + # Raise an error if the quantization type is not supported. |
| 95 | + # Might be useful if llama.cpp adds a new quantization type. |
| 96 | + # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. |
| 97 | + qweight_type = WeightType(qweight_type) |
| 98 | + raise NotImplementedError( |
| 99 | + f"Unsupported GGUF quantization type: {qweight_type}") |
| 100 | + return y |
| 101 | + |
| 102 | + |
32 | 103 | # Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook |
33 | 104 | def _create_accelerate_new_hook(old_hook): |
34 | 105 | r""" |
@@ -451,11 +522,22 @@ def __init__( |
451 | 522 | ) -> None: |
452 | 523 | super().__init__(in_features, out_features, bias, device) |
453 | 524 | self.compute_dtype = compute_dtype |
| 525 | + self.device = device |
454 | 526 |
|
455 | 527 | def forward(self, inputs): |
| 528 | + if ops is not None and self.weight.is_cuda and inputs.is_cuda: |
| 529 | + return self.forward_cuda(inputs) |
| 530 | + return self.forward_native(inputs) |
| 531 | + |
| 532 | + def forward_native(self, inputs): |
456 | 533 | weight = dequantize_gguf_tensor(self.weight) |
457 | 534 | weight = weight.to(self.compute_dtype) |
458 | 535 | bias = self.bias.to(self.compute_dtype) if self.bias is not None else None |
459 | 536 |
|
460 | 537 | output = torch.nn.functional.linear(inputs, weight, bias) |
461 | 538 | return output |
| 539 | + |
| 540 | + def forward_cuda(self, inputs): |
| 541 | + quant_type = self.weight.quant_type |
| 542 | + return _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type) |
| 543 | + |
0 commit comments