Skip to content

Commit 653f63a

Browse files
committed
Add layer keys check
1 parent 8a9e2f5 commit 653f63a

File tree

1 file changed

+46
-1
lines changed

1 file changed

+46
-1
lines changed

invokeai/backend/lora.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import bisect
55
from pathlib import Path
6-
from typing import Dict, List, Optional, Tuple, Union
6+
from typing import Dict, List, Optional, Set, Tuple, Union
77

88
import torch
99
from safetensors.torch import load_file
@@ -70,6 +70,15 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
7070
if self.bias is not None:
7171
self.bias = self.bias.to(device=device, dtype=dtype)
7272

73+
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
74+
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
75+
unknown_keys = set(values.keys()) - all_known_keys
76+
if unknown_keys:
77+
# TODO: how to warn log?
78+
print(
79+
f"[WARN] Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
80+
)
81+
7382

7483
# TODO: find and debug lora/locon with bias
7584
class LoRALayer(LoRALayerBase):
@@ -89,6 +98,14 @@ def __init__(
8998
self.mid = values.get("lora_mid.weight", None)
9099

91100
self.rank = self.down.shape[0]
101+
self.check_keys(
102+
values,
103+
{
104+
"lora_up.weight",
105+
"lora_down.weight",
106+
"lora_mid.weight",
107+
},
108+
)
92109

93110
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
94111
if self.mid is not None:
@@ -136,6 +153,17 @@ def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
136153
self.t2 = values.get("hada_t2", None)
137154

138155
self.rank = self.w1_b.shape[0]
156+
self.check_keys(
157+
values,
158+
{
159+
"hada_w1_a",
160+
"hada_w1_b",
161+
"hada_w2_a",
162+
"hada_w2_b",
163+
"hada_t1",
164+
"hada_t2",
165+
},
166+
)
139167

140168
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
141169
if self.t1 is None:
@@ -204,6 +232,21 @@ def __init__(
204232
else:
205233
self.rank = None # unscaled
206234

235+
# Although lokr_t1 not used in algo, it still defined in LoKR weights
236+
self.check_keys(
237+
values,
238+
{
239+
"lokr_w1",
240+
"lokr_w1_a",
241+
"lokr_w1_b",
242+
"lokr_w2",
243+
"lokr_w2_a",
244+
"lokr_w2_b",
245+
"lokr_t1",
246+
"lokr_t2",
247+
},
248+
)
249+
207250
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
208251
w1: Optional[torch.Tensor] = self.w1
209252
if w1 is None:
@@ -275,6 +318,7 @@ def __init__(
275318
self.bias = values.get("diff_b", None)
276319

277320
self.rank = None # unscaled
321+
self.check_keys(values, {"diff", "diff_b"})
278322

279323
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
280324
return self.weight
@@ -305,6 +349,7 @@ def __init__(
305349
self.on_input = values["on_input"]
306350

307351
self.rank = None # unscaled
352+
self.check_keys(values, {"weight", "on_input"})
308353

309354
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
310355
weight = self.weight

0 commit comments

Comments
 (0)