Skip to content

Commit e46571a

Browse files
committed
optimize
Signed-off-by: Isotr0py <[email protected]>
1 parent 66bd237 commit e46571a

File tree

2 files changed

+32
-38
lines changed

2 files changed

+32
-38
lines changed

src/diffusers/quantizers/gguf/utils.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from contextlib import nullcontext
1818

1919
import gguf
20-
from gguf import GGMLQuantizationType as WeightType
2120
import torch
2221
import torch.nn as nn
2322

@@ -33,37 +32,37 @@
3332
can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7
3433
if can_use_cuda_kernels and is_kernels_available():
3534
from kernels import get_kernel
35+
3636
ops = get_kernel("Isotr0py/ggml")
3737
else:
3838
ops = None
3939

40-
41-
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
40+
UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
4241
STANDARD_QUANT_TYPES = {
43-
WeightType.Q4_0,
44-
WeightType.Q4_1,
45-
WeightType.Q5_0,
46-
WeightType.Q5_1,
47-
WeightType.Q8_0,
48-
WeightType.Q8_1,
42+
gguf.GGMLQuantizationType.Q4_0,
43+
gguf.GGMLQuantizationType.Q4_1,
44+
gguf.GGMLQuantizationType.Q5_0,
45+
gguf.GGMLQuantizationType.Q5_1,
46+
gguf.GGMLQuantizationType.Q8_0,
47+
gguf.GGMLQuantizationType.Q8_1,
4948
}
5049
KQUANT_TYPES = {
51-
WeightType.Q2_K,
52-
WeightType.Q3_K,
53-
WeightType.Q4_K,
54-
WeightType.Q5_K,
55-
WeightType.Q6_K,
50+
gguf.GGMLQuantizationType.Q2_K,
51+
gguf.GGMLQuantizationType.Q3_K,
52+
gguf.GGMLQuantizationType.Q4_K,
53+
gguf.GGMLQuantizationType.Q5_K,
54+
gguf.GGMLQuantizationType.Q6_K,
5655
}
5756
IMATRIX_QUANT_TYPES = {
58-
WeightType.IQ1_M,
59-
WeightType.IQ1_S,
60-
WeightType.IQ2_XXS,
61-
WeightType.IQ2_XS,
62-
WeightType.IQ2_S,
63-
WeightType.IQ3_XXS,
64-
WeightType.IQ3_S,
65-
WeightType.IQ4_XS,
66-
WeightType.IQ4_NL,
57+
gguf.GGMLQuantizationType.IQ1_M,
58+
gguf.GGMLQuantizationType.IQ1_S,
59+
gguf.GGMLQuantizationType.IQ2_XXS,
60+
gguf.GGMLQuantizationType.IQ2_XS,
61+
gguf.GGMLQuantizationType.IQ2_S,
62+
gguf.GGMLQuantizationType.IQ3_XXS,
63+
gguf.GGMLQuantizationType.IQ3_S,
64+
gguf.GGMLQuantizationType.IQ4_XS,
65+
gguf.GGMLQuantizationType.IQ4_NL,
6766
}
6867
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
6968
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
@@ -73,8 +72,7 @@
7372
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
7473

7574

76-
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
77-
qweight_type: int) -> torch.Tensor:
75+
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
7876
# there is no need to call any kernel for fp16/bf16
7977
if qweight_type in UNQUANTIZED_TYPES:
8078
return x @ qweight.T
@@ -87,8 +85,8 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
8785
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
8886
# elif qweight_type in MMQ_QUANT_TYPES:
8987
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
90-
# If there is no available MMQ kernel, fallback to dequantize
9188

89+
# If there is no available MMQ kernel, fallback to dequantize
9290
elif qweight_type in DEQUANT_TYPES:
9391
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
9492
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
@@ -98,9 +96,8 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
9896
# Raise an error if the quantization type is not supported.
9997
# Might be useful if llama.cpp adds a new quantization type.
10098
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
101-
qweight_type = WeightType(qweight_type)
102-
raise NotImplementedError(
103-
f"Unsupported GGUF quantization type: {qweight_type}")
99+
qweight_type = gguf.GGMLQuantizationType(qweight_type)
100+
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
104101
return y
105102

106103

@@ -528,25 +525,22 @@ def __init__(
528525
self.compute_dtype = compute_dtype
529526
self.device = device
530527

531-
def forward(self, inputs):
528+
def forward(self, inputs: torch.Tensor):
532529
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
533530
return self.forward_cuda(inputs)
534531
return self.forward_native(inputs)
535532

536-
def forward_native(self, inputs):
533+
def forward_native(self, inputs: torch.Tensor):
537534
weight = dequantize_gguf_tensor(self.weight)
538535
weight = weight.to(self.compute_dtype)
539536
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
540537

541538
output = torch.nn.functional.linear(inputs, weight, bias)
542539
return output
543540

544-
def forward_cuda(self, inputs):
541+
def forward_cuda(self, inputs: torch.Tensor):
545542
quant_type = self.weight.quant_type
546-
orig_shape = inputs.shape
547-
inputs = inputs.view(-1, orig_shape[-1])
548543
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
549544
if self.bias is not None:
550-
output = output + self.bias.to(self.compute_dtype)
551-
return output.view(*orig_shape[:-1], -1)
552-
545+
output += self.bias.to(self.compute_dtype)
546+
return output

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@
7676
is_hpu_available,
7777
is_inflect_available,
7878
is_invisible_watermark_available,
79-
is_kernels_available,
8079
is_k_diffusion_available,
8180
is_k_diffusion_version,
81+
is_kernels_available,
8282
is_librosa_available,
8383
is_matplotlib_available,
8484
is_nltk_available,

0 commit comments

Comments
 (0)