Skip to content

Commit 66bd237

Browse files
committed
fix
Signed-off-by: Isotr0py <[email protected]>
1 parent 6c4d01d commit 66bd237

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

src/diffusers/quantizers/gguf/utils.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,21 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
7878
# there is no need to call any kernel for fp16/bf16
7979
if qweight_type in UNQUANTIZED_TYPES:
8080
return x @ qweight.T
81-
# enable MMVQ in contiguous batching with batch_size=1
82-
if qweight_type in MMVQ_QUANT_TYPES:
83-
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
84-
# Use MMQ Kernel if it's available (standard + k-quants)
85-
elif qweight_type in MMQ_QUANT_TYPES:
86-
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
81+
82+
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
83+
# contiguous batching and inefficient with diffusers' batching,
84+
# so we disabled it now.
85+
86+
# elif qweight_type in MMVQ_QUANT_TYPES:
87+
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
88+
# elif qweight_type in MMQ_QUANT_TYPES:
89+
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
8790
# If there is no available MMQ kernel, fallback to dequantize
91+
8892
elif qweight_type in DEQUANT_TYPES:
8993
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
9094
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
91-
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
95+
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
9296
y = x @ weight.T
9397
else:
9498
# Raise an error if the quantization type is not supported.
@@ -539,5 +543,10 @@ def forward_native(self, inputs):
539543

540544
def forward_cuda(self, inputs):
541545
quant_type = self.weight.quant_type
542-
return _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
546+
orig_shape = inputs.shape
547+
inputs = inputs.view(-1, orig_shape[-1])
548+
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
549+
if self.bias is not None:
550+
output = output + self.bias.to(self.compute_dtype)
551+
return output.view(*orig_shape[:-1], -1)
543552

0 commit comments

Comments
 (0)