2
2
3
3
import torch
4
4
5
- from invokeai .backend .patches .layers .lora_layer import LoRALayer
6
5
from invokeai .backend .patches .layers .base_layer_patch import BaseLayerPatch
6
+ from invokeai .backend .patches .layers .lora_layer import LoRALayer
7
7
from invokeai .backend .patches .layers .merged_layer_patch import MergedLayerPatch , Range
8
- from invokeai .backend .patches .layers .utils import any_lora_layer_from_state_dict , swap_shift_scale_for_linear_weight , decomposite_weight_matric_with_rank
8
+ from invokeai .backend .patches .layers .utils import (
9
+ any_lora_layer_from_state_dict ,
10
+ decomposite_weight_matric_with_rank ,
11
+ swap_shift_scale_for_linear_weight ,
12
+ )
9
13
from invokeai .backend .patches .lora_conversions .flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
10
14
from invokeai .backend .patches .model_patch_raw import ModelPatchRaw
11
15
@@ -30,46 +34,47 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
30
34
31
35
return all_keys_in_peft_format and all_expected_keys_present
32
36
37
+
33
38
def approximate_flux_adaLN_lora_layer_from_diffusers_state_dict (state_dict : Dict [str , torch .Tensor ]) -> LoRALayer :
34
- ''' Approximate given diffusers AdaLN loRA layer in our Flux model'''
39
+ """ Approximate given diffusers AdaLN loRA layer in our Flux model"""
35
40
36
- if not "lora_up.weight" in state_dict :
41
+ if "lora_up.weight" not in state_dict :
37
42
raise ValueError (f"Unsupported lora format: { state_dict .keys ()} , missing lora_up" )
38
-
39
- if not "lora_down.weight" in state_dict :
43
+
44
+ if "lora_down.weight" not in state_dict :
40
45
raise ValueError (f"Unsupported lora format: { state_dict .keys ()} , missing lora_down" )
41
-
42
- up = state_dict .pop ('lora_up.weight' )
43
- down = state_dict .pop ('lora_down.weight' )
44
46
45
- # layer-patcher upcast things to f32,
47
+ up = state_dict .pop ("lora_up.weight" )
48
+ down = state_dict .pop ("lora_down.weight" )
49
+
50
+ # layer-patcher upcast things to f32,
46
51
# we want to maintain a better precison for this one
47
52
dtype = torch .float32
48
53
49
54
device = up .device
50
55
up_shape = up .shape
51
56
down_shape = down .shape
52
-
57
+
53
58
# desired low rank
54
59
rank = up_shape [1 ]
55
60
56
61
# up scaling for more precise
57
62
up = up .to (torch .float32 )
58
63
down = down .to (torch .float32 )
59
64
60
- weight = up .reshape (up_shape [0 ], - 1 ) @ down .reshape (down_shape [0 ], - 1 )
65
+ weight = up .reshape (up_shape [0 ], - 1 ) @ down .reshape (down_shape [0 ], - 1 )
61
66
62
67
# swap to our linear format
63
68
swapped = swap_shift_scale_for_linear_weight (weight )
64
69
65
70
_up , _down = decomposite_weight_matric_with_rank (swapped , rank )
66
71
67
- assert ( _up .shape == up_shape )
68
- assert ( _down .shape == down_shape )
72
+ assert _up .shape == up_shape
73
+ assert _down .shape == down_shape
69
74
70
75
# down scaling to original dtype, device
71
- state_dict [' lora_up.weight' ] = _up .to (dtype ).to (device = device )
72
- state_dict [' lora_down.weight' ] = _down .to (dtype ).to (device = device )
76
+ state_dict [" lora_up.weight" ] = _up .to (dtype ).to (device = device )
77
+ state_dict [" lora_down.weight" ] = _down .to (dtype ).to (device = device )
73
78
74
79
return LoRALayer .from_state_dict_values (state_dict )
75
80
@@ -131,7 +136,7 @@ def add_adaLN_lora_layer_if_present(src_key: str, dst_key: str) -> None:
131
136
src_layer_dict = grouped_state_dict .pop (src_key )
132
137
values = get_lora_layer_values (src_layer_dict )
133
138
layers [dst_key ] = approximate_flux_adaLN_lora_layer_from_diffusers_state_dict (values )
134
-
139
+
135
140
def add_qkv_lora_layer_if_present (
136
141
src_keys : list [str ],
137
142
src_weight_shapes : list [tuple [int , int ]],
@@ -274,8 +279,8 @@ def add_qkv_lora_layer_if_present(
274
279
# Final layer.
275
280
add_lora_layer_if_present ("proj_out" , "final_layer.linear" )
276
281
add_adaLN_lora_layer_if_present (
277
- ' norm_out.linear' ,
278
- ' final_layer.adaLN_modulation.1' ,
282
+ " norm_out.linear" ,
283
+ " final_layer.adaLN_modulation.1" ,
279
284
)
280
285
281
286
# Assert that all keys were processed.
0 commit comments