|  | 
| 13 | 13 | # # limitations under the License. | 
| 14 | 14 | 
 | 
| 15 | 15 | 
 | 
|  | 16 | +import inspect | 
| 16 | 17 | from contextlib import nullcontext | 
| 17 | 18 | 
 | 
| 18 | 19 | import gguf | 
|  | 
| 23 | 24 | 
 | 
| 24 | 25 | 
 | 
| 25 | 26 | if is_accelerate_available(): | 
|  | 27 | +    import accelerate | 
| 26 | 28 |     from accelerate import init_empty_weights | 
|  | 29 | +    from accelerate.hooks import add_hook_to_module, remove_hook_from_module | 
|  | 30 | + | 
|  | 31 | + | 
|  | 32 | +# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook | 
|  | 33 | +def _create_accelerate_new_hook(old_hook): | 
|  | 34 | +    r""" | 
|  | 35 | +    Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of: | 
|  | 36 | +    https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with | 
|  | 37 | +    some changes | 
|  | 38 | +    """ | 
|  | 39 | +    old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) | 
|  | 40 | +    old_hook_attr = old_hook.__dict__ | 
|  | 41 | +    filtered_old_hook_attr = {} | 
|  | 42 | +    old_hook_init_signature = inspect.signature(old_hook_cls.__init__) | 
|  | 43 | +    for k in old_hook_attr.keys(): | 
|  | 44 | +        if k in old_hook_init_signature.parameters: | 
|  | 45 | +            filtered_old_hook_attr[k] = old_hook_attr[k] | 
|  | 46 | +    new_hook = old_hook_cls(**filtered_old_hook_attr) | 
|  | 47 | +    return new_hook | 
| 27 | 48 | 
 | 
| 28 | 49 | 
 | 
| 29 | 50 | def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]): | 
| @@ -59,6 +80,42 @@ def _should_convert_to_gguf(state_dict, prefix): | 
| 59 | 80 |     return model | 
| 60 | 81 | 
 | 
| 61 | 82 | 
 | 
|  | 83 | +def _dequantize_gguf_and_restore_linear(model, modules_to_not_convert=[]): | 
|  | 84 | +    for name, module in model.named_children(): | 
|  | 85 | +        if isinstance(module, GGUFLinear) and name not in modules_to_not_convert: | 
|  | 86 | +            device = module.weight.device | 
|  | 87 | +            bias = getattr(module, "bias", None) | 
|  | 88 | + | 
|  | 89 | +            ctx = init_empty_weights if is_accelerate_available() else nullcontext | 
|  | 90 | +            with ctx(): | 
|  | 91 | +                new_module = nn.Linear( | 
|  | 92 | +                    module.in_features, | 
|  | 93 | +                    module.out_features, | 
|  | 94 | +                    module.bias is not None, | 
|  | 95 | +                    device=device, | 
|  | 96 | +                ) | 
|  | 97 | +            new_module.weight = nn.Parameter(dequantize_gguf_tensor(module.weight)) | 
|  | 98 | +            if bias is not None: | 
|  | 99 | +                new_module.bias = bias | 
|  | 100 | + | 
|  | 101 | +            # Create a new hook and attach it in case we use accelerate | 
|  | 102 | +            if hasattr(module, "_hf_hook"): | 
|  | 103 | +                old_hook = module._hf_hook | 
|  | 104 | +                new_hook = _create_accelerate_new_hook(old_hook) | 
|  | 105 | + | 
|  | 106 | +                remove_hook_from_module(module) | 
|  | 107 | +                add_hook_to_module(new_module, new_hook) | 
|  | 108 | + | 
|  | 109 | +            new_module.to(device) | 
|  | 110 | +            model._modules[name] = new_module | 
|  | 111 | + | 
|  | 112 | +        has_children = list(module.children()) | 
|  | 113 | +        if has_children: | 
|  | 114 | +            _dequantize_gguf_and_restore_linear(module, modules_to_not_convert) | 
|  | 115 | + | 
|  | 116 | +    return model | 
|  | 117 | + | 
|  | 118 | + | 
| 62 | 119 | # dequantize operations based on torch ports of GGUF dequantize_functions | 
| 63 | 120 | # from City96 | 
| 64 | 121 | # more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py | 
|  | 
0 commit comments