|  | 
| 17 | 17 | from contextlib import nullcontext | 
| 18 | 18 | 
 | 
| 19 | 19 | import gguf | 
|  | 20 | +from gguf import GGMLQuantizationType as WeightType | 
| 20 | 21 | import torch | 
| 21 | 22 | import torch.nn as nn | 
| 22 | 23 | 
 | 
| 23 |  | -from ...utils import is_accelerate_available | 
|  | 24 | +from ...utils import is_accelerate_available, is_kernels_available | 
| 24 | 25 | 
 | 
| 25 | 26 | 
 | 
| 26 | 27 | if is_accelerate_available(): | 
|  | 
| 29 | 30 |     from accelerate.hooks import add_hook_to_module, remove_hook_from_module | 
| 30 | 31 | 
 | 
| 31 | 32 | 
 | 
|  | 33 | +can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7 | 
|  | 34 | +if can_use_cuda_kernels and is_kernels_available(): | 
|  | 35 | +    from kernels import get_kernel | 
|  | 36 | +    ops = get_kernel("Isotr0py/ggml") | 
|  | 37 | +else: | 
|  | 38 | +    ops = None | 
|  | 39 | + | 
|  | 40 | + | 
|  | 41 | +UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16} | 
|  | 42 | +STANDARD_QUANT_TYPES = { | 
|  | 43 | +    WeightType.Q4_0, | 
|  | 44 | +    WeightType.Q4_1, | 
|  | 45 | +    WeightType.Q5_0, | 
|  | 46 | +    WeightType.Q5_1, | 
|  | 47 | +    WeightType.Q8_0, | 
|  | 48 | +    WeightType.Q8_1, | 
|  | 49 | +} | 
|  | 50 | +KQUANT_TYPES = { | 
|  | 51 | +    WeightType.Q2_K, | 
|  | 52 | +    WeightType.Q3_K, | 
|  | 53 | +    WeightType.Q4_K, | 
|  | 54 | +    WeightType.Q5_K, | 
|  | 55 | +    WeightType.Q6_K, | 
|  | 56 | +} | 
|  | 57 | +IMATRIX_QUANT_TYPES = { | 
|  | 58 | +    WeightType.IQ1_M, | 
|  | 59 | +    WeightType.IQ1_S, | 
|  | 60 | +    WeightType.IQ2_XXS, | 
|  | 61 | +    WeightType.IQ2_XS, | 
|  | 62 | +    WeightType.IQ2_S, | 
|  | 63 | +    WeightType.IQ3_XXS, | 
|  | 64 | +    WeightType.IQ3_S, | 
|  | 65 | +    WeightType.IQ4_XS, | 
|  | 66 | +    WeightType.IQ4_NL, | 
|  | 67 | +} | 
|  | 68 | +# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. | 
|  | 69 | +# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add | 
|  | 70 | +# MMQ kernel for I-Matrix quantization. | 
|  | 71 | +DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES | 
|  | 72 | +MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES | 
|  | 73 | +MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | 
|  | 74 | + | 
|  | 75 | + | 
|  | 76 | +def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, | 
|  | 77 | +                        qweight_type: int) -> torch.Tensor: | 
|  | 78 | +    # there is no need to call any kernel for fp16/bf16 | 
|  | 79 | +    if qweight_type in UNQUANTIZED_TYPES: | 
|  | 80 | +        return x @ qweight.T | 
|  | 81 | +    # enable MMVQ in contiguous batching with batch_size=1 | 
|  | 82 | +    if qweight_type in MMVQ_QUANT_TYPES: | 
|  | 83 | +        y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) | 
|  | 84 | +    # Use MMQ Kernel if it's available (standard + k-quants) | 
|  | 85 | +    elif qweight_type in MMQ_QUANT_TYPES: | 
|  | 86 | +        y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) | 
|  | 87 | +    # If there is no available MMQ kernel, fallback to dequantize | 
|  | 88 | +    elif qweight_type in DEQUANT_TYPES: | 
|  | 89 | +        block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] | 
|  | 90 | +        shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) | 
|  | 91 | +        weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype) | 
|  | 92 | +        y = x @ weight.T | 
|  | 93 | +    else: | 
|  | 94 | +        # Raise an error if the quantization type is not supported. | 
|  | 95 | +        # Might be useful if llama.cpp adds a new quantization type. | 
|  | 96 | +        # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. | 
|  | 97 | +        qweight_type = WeightType(qweight_type) | 
|  | 98 | +        raise NotImplementedError( | 
|  | 99 | +            f"Unsupported GGUF quantization type: {qweight_type}") | 
|  | 100 | +    return y | 
|  | 101 | + | 
|  | 102 | + | 
| 32 | 103 | # Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook | 
| 33 | 104 | def _create_accelerate_new_hook(old_hook): | 
| 34 | 105 |     r""" | 
| @@ -451,11 +522,22 @@ def __init__( | 
| 451 | 522 |     ) -> None: | 
| 452 | 523 |         super().__init__(in_features, out_features, bias, device) | 
| 453 | 524 |         self.compute_dtype = compute_dtype | 
|  | 525 | +        self.device = device | 
| 454 | 526 | 
 | 
| 455 | 527 |     def forward(self, inputs): | 
|  | 528 | +        if ops is not None and self.weight.is_cuda and inputs.is_cuda: | 
|  | 529 | +            return self.forward_cuda(inputs) | 
|  | 530 | +        return self.forward_native(inputs) | 
|  | 531 | + | 
|  | 532 | +    def forward_native(self, inputs): | 
| 456 | 533 |         weight = dequantize_gguf_tensor(self.weight) | 
| 457 | 534 |         weight = weight.to(self.compute_dtype) | 
| 458 | 535 |         bias = self.bias.to(self.compute_dtype) if self.bias is not None else None | 
| 459 | 536 | 
 | 
| 460 | 537 |         output = torch.nn.functional.linear(inputs, weight, bias) | 
| 461 | 538 |         return output | 
|  | 539 | + | 
|  | 540 | +    def forward_cuda(self, inputs): | 
|  | 541 | +        quant_type = self.weight.quant_type | 
|  | 542 | +        return _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type) | 
|  | 543 | + | 
0 commit comments