|
12 | 12 | # # See the License for the specific language governing permissions and |
13 | 13 | # # limitations under the License. |
14 | 14 |
|
15 | | - |
16 | 15 | import inspect |
| 16 | +import os |
17 | 17 | from contextlib import nullcontext |
18 | 18 |
|
19 | 19 | import gguf |
20 | 20 | import torch |
21 | 21 | import torch.nn as nn |
22 | 22 |
|
23 | | -from ...utils import is_accelerate_available |
| 23 | +from ...utils import is_accelerate_available, is_kernels_available |
24 | 24 |
|
25 | 25 |
|
26 | 26 | if is_accelerate_available(): |
|
29 | 29 | from accelerate.hooks import add_hook_to_module, remove_hook_from_module |
30 | 30 |
|
31 | 31 |
|
| 32 | +can_use_cuda_kernels = ( |
| 33 | + os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"] |
| 34 | + and torch.cuda.is_available() |
| 35 | + and torch.cuda.get_device_capability()[0] >= 7 |
| 36 | +) |
| 37 | +if can_use_cuda_kernels and is_kernels_available(): |
| 38 | + from kernels import get_kernel |
| 39 | + |
| 40 | + ops = get_kernel("Isotr0py/ggml") |
| 41 | +else: |
| 42 | + ops = None |
| 43 | + |
| 44 | +UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16} |
| 45 | +STANDARD_QUANT_TYPES = { |
| 46 | + gguf.GGMLQuantizationType.Q4_0, |
| 47 | + gguf.GGMLQuantizationType.Q4_1, |
| 48 | + gguf.GGMLQuantizationType.Q5_0, |
| 49 | + gguf.GGMLQuantizationType.Q5_1, |
| 50 | + gguf.GGMLQuantizationType.Q8_0, |
| 51 | + gguf.GGMLQuantizationType.Q8_1, |
| 52 | +} |
| 53 | +KQUANT_TYPES = { |
| 54 | + gguf.GGMLQuantizationType.Q2_K, |
| 55 | + gguf.GGMLQuantizationType.Q3_K, |
| 56 | + gguf.GGMLQuantizationType.Q4_K, |
| 57 | + gguf.GGMLQuantizationType.Q5_K, |
| 58 | + gguf.GGMLQuantizationType.Q6_K, |
| 59 | +} |
| 60 | +IMATRIX_QUANT_TYPES = { |
| 61 | + gguf.GGMLQuantizationType.IQ1_M, |
| 62 | + gguf.GGMLQuantizationType.IQ1_S, |
| 63 | + gguf.GGMLQuantizationType.IQ2_XXS, |
| 64 | + gguf.GGMLQuantizationType.IQ2_XS, |
| 65 | + gguf.GGMLQuantizationType.IQ2_S, |
| 66 | + gguf.GGMLQuantizationType.IQ3_XXS, |
| 67 | + gguf.GGMLQuantizationType.IQ3_S, |
| 68 | + gguf.GGMLQuantizationType.IQ4_XS, |
| 69 | + gguf.GGMLQuantizationType.IQ4_NL, |
| 70 | +} |
| 71 | +# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. |
| 72 | +# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add |
| 73 | +# MMQ kernel for I-Matrix quantization. |
| 74 | +DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES |
| 75 | +MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES |
| 76 | +MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES |
| 77 | + |
| 78 | + |
| 79 | +def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: |
| 80 | + # there is no need to call any kernel for fp16/bf16 |
| 81 | + if qweight_type in UNQUANTIZED_TYPES: |
| 82 | + return x @ qweight.T |
| 83 | + |
| 84 | + # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for |
| 85 | + # contiguous batching and inefficient with diffusers' batching, |
| 86 | + # so we disabled it now. |
| 87 | + |
| 88 | + # elif qweight_type in MMVQ_QUANT_TYPES: |
| 89 | + # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) |
| 90 | + # elif qweight_type in MMQ_QUANT_TYPES: |
| 91 | + # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) |
| 92 | + |
| 93 | + # If there is no available MMQ kernel, fallback to dequantize |
| 94 | + if qweight_type in DEQUANT_TYPES: |
| 95 | + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] |
| 96 | + shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) |
| 97 | + weight = ops.ggml_dequantize(qweight, qweight_type, *shape) |
| 98 | + y = x @ weight.to(x.dtype).T |
| 99 | + else: |
| 100 | + # Raise an error if the quantization type is not supported. |
| 101 | + # Might be useful if llama.cpp adds a new quantization type. |
| 102 | + # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. |
| 103 | + qweight_type = gguf.GGMLQuantizationType(qweight_type) |
| 104 | + raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}") |
| 105 | + return y.as_tensor() |
| 106 | + |
| 107 | + |
32 | 108 | # Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook |
33 | 109 | def _create_accelerate_new_hook(old_hook): |
34 | 110 | r""" |
@@ -451,11 +527,24 @@ def __init__( |
451 | 527 | ) -> None: |
452 | 528 | super().__init__(in_features, out_features, bias, device) |
453 | 529 | self.compute_dtype = compute_dtype |
| 530 | + self.device = device |
| 531 | + |
| 532 | + def forward(self, inputs: torch.Tensor): |
| 533 | + if ops is not None and self.weight.is_cuda and inputs.is_cuda: |
| 534 | + return self.forward_cuda(inputs) |
| 535 | + return self.forward_native(inputs) |
454 | 536 |
|
455 | | - def forward(self, inputs): |
| 537 | + def forward_native(self, inputs: torch.Tensor): |
456 | 538 | weight = dequantize_gguf_tensor(self.weight) |
457 | 539 | weight = weight.to(self.compute_dtype) |
458 | 540 | bias = self.bias.to(self.compute_dtype) if self.bias is not None else None |
459 | 541 |
|
460 | 542 | output = torch.nn.functional.linear(inputs, weight, bias) |
461 | 543 | return output |
| 544 | + |
| 545 | + def forward_cuda(self, inputs: torch.Tensor): |
| 546 | + quant_type = self.weight.quant_type |
| 547 | + output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type) |
| 548 | + if self.bias is not None: |
| 549 | + output += self.bias.to(self.compute_dtype) |
| 550 | + return output |
0 commit comments