Skip to content

Commit 3876c4c

Browse files
feat: add missing adaLN layer in lora conversion
1 parent fa1d214 commit 3876c4c

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,19 @@ def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
9999
values = get_lora_layer_values(src_layer_dict)
100100
layers[dst_key] = any_lora_layer_from_state_dict(values)
101101

102+
def add_lora_adaLN_layer_if_present(src_key: str, dst_key: str) -> None:
103+
if src_key in grouped_state_dict:
104+
src_layer_dict = grouped_state_dict.pop(src_key)
105+
values = get_lora_layer_values(src_layer_dict)
106+
107+
for _key in values.keys():
108+
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
109+
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
110+
scale, shift = values[_key].chunk(2, dim=0)
111+
values[_key] = torch.cat([shift, scale], dim=0)
112+
113+
layers[dst_key] = any_lora_layer_from_state_dict(values)
114+
102115
def add_qkv_lora_layer_if_present(
103116
src_keys: list[str],
104117
src_weight_shapes: list[tuple[int, int]],
@@ -240,6 +253,10 @@ def add_qkv_lora_layer_if_present(
240253

241254
# Final layer.
242255
add_lora_layer_if_present("proj_out", "final_layer.linear")
256+
add_lora_adaLN_layer_if_present(
257+
'norm_out.linear',
258+
'final_layer.adaLN_modulation.1',
259+
)
243260

244261
# Assert that all keys were processed.
245262
assert len(grouped_state_dict) == 0

0 commit comments

Comments
 (0)