Skip to content

Commit 6c4d01d

Browse files
committed
add gguf kernel support
Signed-off-by: Isotr0py <[email protected]>
1 parent 425a715 commit 6c4d01d

File tree

3 files changed

+89
-1
lines changed

3 files changed

+89
-1
lines changed

src/diffusers/quantizers/gguf/utils.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
from contextlib import nullcontext
1818

1919
import gguf
20+
from gguf import GGMLQuantizationType as WeightType
2021
import torch
2122
import torch.nn as nn
2223

23-
from ...utils import is_accelerate_available
24+
from ...utils import is_accelerate_available, is_kernels_available
2425

2526

2627
if is_accelerate_available():
@@ -29,6 +30,76 @@
2930
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
3031

3132

33+
can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7
34+
if can_use_cuda_kernels and is_kernels_available():
35+
from kernels import get_kernel
36+
ops = get_kernel("Isotr0py/ggml")
37+
else:
38+
ops = None
39+
40+
41+
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
42+
STANDARD_QUANT_TYPES = {
43+
WeightType.Q4_0,
44+
WeightType.Q4_1,
45+
WeightType.Q5_0,
46+
WeightType.Q5_1,
47+
WeightType.Q8_0,
48+
WeightType.Q8_1,
49+
}
50+
KQUANT_TYPES = {
51+
WeightType.Q2_K,
52+
WeightType.Q3_K,
53+
WeightType.Q4_K,
54+
WeightType.Q5_K,
55+
WeightType.Q6_K,
56+
}
57+
IMATRIX_QUANT_TYPES = {
58+
WeightType.IQ1_M,
59+
WeightType.IQ1_S,
60+
WeightType.IQ2_XXS,
61+
WeightType.IQ2_XS,
62+
WeightType.IQ2_S,
63+
WeightType.IQ3_XXS,
64+
WeightType.IQ3_S,
65+
WeightType.IQ4_XS,
66+
WeightType.IQ4_NL,
67+
}
68+
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
69+
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
70+
# MMQ kernel for I-Matrix quantization.
71+
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
72+
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
73+
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
74+
75+
76+
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
77+
qweight_type: int) -> torch.Tensor:
78+
# there is no need to call any kernel for fp16/bf16
79+
if qweight_type in UNQUANTIZED_TYPES:
80+
return x @ qweight.T
81+
# enable MMVQ in contiguous batching with batch_size=1
82+
if qweight_type in MMVQ_QUANT_TYPES:
83+
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
84+
# Use MMQ Kernel if it's available (standard + k-quants)
85+
elif qweight_type in MMQ_QUANT_TYPES:
86+
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
87+
# If there is no available MMQ kernel, fallback to dequantize
88+
elif qweight_type in DEQUANT_TYPES:
89+
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
90+
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
91+
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
92+
y = x @ weight.T
93+
else:
94+
# Raise an error if the quantization type is not supported.
95+
# Might be useful if llama.cpp adds a new quantization type.
96+
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
97+
qweight_type = WeightType(qweight_type)
98+
raise NotImplementedError(
99+
f"Unsupported GGUF quantization type: {qweight_type}")
100+
return y
101+
102+
32103
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
33104
def _create_accelerate_new_hook(old_hook):
34105
r"""
@@ -451,11 +522,22 @@ def __init__(
451522
) -> None:
452523
super().__init__(in_features, out_features, bias, device)
453524
self.compute_dtype = compute_dtype
525+
self.device = device
454526

455527
def forward(self, inputs):
528+
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
529+
return self.forward_cuda(inputs)
530+
return self.forward_native(inputs)
531+
532+
def forward_native(self, inputs):
456533
weight = dequantize_gguf_tensor(self.weight)
457534
weight = weight.to(self.compute_dtype)
458535
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
459536

460537
output = torch.nn.functional.linear(inputs, weight, bias)
461538
return output
539+
540+
def forward_cuda(self, inputs):
541+
quant_type = self.weight.quant_type
542+
return _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
543+

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
is_hpu_available,
7777
is_inflect_available,
7878
is_invisible_watermark_available,
79+
is_kernels_available,
7980
is_k_diffusion_available,
8081
is_k_diffusion_version,
8182
is_librosa_available,

src/diffusers/utils/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
192192
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
193193
_transformers_available, _transformers_version = _is_package_available("transformers")
194194
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
195+
_kernels_available, _kernels_version = _is_package_available("kernels")
195196
_inflect_available, _inflect_version = _is_package_available("inflect")
196197
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
197198
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
@@ -274,6 +275,10 @@ def is_accelerate_available():
274275
return _accelerate_available
275276

276277

278+
def is_kernels_available():
279+
return _kernels_available
280+
281+
277282
def is_k_diffusion_available():
278283
return _k_diffusion_available
279284

0 commit comments

Comments
 (0)