Skip to content

Commit c878418

Browse files
feat: refactor conversion module, add test for svd correctness
1 parent 4f00a7b commit c878418

File tree

3 files changed

+102
-43
lines changed

3 files changed

+102
-43
lines changed

invokeai/backend/patches/layers/utils.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def swap_shift_scale_for_linear_weight(weight: torch.Tensor) -> torch.Tensor:
4646
def decomposite_weight_matric_with_rank(
4747
delta: torch.Tensor,
4848
rank: int,
49+
epsilon: float = 1e-8,
4950
) -> Tuple[torch.Tensor, torch.Tensor]:
5051
"""Decompose given matrix with a specified rank."""
5152
U, S, V = torch.svd(delta)
@@ -55,50 +56,9 @@ def decomposite_weight_matric_with_rank(
5556
S_r = S[:rank]
5657
V_r = V[:, :rank]
5758

58-
S_sqrt = torch.sqrt(S_r)
59+
S_sqrt = torch.sqrt(S_r + epsilon) # regularization
5960

6061
up = torch.matmul(U_r, torch.diag(S_sqrt))
6162
down = torch.matmul(torch.diag(S_sqrt), V_r.T)
6263

6364
return up, down
64-
65-
66-
def approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRALayer:
67-
'''Approximate given diffusers AdaLN loRA layer in our Flux model'''
68-
69-
if not "lora_up.weight" in state_dict:
70-
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_up")
71-
72-
if not "lora_down.weight" in state_dict:
73-
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_down")
74-
75-
up = state_dict.pop('lora_up.weight')
76-
down = state_dict.pop('lora_down.weight')
77-
78-
dtype = up.dtype
79-
device = up.device
80-
up_shape = up.shape
81-
down_shape = down.shape
82-
83-
# desired low rank
84-
rank = up_shape[1]
85-
86-
# up scaling for more precise
87-
up.double()
88-
down.double()
89-
weight = up.reshape(up.shape[0], -1) @ down.reshape(down.shape[0], -1)
90-
91-
# swap to our linear format
92-
swapped = swap_shift_scale_for_linear_weight(weight)
93-
94-
_up, _down = decomposite_weight_matric_with_rank(swapped, rank)
95-
96-
assert(_up.shape == up_shape)
97-
assert(_down.shape == down_shape)
98-
99-
# down scaling to original dtype, device
100-
state_dict['lora_up.weight'] = _up.to(dtype).to(device=device)
101-
state_dict['lora_down.weight'] = _down.to(dtype).to(device=device)
102-
103-
return LoRALayer.from_state_dict_values(state_dict)
104-

invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import torch
44

5+
from invokeai.backend.patches.layers.lora_layer import LoRALayer
56
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
67
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
7-
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict, approximate_flux_adaLN_lora_layer_from_diffusers_state_dict
8+
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict, swap_shift_scale_for_linear_weight, decomposite_weight_matric_with_rank
89
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
910
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
1011

@@ -38,6 +39,49 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
3839

3940
return all_keys_in_peft_format and (transformer_keys_present or base_model_keys_present)
4041

42+
def approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRALayer:
43+
'''Approximate given diffusers AdaLN loRA layer in our Flux model'''
44+
45+
if not "lora_up.weight" in state_dict:
46+
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_up")
47+
48+
if not "lora_down.weight" in state_dict:
49+
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_down")
50+
51+
up = state_dict.pop('lora_up.weight')
52+
down = state_dict.pop('lora_down.weight')
53+
54+
# layer-patcher upcast things to f32,
55+
# we want to maintain a better precison for this one
56+
dtype = torch.float32
57+
58+
device = up.device
59+
up_shape = up.shape
60+
down_shape = down.shape
61+
62+
# desired low rank
63+
rank = up_shape[1]
64+
65+
# up scaling for more precise
66+
up = up.to(torch.float32)
67+
down = down.to(torch.float32)
68+
69+
weight = up.reshape(up_shape[0], -1) @ down.reshape(down_shape[0], -1)
70+
71+
# swap to our linear format
72+
swapped = swap_shift_scale_for_linear_weight(weight)
73+
74+
_up, _down = decomposite_weight_matric_with_rank(swapped, rank)
75+
76+
assert(_up.shape == up_shape)
77+
assert(_down.shape == down_shape)
78+
79+
# down scaling to original dtype, device
80+
state_dict['lora_up.weight'] = _up.to(dtype).to(device=device)
81+
state_dict['lora_down.weight'] = _down.to(dtype).to(device=device)
82+
83+
return LoRALayer.from_state_dict_values(state_dict)
84+
4185

4286
def lora_model_from_flux_diffusers_state_dict(
4387
state_dict: Dict[str, torch.Tensor], alpha: float | None

tests/backend/patches/lora_conversions/test_flux_diffusers_lora_conversion_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import pytest
22
import torch
33

4+
5+
from invokeai.backend.patches.layers.utils import swap_shift_scale_for_linear_weight
46
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
57
is_state_dict_likely_in_flux_diffusers_format,
68
lora_model_from_flux_diffusers_state_dict,
9+
approximate_flux_adaLN_lora_layer_from_diffusers_state_dict,
710
)
811
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
912
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
@@ -97,3 +100,55 @@ def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():
97100
# Check that an error is raised.
98101
with pytest.raises(AssertionError):
99102
lora_model_from_flux_diffusers_state_dict(state_dict, alpha=8.0)
103+
104+
105+
@pytest.mark.parametrize("layer_sd_keys",[
106+
{}, # no keys
107+
{'lora_A.weight': [1024, 8], 'lora_B.weight': [8, 512]}, # wrong keys
108+
{'lora_up.weight': [1024, 8],}, # missing key
109+
{'lora_down.weight': [8, 512],}, # missing key
110+
])
111+
def test_approximate_adaLN_from_state_dict_should_only_accept_vanilla_LoRA_format(layer_sd_keys: dict[str, list[int]]):
112+
"""Should only accept the valid state dict"""
113+
layer_state_dict = keys_to_mock_state_dict(layer_sd_keys)
114+
115+
with pytest.raises(ValueError):
116+
approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(layer_state_dict)
117+
118+
119+
@pytest.mark.parametrize("dtype, rtol", [
120+
(torch.float32, 1e-4),
121+
(torch.half, 1e-3),
122+
])
123+
def test_approximate_adaLN_from_state_dict_should_work(dtype: torch.dtype, rtol: float, rate: float = 0.99):
124+
"""Test that we should approximate good enough adaLN layer from diffusers state dict.
125+
This should tolorance some kind of errorness respect to input dtype"""
126+
input_dim = 1024
127+
output_dim = 512
128+
rank = 8 # Low rank
129+
total = input_dim * output_dim
130+
131+
up = torch.randn(input_dim, rank, dtype=dtype)
132+
down = torch.randn(rank, output_dim, dtype=dtype)
133+
134+
layer_state_dict = {
135+
'lora_up.weight': up,
136+
'lora_down.weight': down
137+
}
138+
139+
# XXX Layer patcher cast things to f32
140+
original = up.float() @ down.float()
141+
swapped = swap_shift_scale_for_linear_weight(original)
142+
143+
layer = approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(layer_state_dict)
144+
weight = layer.get_weight(original).float()
145+
146+
print(weight.dtype, swapped.dtype, layer.up.dtype)
147+
148+
close_count = torch.isclose(weight, swapped, rtol=rtol).sum().item()
149+
close_rate = close_count / total
150+
151+
assert close_rate > rate
152+
153+
154+

0 commit comments

Comments
 (0)