Skip to content

Commit afd5d7d

Browse files
committed
update
1 parent 8f604b3 commit afd5d7d

File tree

7 files changed

+76
-22
lines changed

7 files changed

+76
-22
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,8 @@
338338
"StableDiffusion3ControlNetPipeline",
339339
"StableDiffusion3Img2ImgPipeline",
340340
"StableDiffusion3InpaintPipeline",
341-
"StableDiffusion3PAGPipeline",
342341
"StableDiffusion3PAGImg2ImgPipeline",
342+
"StableDiffusion3PAGPipeline",
343343
"StableDiffusion3Pipeline",
344344
"StableDiffusionAdapterPipeline",
345345
"StableDiffusionAttendAndExcitePipeline",

src/diffusers/loaders/single_file_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
349349
if hf_quantizer is not None:
350350
hf_quantizer.postprocess_model(model)
351351

352-
if torch_dtype is not None:
352+
if torch_dtype is not None and hf_quantizer is None:
353353
model.to(torch_dtype)
354354

355355
model.eval()

src/diffusers/models/model_loading_utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
449449
import gguf
450450
from gguf import GGUFReader
451451

452-
from ..quantizers.gguf.utils import GGUFParameter
452+
from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
453453
else:
454454
logger.error(
455455
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
@@ -458,8 +458,6 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
458458
raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
459459

460460
reader = GGUFReader(gguf_checkpoint_path)
461-
fields = reader.fields
462-
reader_keys = list(fields.keys())
463461

464462
parsed_parameters = {}
465463
for tensor in tqdm(reader.tensors, desc="Loading GGUF Parameters: "):
@@ -468,10 +466,16 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
468466

469467
# if the tensor is a torch supported dtype do not use GGUFParameter
470468
is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
469+
if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
470+
raise ValueError(
471+
(
472+
f"{name} has a quantization type: {quant_type} which is unsupported."
473+
f" Currently the following quantization types are supported: {SUPPORTED_GGUF_QUANT_TYPES}"
474+
"To request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
475+
)
476+
)
477+
471478
weights = torch.from_numpy(tensor.data.copy())
472479
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
473480

474-
if len(reader_keys) > 0:
475-
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")
476-
477481
return parsed_parameters

src/diffusers/quantizers/gguf/gguf_quantizer.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
22

33
from ..base import DiffusersQuantizer
44

@@ -12,6 +12,7 @@
1212
is_accelerate_available,
1313
is_accelerate_version,
1414
is_gguf_available,
15+
is_gguf_version,
1516
is_torch_available,
1617
logging,
1718
)
@@ -21,7 +22,11 @@
2122
import gguf
2223
import torch
2324

24-
from .utils import GGUFParameter, _quant_shape_from_byte_shape, _replace_with_gguf_linear
25+
from .utils import (
26+
GGUFParameter,
27+
_quant_shape_from_byte_shape,
28+
_replace_with_gguf_linear,
29+
)
2530

2631

2732
logger = logging.get_logger(__name__)
@@ -39,11 +44,26 @@ def validate_environment(self, *args, **kwargs):
3944
raise ImportError(
4045
"Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`"
4146
)
42-
if not is_gguf_available():
47+
if not is_gguf_available() or is_gguf_version("<", "0.10.0"):
4348
raise ImportError(
44-
"To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf`"
49+
"To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf>=0.10.0`"
4550
)
4651

