Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions src/diffusers/quantizers/gguf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 7
)
is_int8_tensor_core_available = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8

if can_use_cuda_kernels and is_kernels_available():
from kernels import get_kernel

ops = get_kernel("Isotr0py/ggml")
ops = get_kernel("Isotr0py/ggml", revision="mma-standard")
else:
ops = None

Expand Down Expand Up @@ -69,29 +71,23 @@
gguf.GGMLQuantizationType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# Consolidate DEQUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES


def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T

# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
# contiguous batching and inefficient with diffusers' batching,
# so we disabled it now.

# elif qweight_type in MMVQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# elif qweight_type in MMQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])

# For best performance, we only use MMQ kernels with int8 MMA
# implementation for Ampere and newer architectures.
if qweight_type in MMQ_QUANT_TYPES and is_int8_tensor_core_available:
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
if qweight_type in DEQUANT_TYPES:
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
Expand Down