Skip to content

Commit bb2e49a

Browse files
feat: add new layer type for diffusers-ada-ln
1 parent 67cac7d commit bb2e49a

File tree

3 files changed

+27
-10
lines changed

3 files changed

+27
-10
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
3+
from invokeai.backend.patches.layers.lora_layer import LoRALayer
4+
5+
class DiffusersAdaLN_LoRALayer(LoRALayer):
6+
'''LoRA layer converted from Diffusers AdaLN, weight is shift-scale swapped'''
7+
8+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
9+
# In SD3 and Flux implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
10+
# while in diffusers it split into scale, shift.
11+
# So we swap the linear projection weights in order to be able to use Flux implementation
12+
13+
weight = super().get_weight(orig_weight)
14+
scale, shift = weight.chunk(2, dim=0)
15+
16+
return torch.cat([shift, scale], dim=0)

invokeai/backend/patches/layers/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from invokeai.backend.patches.layers.lokr_layer import LoKRLayer
1111
from invokeai.backend.patches.layers.lora_layer import LoRALayer
1212
from invokeai.backend.patches.layers.norm_layer import NormLayer
13+
from invokeai.backend.patches.layers.diffusers_ada_ln_lora_layer import DiffusersAdaLN_LoRALayer
1314

1415

1516
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseLayerPatch:
@@ -33,3 +34,10 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseL
3334
return NormLayer.from_state_dict_values(state_dict)
3435
else:
3536
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
37+
38+
39+
def diffusers_adaLN_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> DiffusersAdaLN_LoRALayer:
40+
if not "lora_up.weight" in state_dict:
41+
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
42+
43+
return DiffusersAdaLN_LoRALayer.from_state_dict_values(state_dict)

invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
66
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
7-
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
7+
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict, diffusers_adaLN_lora_layer_from_state_dict
88
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
99
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
1010

@@ -103,15 +103,8 @@ def add_adaLN_lora_layer_if_present(src_key: str, dst_key: str) -> None:
103103
if src_key in grouped_state_dict:
104104
src_layer_dict = grouped_state_dict.pop(src_key)
105105
values = get_lora_layer_values(src_layer_dict)
106-
107-
for _key in values.keys():
108-
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
109-
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
110-
scale, shift = values[_key].chunk(2, dim=0)
111-
values[_key] = torch.cat([shift, scale], dim=0)
112-
113-
layers[dst_key] = any_lora_layer_from_state_dict(values)
114-
106+
layers[dst_key] = diffusers_adaLN_lora_layer_from_state_dict(values)
107+
115108
def add_qkv_lora_layer_if_present(
116109
src_keys: list[str],
117110
src_weight_shapes: list[tuple[int, int]],

0 commit comments

Comments
 (0)