@@ -78,17 +78,21 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
7878 # there is no need to call any kernel for fp16/bf16
7979 if qweight_type in UNQUANTIZED_TYPES :
8080 return x @ qweight .T
81- # enable MMVQ in contiguous batching with batch_size=1
82- if qweight_type in MMVQ_QUANT_TYPES :
83- y = ops .ggml_mul_mat_vec_a8 (qweight , x , qweight_type , qweight .shape [0 ])
84- # Use MMQ Kernel if it's available (standard + k-quants)
85- elif qweight_type in MMQ_QUANT_TYPES :
86- y = ops .ggml_mul_mat_a8 (qweight , x , qweight_type , qweight .shape [0 ])
81+
82+ # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
83+ # contiguous batching and inefficient with diffusers' batching,
84+ # so we disabled it now.
85+
86+ # elif qweight_type in MMVQ_QUANT_TYPES:
87+ # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
88+ # elif qweight_type in MMQ_QUANT_TYPES:
89+ # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
8790 # If there is no available MMQ kernel, fallback to dequantize
91+
8892 elif qweight_type in DEQUANT_TYPES :
8993 block_size , type_size = gguf .GGML_QUANT_SIZES [qweight_type ]
9094 shape = (qweight .shape [0 ], qweight .shape [1 ] // type_size * block_size )
91- weight = ops .ggml_dequantize (qweight , qweight_type , * shape , x . dtype )
95+ weight = ops .ggml_dequantize (qweight , qweight_type , * shape )
9296 y = x @ weight .T
9397 else :
9498 # Raise an error if the quantization type is not supported.
@@ -539,5 +543,10 @@ def forward_native(self, inputs):
539543
540544 def forward_cuda (self , inputs ):
541545 quant_type = self .weight .quant_type
542- return _fused_mul_mat_gguf (inputs .to (self .compute_dtype ), self .weight , quant_type )
546+ orig_shape = inputs .shape
547+ inputs = inputs .view (- 1 , orig_shape [- 1 ])
548+ output = _fused_mul_mat_gguf (inputs .to (self .compute_dtype ), self .weight , quant_type )
549+ if self .bias is not None :
550+ output = output + self .bias .to (self .compute_dtype )
551+ return output .view (* orig_shape [:- 1 ], - 1 )
543552
0 commit comments