@@ -99,6 +99,19 @@ def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
99
99
values = get_lora_layer_values (src_layer_dict )
100
100
layers [dst_key ] = any_lora_layer_from_state_dict (values )
101
101
102
+ def add_lora_adaLN_layer_if_present (src_key : str , dst_key : str ) -> None :
103
+ if src_key in grouped_state_dict :
104
+ src_layer_dict = grouped_state_dict .pop (src_key )
105
+ values = get_lora_layer_values (src_layer_dict )
106
+
107
+ for _key in values .keys ():
108
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
109
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
110
+ scale , shift = values [_key ].chunk (2 , dim = 0 )
111
+ values [_key ] = torch .cat ([shift , scale ], dim = 0 )
112
+
113
+ layers [dst_key ] = any_lora_layer_from_state_dict (values )
114
+
102
115
def add_qkv_lora_layer_if_present (
103
116
src_keys : list [str ],
104
117
src_weight_shapes : list [tuple [int , int ]],
@@ -240,6 +253,10 @@ def add_qkv_lora_layer_if_present(
240
253
241
254
# Final layer.
242
255
add_lora_layer_if_present ("proj_out" , "final_layer.linear" )
256
+ add_lora_adaLN_layer_if_present (
257
+ 'norm_out.linear' ,
258
+ 'final_layer.adaLN_modulation.1' ,
259
+ )
243
260
244
261
# Assert that all keys were processed.
245
262
assert len (grouped_state_dict ) == 0
0 commit comments