@@ -78,17 +78,21 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
7878    # there is no need to call any kernel for fp16/bf16 
7979    if  qweight_type  in  UNQUANTIZED_TYPES :
8080        return  x  @ qweight .T 
81-     # enable MMVQ in contiguous batching with batch_size=1 
82-     if  qweight_type  in  MMVQ_QUANT_TYPES :
83-         y  =  ops .ggml_mul_mat_vec_a8 (qweight , x , qweight_type , qweight .shape [0 ])
84-     # Use MMQ Kernel if it's available (standard + k-quants) 
85-     elif  qweight_type  in  MMQ_QUANT_TYPES :
86-         y  =  ops .ggml_mul_mat_a8 (qweight , x , qweight_type , qweight .shape [0 ])
81+ 
82+     # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for 
83+     # contiguous batching and inefficient with diffusers' batching, 
84+     # so we disabled it now. 
85+ 
86+     # elif qweight_type in MMVQ_QUANT_TYPES: 
87+     #     y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) 
88+     # elif qweight_type in MMQ_QUANT_TYPES: 
89+     #     y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) 
8790    # If there is no available MMQ kernel, fallback to dequantize 
91+ 
8892    elif  qweight_type  in  DEQUANT_TYPES :
8993        block_size , type_size  =  gguf .GGML_QUANT_SIZES [qweight_type ]
9094        shape  =  (qweight .shape [0 ], qweight .shape [1 ] //  type_size  *  block_size )
91-         weight  =  ops .ggml_dequantize (qweight , qweight_type , * shape ,  x . dtype )
95+         weight  =  ops .ggml_dequantize (qweight , qweight_type , * shape )
9296        y  =  x  @ weight .T 
9397    else :
9498        # Raise an error if the quantization type is not supported. 
@@ -539,5 +543,10 @@ def forward_native(self, inputs):
539543
540544    def  forward_cuda (self , inputs ):
541545        quant_type  =  self .weight .quant_type 
542-         return  _fused_mul_mat_gguf (inputs .to (self .compute_dtype ), self .weight , quant_type )
546+         orig_shape  =  inputs .shape 
547+         inputs  =  inputs .view (- 1 , orig_shape [- 1 ])
548+         output  =  _fused_mul_mat_gguf (inputs .to (self .compute_dtype ), self .weight , quant_type )
549+         if  self .bias  is  not None :
550+             output  =  output  +  self .bias .to (self .compute_dtype )
551+         return  output .view (* orig_shape [:- 1 ], - 1 )
543552
0 commit comments