diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 2fba9986e825..0e657c66ee9a 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -34,10 +34,12 @@ and torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7 ) +is_int8_tensor_core_available = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 + if can_use_cuda_kernels and is_kernels_available(): from kernels import get_kernel - ops = get_kernel("Isotr0py/ggml") + ops = get_kernel("Isotr0py/ggml", revision="mma-standard") else: ops = None @@ -69,11 +71,10 @@ gguf.GGMLQuantizationType.IQ4_NL, } # TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. -# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add +# Consolidate DEQUANT_TYPES and MMQ_QUANT_TYPES after we add # MMQ kernel for I-Matrix quantization. DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES -MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES -MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES +MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: @@ -81,17 +82,12 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: in if qweight_type in UNQUANTIZED_TYPES: return x @ qweight.T - # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for - # contiguous batching and inefficient with diffusers' batching, - # so we disabled it now. - - # elif qweight_type in MMVQ_QUANT_TYPES: - # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) - # elif qweight_type in MMQ_QUANT_TYPES: - # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) - + # For best performance, we only use MMQ kernels with int8 MMA + # implementation for Ampere and newer architectures. + if qweight_type in MMQ_QUANT_TYPES and is_int8_tensor_core_available: + y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) # If there is no available MMQ kernel, fallback to dequantize - if qweight_type in DEQUANT_TYPES: + elif qweight_type in DEQUANT_TYPES: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) weight = ops.ggml_dequantize(qweight, qweight_type, *shape)