|
13 | 13 | # # limitations under the License. |
14 | 14 |
|
15 | 15 |
|
| 16 | +import inspect |
16 | 17 | from contextlib import nullcontext |
17 | 18 |
|
18 | 19 | import gguf |
|
23 | 24 |
|
24 | 25 |
|
25 | 26 | if is_accelerate_available(): |
| 27 | + import accelerate |
26 | 28 | from accelerate import init_empty_weights |
| 29 | + from accelerate.hooks import add_hook_to_module, remove_hook_from_module |
| 30 | + |
| 31 | + |
| 32 | +# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook |
| 33 | +def _create_accelerate_new_hook(old_hook): |
| 34 | + r""" |
| 35 | + Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of: |
| 36 | + https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with |
| 37 | + some changes |
| 38 | + """ |
| 39 | + old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) |
| 40 | + old_hook_attr = old_hook.__dict__ |
| 41 | + filtered_old_hook_attr = {} |
| 42 | + old_hook_init_signature = inspect.signature(old_hook_cls.__init__) |
| 43 | + for k in old_hook_attr.keys(): |
| 44 | + if k in old_hook_init_signature.parameters: |
| 45 | + filtered_old_hook_attr[k] = old_hook_attr[k] |
| 46 | + new_hook = old_hook_cls(**filtered_old_hook_attr) |
| 47 | + return new_hook |
27 | 48 |
|
28 | 49 |
|
29 | 50 | def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]): |
@@ -59,6 +80,42 @@ def _should_convert_to_gguf(state_dict, prefix): |
59 | 80 | return model |
60 | 81 |
|
61 | 82 |
|
| 83 | +def _dequantize_gguf_and_restore_linear(model, modules_to_not_convert=[]): |
| 84 | + for name, module in model.named_children(): |
| 85 | + if isinstance(module, GGUFLinear) and name not in modules_to_not_convert: |
| 86 | + device = module.weight.device |
| 87 | + bias = getattr(module, "bias", None) |
| 88 | + |
| 89 | + ctx = init_empty_weights if is_accelerate_available() else nullcontext |
| 90 | + with ctx(): |
| 91 | + new_module = nn.Linear( |
| 92 | + module.in_features, |
| 93 | + module.out_features, |
| 94 | + module.bias is not None, |
| 95 | + device=device, |
| 96 | + ) |
| 97 | + new_module.weight = nn.Parameter(dequantize_gguf_tensor(module.weight)) |
| 98 | + if bias is not None: |
| 99 | + new_module.bias = bias |
| 100 | + |
| 101 | + # Create a new hook and attach it in case we use accelerate |
| 102 | + if hasattr(module, "_hf_hook"): |
| 103 | + old_hook = module._hf_hook |
| 104 | + new_hook = _create_accelerate_new_hook(old_hook) |
| 105 | + |
| 106 | + remove_hook_from_module(module) |
| 107 | + add_hook_to_module(new_module, new_hook) |
| 108 | + |
| 109 | + new_module.to(device) |
| 110 | + model._modules[name] = new_module |
| 111 | + |
| 112 | + has_children = list(module.children()) |
| 113 | + if has_children: |
| 114 | + _dequantize_gguf_and_restore_linear(module, modules_to_not_convert) |
| 115 | + |
| 116 | + return model |
| 117 | + |
| 118 | + |
62 | 119 | # dequantize operations based on torch ports of GGUF dequantize_functions |
63 | 120 | # from City96 |
64 | 121 | # more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py |
|
0 commit comments