|  | 
| 17 | 17 | import importlib | 
| 18 | 18 | import inspect | 
| 19 | 19 | import os | 
|  | 20 | +from array import array | 
| 20 | 21 | from collections import OrderedDict | 
| 21 | 22 | from pathlib import Path | 
| 22 | 23 | from typing import List, Optional, Union | 
|  | 
| 26 | 27 | from huggingface_hub.utils import EntryNotFoundError | 
| 27 | 28 | 
 | 
| 28 | 29 | from ..utils import ( | 
|  | 30 | +    GGUF_FILE_EXTENSION, | 
| 29 | 31 |     SAFE_WEIGHTS_INDEX_NAME, | 
| 30 | 32 |     SAFETENSORS_FILE_EXTENSION, | 
| 31 | 33 |     WEIGHTS_INDEX_NAME, | 
| 32 | 34 |     _add_variant, | 
| 33 | 35 |     _get_model_file, | 
| 34 | 36 |     deprecate, | 
| 35 | 37 |     is_accelerate_available, | 
|  | 38 | +    is_gguf_available, | 
|  | 39 | +    is_torch_available, | 
| 36 | 40 |     is_torch_version, | 
| 37 | 41 |     logging, | 
| 38 | 42 | ) | 
| @@ -139,6 +143,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ | 
| 139 | 143 |         file_extension = os.path.basename(checkpoint_file).split(".")[-1] | 
| 140 | 144 |         if file_extension == SAFETENSORS_FILE_EXTENSION: | 
| 141 | 145 |             return safetensors.torch.load_file(checkpoint_file, device="cpu") | 
|  | 146 | +        elif file_extension == GGUF_FILE_EXTENSION: | 
|  | 147 | +            return load_gguf_checkpoint(checkpoint_file) | 
| 142 | 148 |         else: | 
| 143 | 149 |             weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} | 
| 144 | 150 |             return torch.load( | 
| @@ -211,13 +217,14 @@ def load_model_dict_into_meta( | 
| 211 | 217 |                     set_module_kwargs["dtype"] = dtype | 
| 212 | 218 | 
 | 
| 213 | 219 |         # bnb params are flattened. | 
|  | 220 | +        # gguf quants have a different shape based on the type of quantization applied | 
| 214 | 221 |         if empty_state_dict[param_name].shape != param.shape: | 
| 215 | 222 |             if ( | 
| 216 | 223 |                 is_quantized | 
| 217 | 224 |                 and hf_quantizer.pre_quantized | 
| 218 | 225 |                 and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) | 
| 219 | 226 |             ): | 
| 220 |  | -                hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape) | 
|  | 227 | +                hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param) | 
| 221 | 228 |             else: | 
| 222 | 229 |                 model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" | 
| 223 | 230 |                 raise ValueError( | 
| @@ -396,3 +403,78 @@ def _fetch_index_file_legacy( | 
| 396 | 403 |                 index_file = None | 
| 397 | 404 | 
 | 
| 398 | 405 |     return index_file | 
|  | 406 | + | 
|  | 407 | + | 
|  | 408 | +def _gguf_parse_value(_value, data_type): | 
|  | 409 | +    if not isinstance(data_type, list): | 
|  | 410 | +        data_type = [data_type] | 
|  | 411 | +    if len(data_type) == 1: | 
|  | 412 | +        data_type = data_type[0] | 
|  | 413 | +        array_data_type = None | 
|  | 414 | +    else: | 
|  | 415 | +        if data_type[0] != 9: | 
|  | 416 | +            raise ValueError("Received multiple types, therefore expected the first type to indicate an array.") | 
|  | 417 | +        data_type, array_data_type = data_type | 
|  | 418 | + | 
|  | 419 | +    if data_type in [0, 1, 2, 3, 4, 5, 10, 11]: | 
|  | 420 | +        _value = int(_value[0]) | 
|  | 421 | +    elif data_type in [6, 12]: | 
|  | 422 | +        _value = float(_value[0]) | 
|  | 423 | +    elif data_type in [7]: | 
|  | 424 | +        _value = bool(_value[0]) | 
|  | 425 | +    elif data_type in [8]: | 
|  | 426 | +        _value = array("B", list(_value)).tobytes().decode() | 
|  | 427 | +    elif data_type in [9]: | 
|  | 428 | +        _value = _gguf_parse_value(_value, array_data_type) | 
|  | 429 | +    return _value | 
|  | 430 | + | 
|  | 431 | + | 
|  | 432 | +def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): | 
|  | 433 | +    """ | 
|  | 434 | +    Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config | 
|  | 435 | +    attributes. | 
|  | 436 | +
 | 
|  | 437 | +    Args: | 
|  | 438 | +        gguf_checkpoint_path (`str`): | 
|  | 439 | +            The path the to GGUF file to load | 
|  | 440 | +        return_tensors (`bool`, defaults to `True`): | 
|  | 441 | +            Whether to read the tensors from the file and return them. Not doing so is faster and only loads the | 
|  | 442 | +            metadata in memory. | 
|  | 443 | +    """ | 
|  | 444 | + | 
|  | 445 | +    if is_gguf_available() and is_torch_available(): | 
|  | 446 | +        import gguf | 
|  | 447 | +        from gguf import GGUFReader | 
|  | 448 | + | 
|  | 449 | +        from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter | 
|  | 450 | +    else: | 
|  | 451 | +        logger.error( | 
|  | 452 | +            "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see " | 
|  | 453 | +            "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions." | 
|  | 454 | +        ) | 
|  | 455 | +        raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.") | 
|  | 456 | + | 
|  | 457 | +    reader = GGUFReader(gguf_checkpoint_path) | 
|  | 458 | + | 
|  | 459 | +    parsed_parameters = {} | 
|  | 460 | +    for tensor in reader.tensors: | 
|  | 461 | +        name = tensor.name | 
|  | 462 | +        quant_type = tensor.tensor_type | 
|  | 463 | + | 
|  | 464 | +        # if the tensor is a torch supported dtype do not use GGUFParameter | 
|  | 465 | +        is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16] | 
|  | 466 | +        if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES: | 
|  | 467 | +            _supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES]) | 
|  | 468 | +            raise ValueError( | 
|  | 469 | +                ( | 
|  | 470 | +                    f"{name} has a quantization type: {str(quant_type)} which is unsupported." | 
|  | 471 | +                    "\n\nCurrently the following quantization types are supported: \n\n" | 
|  | 472 | +                    f"{_supported_quants_str}" | 
|  | 473 | +                    "\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers" | 
|  | 474 | +                ) | 
|  | 475 | +            ) | 
|  | 476 | + | 
|  | 477 | +        weights = torch.from_numpy(tensor.data.copy()) | 
|  | 478 | +        parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights | 
|  | 479 | + | 
|  | 480 | +    return parsed_parameters | 
0 commit comments