|  | 
| 23 | 23 |     deprecate, | 
| 24 | 24 |     get_submodule_by_name, | 
| 25 | 25 |     is_bitsandbytes_available, | 
|  | 26 | +    is_gguf_available, | 
| 26 | 27 |     is_peft_available, | 
| 27 | 28 |     is_peft_version, | 
| 28 | 29 |     is_torch_version, | 
|  | 
| 49 | 50 | ) | 
| 50 | 51 | 
 | 
| 51 | 52 | 
 | 
| 52 |  | -if is_bitsandbytes_available(): | 
| 53 |  | -    from ..quantizers.bitsandbytes import dequantize_bnb_weight | 
| 54 |  | - | 
| 55 | 53 | _LOW_CPU_MEM_USAGE_DEFAULT_LORA = False | 
| 56 | 54 | if is_torch_version(">=", "1.9.0"): | 
| 57 | 55 |     if ( | 
|  | 
| 72 | 70 | _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"} | 
| 73 | 71 | 
 | 
| 74 | 72 | 
 | 
|  | 73 | +def _dequantize_weight_for_expanded_lora(model, module): | 
|  | 74 | +    if is_bitsandbytes_available(): | 
|  | 75 | +        from ..quantizers.bitsandbytes import dequantize_bnb_weight | 
|  | 76 | + | 
|  | 77 | +    if is_gguf_available(): | 
|  | 78 | +        from ..quantizers.gguf.utils import dequantize_gguf_tensor | 
|  | 79 | + | 
|  | 80 | +    is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit" | 
|  | 81 | +    is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter" | 
|  | 82 | + | 
|  | 83 | +    if is_bnb_4bit_quantized and not is_bitsandbytes_available(): | 
|  | 84 | +        raise ValueError( | 
|  | 85 | +            "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints." | 
|  | 86 | +        ) | 
|  | 87 | +    if is_gguf_quantized and not is_gguf_available(): | 
|  | 88 | +        raise ValueError( | 
|  | 89 | +            "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints." | 
|  | 90 | +        ) | 
|  | 91 | + | 
|  | 92 | +    weight_on_cpu = False | 
|  | 93 | +    if not module.weight.is_cuda: | 
|  | 94 | +        weight_on_cpu = True | 
|  | 95 | + | 
|  | 96 | +    if is_bnb_4bit_quantized: | 
|  | 97 | +        module_weight = dequantize_bnb_weight( | 
|  | 98 | +            module.weight.cuda() if weight_on_cpu else module.weight, | 
|  | 99 | +            state=module.weight.quant_state, | 
|  | 100 | +            dtype=model.dtype, | 
|  | 101 | +        ).data | 
|  | 102 | +    elif is_gguf_quantized: | 
|  | 103 | +        module_weight = dequantize_gguf_tensor( | 
|  | 104 | +            module.weight.cuda() if weight_on_cpu else module.weight, | 
|  | 105 | +        ) | 
|  | 106 | +        module_weight = module_weight.to(model.dtype) | 
|  | 107 | +    else: | 
|  | 108 | +        module_weight = module.weight.data | 
|  | 109 | + | 
|  | 110 | +    if weight_on_cpu: | 
|  | 111 | +        module_weight = module_weight.cpu() | 
|  | 112 | + | 
|  | 113 | +    return module_weight | 
|  | 114 | + | 
|  | 115 | + | 
| 75 | 116 | class StableDiffusionLoraLoaderMixin(LoraBaseMixin): | 
| 76 | 117 |     r""" | 
| 77 | 118 |     Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and | 
| @@ -1970,26 +2011,10 @@ def _maybe_expand_transformer_param_shape_or_error_( | 
| 1970 | 2011 |         overwritten_params = {} | 
| 1971 | 2012 | 
 | 
| 1972 | 2013 |         is_peft_loaded = getattr(transformer, "peft_config", None) is not None | 
|  | 2014 | +        is_quantized = hasattr(transformer, "hf_quantizer") | 
| 1973 | 2015 |         for name, module in transformer.named_modules(): | 
| 1974 | 2016 |             if isinstance(module, torch.nn.Linear): | 
| 1975 |  | -                is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit" | 
| 1976 |  | -                if is_bnb_4bit_quantized and not is_bitsandbytes_available(): | 
| 1977 |  | -                    raise ValueError( | 
| 1978 |  | -                        "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints." | 
| 1979 |  | -                    ) | 
| 1980 |  | -                elif is_bnb_4bit_quantized: | 
| 1981 |  | -                    weight_on_cpu = False | 
| 1982 |  | -                    if not module.weight.is_cuda: | 
| 1983 |  | -                        weight_on_cpu = True | 
| 1984 |  | -                    module_weight = dequantize_bnb_weight( | 
| 1985 |  | -                        module.weight.cuda() if weight_on_cpu else module.weight, | 
| 1986 |  | -                        state=module.weight.quant_state, | 
| 1987 |  | -                        dtype=transformer.dtype, | 
| 1988 |  | -                    ).data | 
| 1989 |  | -                    if weight_on_cpu: | 
| 1990 |  | -                        module_weight = module_weight.cpu() | 
| 1991 |  | -                else: | 
| 1992 |  | -                    module_weight = module.weight.data | 
|  | 2017 | +                module_weight = module.weight.data | 
| 1993 | 2018 |                 module_bias = module.bias.data if module.bias is not None else None | 
| 1994 | 2019 |                 bias = module_bias is not None | 
| 1995 | 2020 | 
 | 
| @@ -2034,6 +2059,9 @@ def _maybe_expand_transformer_param_shape_or_error_( | 
| 2034 | 2059 |                     parent_module_name, _, current_module_name = name.rpartition(".") | 
| 2035 | 2060 |                     parent_module = transformer.get_submodule(parent_module_name) | 
| 2036 | 2061 | 
 | 
|  | 2062 | +                    if is_quantized: | 
|  | 2063 | +                        module_weight = _dequantize_weight_for_expanded_lora(transformer, module) | 
|  | 2064 | + | 
| 2037 | 2065 |                     with torch.device("meta"): | 
| 2038 | 2066 |                         expanded_module = torch.nn.Linear( | 
| 2039 | 2067 |                             in_features, out_features, bias=bias, dtype=module_weight.dtype | 
| @@ -2134,7 +2162,12 @@ def _calculate_module_shape( | 
| 2134 | 2162 |         base_weight_param_name: str = None, | 
| 2135 | 2163 |     ) -> "torch.Size": | 
| 2136 | 2164 |         def _get_weight_shape(weight: torch.Tensor): | 
| 2137 |  | -            return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape | 
|  | 2165 | +            if weight.__class__.__name__ == "Params4bit": | 
|  | 2166 | +                return weight.quant_state.shape | 
|  | 2167 | +            elif weight.__class__.__name__ == "GGUFParameter": | 
|  | 2168 | +                return weight.quant_shape | 
|  | 2169 | +            else: | 
|  | 2170 | +                return weight.shape | 
| 2138 | 2171 | 
 | 
| 2139 | 2172 |         if base_module is not None: | 
| 2140 | 2173 |             return _get_weight_shape(base_module.weight) | 
|  | 
0 commit comments