|
34 | 34 | and torch.cuda.is_available() |
35 | 35 | and torch.cuda.get_device_capability()[0] >= 7 |
36 | 36 | ) |
| 37 | +is_int8_tensor_core_available = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 |
| 38 | + |
37 | 39 | if can_use_cuda_kernels and is_kernels_available(): |
38 | 40 | from kernels import get_kernel |
39 | 41 |
|
40 | | - ops = get_kernel("Isotr0py/ggml") |
| 42 | + ops = get_kernel("Isotr0py/ggml", revision="mma-standard") |
41 | 43 | else: |
42 | 44 | ops = None |
43 | 45 |
|
@@ -81,17 +83,12 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: in |
81 | 83 | if qweight_type in UNQUANTIZED_TYPES: |
82 | 84 | return x @ qweight.T |
83 | 85 |
|
84 | | - # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for |
85 | | - # contiguous batching and inefficient with diffusers' batching, |
86 | | - # so we disabled it now. |
87 | | - |
88 | | - # elif qweight_type in MMVQ_QUANT_TYPES: |
89 | | - # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) |
90 | | - # elif qweight_type in MMQ_QUANT_TYPES: |
91 | | - # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) |
92 | | - |
| 86 | + # For best performance, we only use MMQ kernels with int8 MMA |
| 87 | + # implementation for Ampere and newer architectures. |
| 88 | + if qweight_type in MMQ_QUANT_TYPES and is_int8_tensor_core_available: |
| 89 | + y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) |
93 | 90 | # If there is no available MMQ kernel, fallback to dequantize |
94 | | - if qweight_type in DEQUANT_TYPES: |
| 91 | + elif qweight_type in DEQUANT_TYPES: |
95 | 92 | block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] |
96 | 93 | shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) |
97 | 94 | weight = ops.ggml_dequantize(qweight, qweight_type, *shape) |
|
0 commit comments