From 5b6c9883fcef69ab24c4ce6af8b17644c9b57fd7 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sun, 19 Oct 2025 17:37:09 +0800 Subject: [PATCH 1/4] update mma branch Signed-off-by: Isotr0py --- src/diffusers/quantizers/gguf/utils.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 2fba9986e825..398f5dfa29d0 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -34,10 +34,12 @@ and torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7 ) +is_int8_tensor_core_available = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 + if can_use_cuda_kernels and is_kernels_available(): from kernels import get_kernel - ops = get_kernel("Isotr0py/ggml") + ops = get_kernel("Isotr0py/ggml", revision="mma-standard") else: ops = None @@ -81,17 +83,12 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: in if qweight_type in UNQUANTIZED_TYPES: return x @ qweight.T - # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for - # contiguous batching and inefficient with diffusers' batching, - # so we disabled it now. - - # elif qweight_type in MMVQ_QUANT_TYPES: - # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) - # elif qweight_type in MMQ_QUANT_TYPES: - # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) - + # For best performance, we only use MMQ kernels with int8 MMA + # implementation for Ampere and newer architectures. + if qweight_type in MMQ_QUANT_TYPES and is_int8_tensor_core_available: + y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) # If there is no available MMQ kernel, fallback to dequantize - if qweight_type in DEQUANT_TYPES: + elif qweight_type in DEQUANT_TYPES: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) weight = ops.ggml_dequantize(qweight, qweight_type, *shape) From 942252ed1aa9b848b1ac873625510bed4d76c894 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sun, 19 Oct 2025 17:38:46 +0800 Subject: [PATCH 2/4] remove mmvq Signed-off-by: Isotr0py --- src/diffusers/quantizers/gguf/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 398f5dfa29d0..a9b473cc43a8 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -71,10 +71,9 @@ gguf.GGMLQuantizationType.IQ4_NL, } # TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. -# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add +# Consolidate DEQUANT_TYPES and MMQ_QUANT_TYPES after we add # MMQ kernel for I-Matrix quantization. DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES -MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES From 665aacdec2991235312dbc3d2f0f76b7cd758d2b Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sun, 19 Oct 2025 17:56:52 +0800 Subject: [PATCH 3/4] remove k-quant tempoarily Signed-off-by: Isotr0py --- run_gguf.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 run_gguf.py diff --git a/run_gguf.py b/run_gguf.py new file mode 100644 index 000000000000..fe10aa9f4c90 --- /dev/null +++ b/run_gguf.py @@ -0,0 +1,30 @@ +import torch + +from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig + +ckpt_path = ( + "/home/mozf/LLM/flux1-dev-Q4_0.gguf" +) +transformer = FluxTransformer2DModel.from_single_file( + ckpt_path, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), + torch_dtype=torch.bfloat16, +) +pipe = FluxPipeline.from_pretrained( + "/home/mozf/LLM/FLUX.1-dev", + transformer=transformer, + torch_dtype=torch.bfloat16, +).to("cuda") + + +# pipe.enable_model_cpu_offload() +pipe.transformer.to(memory_format=torch.channels_last) +pipe.transformer.compile(mode="reduce-overhead", fullgraph=True) + +prompt = "A cat holding a sign that says hello world" +image = pipe(prompt, generator=torch.manual_seed(0)).images[0] +# image.save("flux-gguf.png") + +prompt = "A cat holding a sign that says hello world" +image = pipe(prompt, generator=torch.manual_seed(0)).images[0] +image.save("flux-gguf.png") From 6897c60a15937c8703fec95b88e41f7cd993efa0 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 20 Oct 2025 00:30:20 +0800 Subject: [PATCH 4/4] ooops Signed-off-by: Isotr0py --- run_gguf.py | 30 -------------------------- src/diffusers/quantizers/gguf/utils.py | 2 +- 2 files changed, 1 insertion(+), 31 deletions(-) delete mode 100644 run_gguf.py diff --git a/run_gguf.py b/run_gguf.py deleted file mode 100644 index fe10aa9f4c90..000000000000 --- a/run_gguf.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch - -from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig - -ckpt_path = ( - "/home/mozf/LLM/flux1-dev-Q4_0.gguf" -) -transformer = FluxTransformer2DModel.from_single_file( - ckpt_path, - quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), - torch_dtype=torch.bfloat16, -) -pipe = FluxPipeline.from_pretrained( - "/home/mozf/LLM/FLUX.1-dev", - transformer=transformer, - torch_dtype=torch.bfloat16, -).to("cuda") - - -# pipe.enable_model_cpu_offload() -pipe.transformer.to(memory_format=torch.channels_last) -pipe.transformer.compile(mode="reduce-overhead", fullgraph=True) - -prompt = "A cat holding a sign that says hello world" -image = pipe(prompt, generator=torch.manual_seed(0)).images[0] -# image.save("flux-gguf.png") - -prompt = "A cat holding a sign that says hello world" -image = pipe(prompt, generator=torch.manual_seed(0)).images[0] -image.save("flux-gguf.png") diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index a9b473cc43a8..0e657c66ee9a 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -74,7 +74,7 @@ # Consolidate DEQUANT_TYPES and MMQ_QUANT_TYPES after we add # MMQ kernel for I-Matrix quantization. DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES -MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES +MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: