- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
[LoRA] feat: support loading loras into 4bit quantized Flux models. #10578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
779c17b
              f46ba42
              d3d8ef2
              8b13c1e
              c92758f
              a3f533b
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -21,6 +21,7 @@ | |
| from ..utils import ( | ||
| USE_PEFT_BACKEND, | ||
| deprecate, | ||
| get_submodule_by_name, | ||
| is_peft_available, | ||
| is_peft_version, | ||
| is_torch_version, | ||
|  | @@ -1981,10 +1982,16 @@ def _maybe_expand_transformer_param_shape_or_error_( | |
| in_features = state_dict[lora_A_weight_name].shape[1] | ||
| out_features = state_dict[lora_B_weight_name].shape[0] | ||
|  | ||
| # Model maybe loaded with different quantization schemes which may flatten the params. | ||
| # `bitsandbytes`, for example, flatten the weights when using 4bit. | ||
| module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module) | ||
|  | ||
| # This means there's no need for an expansion in the params, so we simply skip. | ||
| if tuple(module_weight.shape) == (out_features, in_features): | ||
| if tuple(module_weight_shape) == (out_features, in_features): | ||
| continue | ||
|  | ||
| # TODO (sayakpaul): We still need to consider if the module we're expanding is | ||
| # quantized and handle it accordingly if that is the case. | ||
| module_out_features, module_in_features = module_weight.shape | ||
| debug_message = "" | ||
| if in_features > module_in_features: | ||
|  | @@ -2080,13 +2087,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): | |
| base_weight_param = transformer_state_dict[base_param_name] | ||
| lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] | ||
|  | ||
| if base_weight_param.shape[1] > lora_A_param.shape[1]: | ||
| # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization. | ||
| base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name) | ||
|  | ||
| if base_module_shape[1] > lora_A_param.shape[1]: | ||
| shape = (lora_A_param.shape[0], base_weight_param.shape[1]) | ||
| expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) | ||
| expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) | ||
| lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight | ||
| expanded_module_names.add(k) | ||
| elif base_weight_param.shape[1] < lora_A_param.shape[1]: | ||
| elif base_module_shape[1] < lora_A_param.shape[1]: | ||
| raise NotImplementedError( | ||
| f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." | ||
| ) | ||
|  | @@ -2098,6 +2108,28 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): | |
|  | ||
| return lora_state_dict | ||
|  | ||
| @staticmethod | ||
| def _calculate_module_shape( | ||
| model: "torch.nn.Module", | ||
| base_module: "torch.nn.Linear" = None, | ||
| base_weight_param_name: str = None, | ||
| ) -> "torch.Size": | ||
| def _get_weight_shape(weight: torch.Tensor): | ||
| return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape | ||
|  | ||
| if base_module is not None: | ||
| return _get_weight_shape(base_module.weight) | ||
| elif base_weight_param_name is not None: | ||
| module_path = ( | ||
| base_weight_param_name.rsplit(".weight", 1)[0] | ||
| if base_weight_param_name.endswith(".weight") | ||
| else base_weight_param_name | ||
|          | ||
| ) | ||
| submodule = get_submodule_by_name(model, module_path) | ||
| return _get_weight_shape(submodule.weight) | ||
|  | ||
| raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") | ||
|  | ||
|  | ||
| # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially | ||
| # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8bit params preserve the original shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They do.