Skip to content

Commit a6cc4bf

Browse files
simpletrontdippsychedelicious
authored andcommitted
chore: ruff fix
1 parent f2f95d7 commit a6cc4bf

File tree

5 files changed

+1091
-1061
lines changed

5 files changed

+1091
-1061
lines changed

invokeai/backend/patches/layers/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseL
3535
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
3636

3737

38-
3938
def swap_shift_scale_for_linear_weight(weight: torch.Tensor) -> torch.Tensor:
4039
"""Swap shift/scale for given linear layer back and forth"""
4140
# In SD3 and Flux implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
4241
# while in diffusers it split into scale, shift. This will flip them around
43-
chunk1, chunk2 = weight.chunk(2, dim=0)
42+
chunk1, chunk2 = weight.chunk(2, dim=0)
4443
return torch.cat([chunk2, chunk1], dim=0)
4544

45+
4646
def decomposite_weight_matric_with_rank(
4747
delta: torch.Tensor,
4848
rank: int,
@@ -56,7 +56,7 @@ def decomposite_weight_matric_with_rank(
5656
S_r = S[:rank]
5757
V_r = V[:, :rank]
5858

59-
S_sqrt = torch.sqrt(S_r + epsilon) # regularization
59+
S_sqrt = torch.sqrt(S_r + epsilon) # regularization
6060

6161
up = torch.matmul(U_r, torch.diag(S_sqrt))
6262
down = torch.matmul(torch.diag(S_sqrt), V_r.T)

invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
import torch
44

5-
from invokeai.backend.patches.layers.lora_layer import LoRALayer
65
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
6+
from invokeai.backend.patches.layers.lora_layer import LoRALayer
77
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
8-
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict, swap_shift_scale_for_linear_weight, decomposite_weight_matric_with_rank
8+
from invokeai.backend.patches.layers.utils import (
9+
any_lora_layer_from_state_dict,
10+
decomposite_weight_matric_with_rank,
11+
swap_shift_scale_for_linear_weight,
12+
)
913
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
1014
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
1115

@@ -30,46 +34,47 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
3034

3135
return all_keys_in_peft_format and all_expected_keys_present
3236

37+
3338
def approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRALayer:
34-
'''Approximate given diffusers AdaLN loRA layer in our Flux model'''
39+
"""Approximate given diffusers AdaLN loRA layer in our Flux model"""
3540

36-
if not "lora_up.weight" in state_dict:
41+
if "lora_up.weight" not in state_dict:
3742
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_up")
38-
39-
if not "lora_down.weight" in state_dict:
43+
44+
if "lora_down.weight" not in state_dict:
4045
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_down")
41-
42-
up = state_dict.pop('lora_up.weight')
43-
down = state_dict.pop('lora_down.weight')
4446

45-
# layer-patcher upcast things to f32,
47+
up = state_dict.pop("lora_up.weight")
48+
down = state_dict.pop("lora_down.weight")
49+
50+
# layer-patcher upcast things to f32,
4651
# we want to maintain a better precison for this one
4752
dtype = torch.float32
4853

4954
device = up.device
5055
up_shape = up.shape
5156
down_shape = down.shape
52-
57+
5358
# desired low rank
5459
rank = up_shape[1]
5560

5661
# up scaling for more precise
5762
up = up.to(torch.float32)
5863
down = down.to(torch.float32)
5964

60-
weight = up.reshape(up_shape[0], -1) @ down.reshape(down_shape[0], -1)
65+
weight = up.reshape(up_shape[0], -1) @ down.reshape(down_shape[0], -1)
6166

6267
# swap to our linear format
6368
swapped = swap_shift_scale_for_linear_weight(weight)
6469

6570
_up, _down = decomposite_weight_matric_with_rank(swapped, rank)
6671

67-
assert(_up.shape == up_shape)
68-
assert(_down.shape == down_shape)
72+
assert _up.shape == up_shape
73+
assert _down.shape == down_shape
6974

7075
# down scaling to original dtype, device
71-
state_dict['lora_up.weight'] = _up.to(dtype).to(device=device)
72-
state_dict['lora_down.weight'] = _down.to(dtype).to(device=device)
76+
state_dict["lora_up.weight"] = _up.to(dtype).to(device=device)
77+
state_dict["lora_down.weight"] = _down.to(dtype).to(device=device)
7378

7479
return LoRALayer.from_state_dict_values(state_dict)
7580

@@ -131,7 +136,7 @@ def add_adaLN_lora_layer_if_present(src_key: str, dst_key: str) -> None:
131136
src_layer_dict = grouped_state_dict.pop(src_key)
132137
values = get_lora_layer_values(src_layer_dict)
133138
layers[dst_key] = approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(values)
134-
139+
135140
def add_qkv_lora_layer_if_present(
136141
src_keys: list[str],
137142
src_weight_shapes: list[tuple[int, int]],
@@ -274,8 +279,8 @@ def add_qkv_lora_layer_if_present(
274279
# Final layer.
275280
add_lora_layer_if_present("proj_out", "final_layer.linear")
276281
add_adaLN_lora_layer_if_present(
277-
'norm_out.linear',
278-
'final_layer.adaLN_modulation.1',
282+
"norm_out.linear",
283+
"final_layer.adaLN_modulation.1",
279284
)
280285

281286
# Assert that all keys were processed.

tests/backend/patches/layers/test_layer_utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import torch
22

3-
from invokeai.backend.patches.layers.utils import decomposite_weight_matric_with_rank, swap_shift_scale_for_linear_weight
3+
from invokeai.backend.patches.layers.utils import (
4+
decomposite_weight_matric_with_rank,
5+
swap_shift_scale_for_linear_weight,
6+
)
47

58

69
def test_swap_shift_scale_for_linear_weight():
@@ -9,38 +12,37 @@ def test_swap_shift_scale_for_linear_weight():
912
expected = torch.Tensor([2, 1])
1013

1114
swapped = swap_shift_scale_for_linear_weight(original)
12-
assert(torch.allclose(expected, swapped))
15+
assert torch.allclose(expected, swapped)
1316

14-
size= (3, 4)
17+
size = (3, 4)
1518
first = torch.randn(size)
1619
second = torch.randn(size)
1720

1821
original = torch.concat([first, second])
1922
expected = torch.concat([second, first])
2023

2124
swapped = swap_shift_scale_for_linear_weight(original)
22-
assert(torch.allclose(expected, swapped))
25+
assert torch.allclose(expected, swapped)
2326

2427
# call this twice will reconstruct the original
2528
reconstructed = swap_shift_scale_for_linear_weight(swapped)
26-
assert(torch.allclose(reconstructed, original))
29+
assert torch.allclose(reconstructed, original)
30+
2731

2832
def test_decomposite_weight_matric_with_rank():
2933
"""Test that decompsition of given matrix into 2 low rank matrices work"""
3034
input_dim = 1024
3135
output_dim = 1024
3236
rank = 8 # Low rank
3337

34-
3538
A = torch.randn(input_dim, rank).double()
3639
B = torch.randn(rank, output_dim).double()
3740
W0 = A @ B
3841

3942
C, D = decomposite_weight_matric_with_rank(W0, rank)
4043
R = C @ D
4144

42-
assert(C.shape == A.shape)
43-
assert(D.shape == B.shape)
45+
assert C.shape == A.shape
46+
assert D.shape == B.shape
4447

4548
assert torch.allclose(W0, R)
46-

0 commit comments

Comments
 (0)