Skip to content

Commit d3eb54f

Browse files
committed
update
1 parent edf3e54 commit d3eb54f

File tree

4 files changed

+92
-1
lines changed

4 files changed

+92
-1
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
351351

352352
if hf_quantizer is not None:
353353
hf_quantizer.postprocess_model(model)
354+
model.hf_quantizer = hf_quantizer
354355

355356
if torch_dtype is not None and hf_quantizer is None:
356357
model.to(torch_dtype)

src/diffusers/quantizers/gguf/gguf_quantizer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .utils import (
2525
GGML_QUANT_SIZES,
2626
GGUFParameter,
27+
_dequantize_gguf_and_restore_linear,
2728
_quant_shape_from_byte_shape,
2829
_replace_with_gguf_linear,
2930
)
@@ -143,3 +144,16 @@ def is_serializable(self):
143144
@property
144145
def is_trainable(self) -> bool:
145146
return False
147+
148+
def _dequantize(self, model):
149+
is_model_on_cpu = model.device.type == "cpu"
150+
if is_model_on_cpu:
151+
logger.info(
152+
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
153+
)
154+
model.to(torch.cuda.current_device())
155+
156+
model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
157+
if is_model_on_cpu:
158+
model.to("cpu")
159+
return model

src/diffusers/quantizers/gguf/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# # limitations under the License.
1414

1515

16+
import inspect
1617
from contextlib import nullcontext
1718

1819
import gguf
@@ -23,7 +24,27 @@
2324

2425

2526
if is_accelerate_available():
27+
import accelerate
2628
from accelerate import init_empty_weights
29+
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
30+
31+
32+
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
33+
def _create_accelerate_new_hook(old_hook):
34+
r"""
35+
Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of:
36+
https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with
37+
some changes
38+
"""
39+
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
40+
old_hook_attr = old_hook.__dict__
41+
filtered_old_hook_attr = {}
42+
old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
43+
for k in old_hook_attr.keys():
44+
if k in old_hook_init_signature.parameters:
45+
filtered_old_hook_attr[k] = old_hook_attr[k]
46+
new_hook = old_hook_cls(**filtered_old_hook_attr)
47+
return new_hook
2748

2849

2950
def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]):
@@ -59,6 +80,42 @@ def _should_convert_to_gguf(state_dict, prefix):
5980
return model
6081

6182

83+
def _dequantize_gguf_and_restore_linear(model, modules_to_not_convert=[]):
84+
for name, module in model.named_children():
85+
if isinstance(module, GGUFLinear) and name not in modules_to_not_convert:
86+
device = module.weight.device
87+
bias = getattr(module, "bias", None)
88+
89+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
90+
with ctx():
91+
new_module = nn.Linear(
92+
module.in_features,
93+
module.out_features,
94+
module.bias is not None,
95+
device=device,
96+
)
97+
new_module.weight = nn.Parameter(dequantize_gguf_tensor(module.weight))
98+
if bias is not None:
99+
new_module.bias = bias
100+
101+
# Create a new hook and attach it in case we use accelerate
102+
if hasattr(module, "_hf_hook"):
103+
old_hook = module._hf_hook
104+
new_hook = _create_accelerate_new_hook(old_hook)
105+
106+
remove_hook_from_module(module)
107+
add_hook_to_module(new_module, new_hook)
108+
109+
new_module.to(device)
110+
model._modules[name] = new_module
111+
112+
has_children = list(module.children())
113+
if has_children:
114+
_dequantize_gguf_and_restore_linear(module, modules_to_not_convert)
115+
116+
return model
117+
118+
62119
# dequantize operations based on torch ports of GGUF dequantize_functions
63120
# from City96
64121
# more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py

tests/quantization/gguf/test_gguf.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import torch
6+
import torch.nn as nn
67

78
from diffusers import (
89
FluxPipeline,
@@ -23,7 +24,7 @@
2324

2425

2526
if is_gguf_available():
26-
from diffusers.quantizers.gguf.utils import GGUFParameter
27+
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
2728

2829

2930
@nightly
@@ -112,6 +113,24 @@ def test_dtype_assignment(self):
112113
# This should work
113114
model.to("cuda")
114115

116+
def test_dequantize_model(self):
117+
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
118+
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
119+
model.dequantize()
120+
121+
def _check_for_gguf_linear(model):
122+
has_children = list(model.children())
123+
if not has_children:
124+
return
125+
126+
for name, module in model.named_children():
127+
if isinstance(module, nn.Linear):
128+
assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear"
129+
assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter"
130+
131+
for name, module in model.named_children():
132+
_check_for_gguf_linear(module)
133+
115134

116135
class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
117136
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"

0 commit comments

Comments
 (0)