Skip to content

Commit 8abfa55

Browse files
committed
update
1 parent 67f1700 commit 8abfa55

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

src/diffusers/quantizers/gguf/gguf_quantizer.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020

2121
if is_torch_available() and is_gguf_available():
22-
import gguf
2322
import torch
2423

2524
from .utils import (
25+
GGML_QUANT_SIZES,
2626
GGUFParameter,
2727
_quant_shape_from_byte_shape,
2828
_replace_with_gguf_linear,
@@ -33,11 +33,17 @@
3333

3434

3535
class GGUFQuantizer(DiffusersQuantizer):
36+
use_keep_in_fp32_modules = True
37+
3638
def __init__(self, quantization_config, **kwargs):
3739
super().__init__(quantization_config, **kwargs)
3840

3941
self.compute_dtype = quantization_config.compute_dtype
4042
self.pre_quantized = quantization_config.pre_quantized
43+
self.modules_to_not_convert = quantization_config.modules_to_not_convert
44+
45+
if not isinstance(self.modules_to_not_convert, list):
46+
self.modules_to_not_convert = [self.modules_to_not_convert]
4147

4248
def validate_environment(self, *args, **kwargs):
4349
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
@@ -70,7 +76,7 @@ def check_quantized_param_shape(self, param_name, current_param, loaded_param):
7076
current_param_shape = current_param.shape
7177
quant_type = loaded_param.quant_type
7278

73-
block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type]
79+
block_size, type_size = GGML_QUANT_SIZES[quant_type]
7480

7581
inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)
7682
if inferred_shape != current_param_shape:
@@ -96,7 +102,7 @@ def check_if_quantized_param(
96102
def create_quantized_param(
97103
self,
98104
model: "ModelMixin",
99-
param_value: "torch.Tensor",
105+
param_value: Union["GGUFParameter", "torch.Tensor"],
100106
param_name: str,
101107
target_device: "torch.device",
102108
state_dict: Dict[str, Any],
@@ -119,7 +125,13 @@ def _process_model_before_weight_loading(
119125
**kwargs,
120126
):
121127
state_dict = kwargs.get("state_dict", None)
122-
_replace_with_gguf_linear(model, self.compute_dtype, state_dict)
128+
129+
self.modules_to_not_convert.extend(keep_in_fp32_modules)
130+
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
131+
132+
_replace_with_gguf_linear(
133+
model, self.compute_dtype, state_dict, modules_to_not_convert=self.modules_to_not_convert
134+
)
123135

124136
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
125137
return model

src/diffusers/quantizers/gguf/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from accelerate import init_empty_weights
2727

2828

29-
def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix=""):
29+
def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]):
3030
def _should_convert_to_gguf(module, state_dict, prefix):
3131
weight_key = prefix + "weight"
3232
return weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter)
@@ -37,9 +37,13 @@ def _should_convert_to_gguf(module, state_dict, prefix):
3737

3838
for name, module in model.named_children():
3939
module_prefix = prefix + name + "."
40-
_replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix)
40+
_replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix, modules_to_not_convert)
4141

42-
if isinstance(module, nn.Linear) and _should_convert_to_gguf(module, state_dict, module_prefix):
42+
if (
43+
isinstance(module, nn.Linear)
44+
and _should_convert_to_gguf(module, state_dict, module_prefix)
45+
and name not in modules_to_not_convert
46+
):
4347
ctx = init_empty_weights if is_accelerate_available() else nullcontext
4448
with ctx():
4549
model._modules[name] = GGUFLinear(

src/diffusers/quantizers/quantization_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,12 @@ def to_diff_dict(self) -> Dict[str, Any]:
393393

394394

395395
class GGUFQuantizationConfig(QuantizationConfigMixin):
396-
def __init__(self, compute_dtype=None, quant_storage=None):
396+
def __init__(self, compute_dtype=None, quant_storage=None, modules_to_not_convert=None):
397397
self.quant_method = QuantizationMethod.GGUF
398398
self.compute_dtype = compute_dtype
399399
self.quant_storage = quant_storage
400400
self.pre_quantized = True
401+
self.modules_to_not_convert = modules_to_not_convert
401402

402403
if self.compute_dtype is None:
403404
self.compute_dtype = torch.float32

0 commit comments

Comments
 (0)