52+
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
53+
# need more space for buffers that are created during quantization
54+
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
55+
return max_memory
56+
57+
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
58+
if target_dtype != torch.uint8:
59+
logger.info(f"target_dtype {target_dtype} is replaced by `torch.uint8` for GGUF quantization")
60+
return torch.uint8
61+
62+
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
63+
if torch_dtype is None:
64+
torch_dtype = self.compute_dtype
65+
return torch_dtype
66+
4767
def check_quantized_param_shape(self, param_name, current_param, loaded_param):
4868
loaded_param_shape = loaded_param.shape
4969
current_param_shape = current_param.shape
@@ -62,7 +82,7 @@ def check_quantized_param_shape(self, param_name, current_param, loaded_param):
6282
def check_if_quantized_param(
6383
self,
6484
model: "ModelMixin",
65-
param_value: "torch.Tensor",
85+
param_value: Union["GGUFParameter", "torch.Tensor"],
6686
param_name: str,
6787
state_dict: Dict[str, Any],
6888
**kwargs,
@@ -82,10 +102,13 @@ def create_quantized_param(
82102
unexpected_keys: Optional[List[str]] = None,
83103
):
84104
module, tensor_name = get_module_from_name(model, param_name)
85-
if tensor_name not in module._parameters:
105+
if tensor_name not in module._parameters and tensor_name not in module._buffers:
86106
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
87107

88-
module._parameters[tensor_name] = param_value
108+
if tensor_name in module._parameters:
109+
module._parameters[tensor_name] = param_value.to(target_device)
110+
if tensor_name in module._buffers:
111+
module._buffers[tensor_name] = param_value.to(target_device)
89112

90113
def _process_model_before_weight_loading(
91114
self,

src/diffusers/quantizers/gguf/utils.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,18 @@
1313
# # limitations under the License.
1414

1515

16+
from contextlib import nullcontext
17+
1618
import gguf
1719
import torch
1820
import torch.nn as nn
1921

22+
from ...utils import is_accelerate_available
23+
24+
25+
if is_accelerate_available():
26+
from accelerate import init_empty_weights
27+
2028

2129
def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix=""):
2230
def _should_convert_to_gguf(module, state_dict, prefix):
@@ -32,12 +40,14 @@ def _should_convert_to_gguf(module, state_dict, prefix):
3240
_replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix)
3341

3442
if isinstance(module, nn.Linear) and _should_convert_to_gguf(module, state_dict, module_prefix):
35-
model._modules[name] = GGUFLinear(
36-
module.in_features,
37-
module.out_features,
38-
module.bias is not None,
39-
compute_dtype=compute_dtype,
40-
)
43+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
44+
with ctx():
45+
model._modules[name] = GGUFLinear(
46+
module.in_features,
47+
module.out_features,
48+
module.bias is not None,
49+
compute_dtype=compute_dtype,
50+
)
4151
model._modules[name].source_cls = type(module)
4252
# Force requires grad to False to avoid unexpected errors
4353
model._modules[name].requires_grad_(False)
@@ -296,6 +306,7 @@ def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
296306
gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
297307
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
298308
}
309+
SUPPORTED_GGUF_QUANT_TYPES = list(dequantize_functions.keys())
299310

300311

301312
def _quant_shape_from_byte_shape(shape, type_size, block_size):
@@ -323,7 +334,7 @@ def dequantize_gguf_tensor(tensor):
323334
return dequant.as_tensor()
324335

325336

326-
class GGUFParameter(torch.Tensor):
337+
class GGUFParameter(torch.nn.Parameter):
327338
def __new__(cls, data, requires_grad=False, quant_type=None):
328339
data = data if data is not None else torch.empty(0)
329340
self = torch.Tensor._make_subclass(cls, data, requires_grad)

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
is_flax_available,
6969
is_ftfy_available,
7070
is_gguf_available,
71+
is_gguf_version,
7172
is_google_colab,
7273
is_inflect_available,
7374
is_invisible_watermark_available,

src/diffusers/utils/import_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,21 @@ def is_bitsandbytes_version(operation: str, version: str):
777777
return compare_versions(parse(_bitsandbytes_version), operation, version)
778778

779779

780+
def is_gguf_version(operation: str, version: str):
781+
"""
782+
Compares the current Accelerate version to a given reference with an operation.
783+
784+
Args:
785+
operation (`str`):
786+
A string representation of an operator, such as `">"` or `"<="`
787+
version (`str`):
788+
A version string
789+
"""
790+
if not _is_gguf_available:
791+
return False
792+
return compare_versions(parse(_gguf_version), operation, version)
793+
794+
780795
def is_k_diffusion_version(operation: str, version: str):
781796
"""
782797
Compares the current k-diffusion version to a given reference with an operation.

0 commit comments

Comments
 (0)