3
3
4
4
import bisect
5
5
from pathlib import Path
6
- from typing import Dict , List , Optional , Tuple , Union
6
+ from typing import Dict , List , Optional , Set , Tuple , Union
7
7
8
8
import torch
9
9
from safetensors .torch import load_file
@@ -70,6 +70,15 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
70
70
if self .bias is not None :
71
71
self .bias = self .bias .to (device = device , dtype = dtype )
72
72
73
+ def check_keys (self , values : Dict [str , torch .Tensor ], known_keys : Set [str ]):
74
+ all_known_keys = known_keys | {"alpha" , "bias_indices" , "bias_values" , "bias_size" }
75
+ unknown_keys = set (values .keys ()) - all_known_keys
76
+ if unknown_keys :
77
+ # TODO: how to warn log?
78
+ print (
79
+ f"[WARN] Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: { unknown_keys } "
80
+ )
81
+
73
82
74
83
# TODO: find and debug lora/locon with bias
75
84
class LoRALayer (LoRALayerBase ):
@@ -89,6 +98,14 @@ def __init__(
89
98
self .mid = values .get ("lora_mid.weight" , None )
90
99
91
100
self .rank = self .down .shape [0 ]
101
+ self .check_keys (
102
+ values ,
103
+ {
104
+ "lora_up.weight" ,
105
+ "lora_down.weight" ,
106
+ "lora_mid.weight" ,
107
+ },
108
+ )
92
109
93
110
def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
94
111
if self .mid is not None :
@@ -136,6 +153,17 @@ def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
136
153
self .t2 = values .get ("hada_t2" , None )
137
154
138
155
self .rank = self .w1_b .shape [0 ]
156
+ self .check_keys (
157
+ values ,
158
+ {
159
+ "hada_w1_a" ,
160
+ "hada_w1_b" ,
161
+ "hada_w2_a" ,
162
+ "hada_w2_b" ,
163
+ "hada_t1" ,
164
+ "hada_t2" ,
165
+ },
166
+ )
139
167
140
168
def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
141
169
if self .t1 is None :
@@ -204,6 +232,21 @@ def __init__(
204
232
else :
205
233
self .rank = None # unscaled
206
234
235
+ # Although lokr_t1 not used in algo, it still defined in LoKR weights
236
+ self .check_keys (
237
+ values ,
238
+ {
239
+ "lokr_w1" ,
240
+ "lokr_w1_a" ,
241
+ "lokr_w1_b" ,
242
+ "lokr_w2" ,
243
+ "lokr_w2_a" ,
244
+ "lokr_w2_b" ,
245
+ "lokr_t1" ,
246
+ "lokr_t2" ,
247
+ },
248
+ )
249
+
207
250
def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
208
251
w1 : Optional [torch .Tensor ] = self .w1
209
252
if w1 is None :
@@ -275,6 +318,7 @@ def __init__(
275
318
self .bias = values .get ("diff_b" , None )
276
319
277
320
self .rank = None # unscaled
321
+ self .check_keys (values , {"diff" , "diff_b" })
278
322
279
323
def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
280
324
return self .weight
@@ -305,6 +349,7 @@ def __init__(
305
349
self .on_input = values ["on_input" ]
306
350
307
351
self .rank = None # unscaled
352
+ self .check_keys (values , {"weight" , "on_input" })
308
353
309
354
def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
310
355
weight = self .weight
0 commit comments