Skip to content

Commit d7f09f2

Browse files
committed
update
1 parent 428e44b commit d7f09f2

File tree

4 files changed

+26
-40
lines changed

4 files changed

+26
-40
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,7 @@ def load_model_dict_into_meta(
182182
hf_quantizer=None,
183183
keep_in_fp32_modules=None,
184184
) -> List[str]:
185-
if hf_quantizer is None:
186-
device = device or torch.device("cpu")
185+
device = device or torch.device("cpu")
187186
dtype = dtype or torch.float32
188187
is_quantized = hf_quantizer is not None
189188

@@ -223,7 +222,7 @@ def load_model_dict_into_meta(
223222
and hf_quantizer.pre_quantized
224223
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
225224
):
226-
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
225+
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
227226
else:
228227
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
229228
raise ValueError(
@@ -469,12 +468,8 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
469468

470469
# if the tensor is a torch supported dtype do not use GGUFParameter
471470
is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
472-
weights = torch.from_numpy(tensor.data)
473-
parsed_parameters[name] = (
474-
GGUFParameter(weights, quant_type=quant_type)
475-
if is_gguf_quant
476-
else weights.permute(*torch.arange(weights.ndim - 1, -1, -1))
477-
)
471+
weights = torch.from_numpy(tensor.data.copy())
472+
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
478473

479474
if len(reader_keys) > 0:
480475
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")

src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,10 @@ def create_quantized_param(
204204

205205
module._parameters[tensor_name] = new_value
206206

207-
def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape):
207+
def check_quantized_param_shape(self, param_name, current_param, loaded_param):
208+
current_param_shape = current_param.shape
209+
loaded_param_shape = loaded_param.shape
210+
208211
n = current_param_shape.numel()
209212
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
210213
if loaded_param_shape != inferred_shape:

src/diffusers/quantizers/gguf/gguf_quantizer.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,17 @@
99
from ...models.modeling_utils import ModelMixin
1010

1111
from ...utils import (
12-
is_accelerate_available,
12+
is_gguf_available,
1313
is_torch_available,
1414
logging,
1515
)
1616

1717

18-
if is_accelerate_available():
19-
pass
20-
2118
if is_torch_available():
2219
import torch
2320

21+
if is_gguf_available():
22+
import gguf
2423

2524
logger = logging.get_logger(__name__)
2625

@@ -32,9 +31,20 @@ def __init__(self, quantization_config, **kwargs):
3231
self.compute_dtype = quantization_config.compute_dtype
3332
self.pre_quantized = True
3433

35-
def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape):
36-
if _quant_shape_from_byte_shape(loaded_param_shape) == current_param_shape:
37-
return True
34+
def check_quantized_param_shape(self, param_name, current_param, loaded_param):
35+
loaded_param_shape = loaded_param.shape
36+
current_param_shape = current_param.shape
37+
quant_type = loaded_param.quant_type
38+
39+
block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type]
40+
41+
inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)
42+
if inferred_shape != current_param_shape:
43+
raise ValueError(
44+
f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}"
45+
)
46+
47+
return True
3848

3949
def check_if_quantized_param(
4050
self,
@@ -44,8 +54,7 @@ def check_if_quantized_param(
4454
state_dict: Dict[str, Any],
4555
**kwargs,
4656
) -> bool:
47-
module, tensor_name = get_module_from_name(model, param_name)
48-
if isinstance(module._parameters.get(tensor_name, None), GGUFParameter):
57+
if isinstance(param_value, GGUFParameter):
4958
return True
5059

5160
return False

src/diffusers/quantizers/gguf/utils.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,6 @@
1717
import torch.nn as nn
1818

1919

20-
_GGUF_FILE_TYPE_MAPPING = {
21-
0: "ALL_F32",
22-
1: "MOSTLY_F16",
23-
2: "MOSTLY_Q4_0",
24-
3: "MOSTLY_Q4_1",
25-
4: "MOSTLY_Q4_1_SOME_F16",
26-
8: "MOSTLY_Q5_0",
27-
9: "MOSTLY_Q5_1",
28-
10: "MOSTLY_Q2_K",
29-
11: "MOSTLY_Q3_K_S",
30-
12: "MOSTLY_Q3_K_M",
31-
13: "MOSTLY_Q3_K_L",
32-
14: "MOSTLY_Q4_K_S",
33-
15: "MOSTLY_Q4_K_M",
34-
16: "MOSTLY_Q5_K_S",
35-
17: "MOSTLY_Q5_K_M",
36-
18: "MOSTLY_Q6_K",
37-
}
38-
39-
4020
def _replace_with_gguf_linear(model, compute_dtype):
4121
for name, module in model.named_children():
4222
if isinstance(module, nn.Linear):
@@ -321,7 +301,6 @@ def dequantize_gguf_tensor(tensor, compute_dtype):
321301

322302
block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type]
323303

324-
tensor = torch.tensor(tensor)
325304
tensor = tensor.view(torch.uint8)
326305
shape = _quant_shape_from_byte_shape(tensor.shape, type_size, block_size)
327306

0 commit comments

Comments
 (0)