Skip to content

Commit 31949ed

Browse files
committed
Refactor code a bit
1 parent 0ccb304 commit 31949ed

File tree

1 file changed

+27
-57
lines changed

1 file changed

+27
-57
lines changed

invokeai/backend/lora.py

Lines changed: 27 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,18 @@ def __init__(
4646
self.rank = None # set in layer implementation
4747
self.layer_key = layer_key
4848

49-
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
49+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
5050
raise NotImplementedError()
5151

52-
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
53-
raise NotImplementedError()
52+
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
53+
return self.bias
54+
55+
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
56+
params = {"weight": self.get_weight(orig_module.weight)}
57+
bias = self.get_bias(orig_module.bias)
58+
if bias is not None:
59+
params["bias"] = bias
60+
return params
5461

5562
def calc_size(self) -> int:
5663
model_size = 0
@@ -79,14 +86,11 @@ def __init__(
7986

8087
self.up = values["lora_up.weight"]
8188
self.down = values["lora_down.weight"]
82-
if "lora_mid.weight" in values:
83-
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
84-
else:
85-
self.mid = None
89+
self.mid = values.get("lora_mid.weight", None)
8690

8791
self.rank = self.down.shape[0]
8892

89-
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
93+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
9094
if self.mid is not None:
9195
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
9296
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
@@ -96,9 +100,6 @@ def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
96100

97101
return weight
98102

99-
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
100-
return {"weight": self.get_weight(orig_module.weight)}
101-
102103
def calc_size(self) -> int:
103104
model_size = super().calc_size()
104105
for val in [self.up, self.mid, self.down]:
@@ -131,20 +132,12 @@ def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
131132
self.w1_b = values["hada_w1_b"]
132133
self.w2_a = values["hada_w2_a"]
133134
self.w2_b = values["hada_w2_b"]
134-
135-
if "hada_t1" in values:
136-
self.t1: Optional[torch.Tensor] = values["hada_t1"]
137-
else:
138-
self.t1 = None
139-
140-
if "hada_t2" in values:
141-
self.t2: Optional[torch.Tensor] = values["hada_t2"]
142-
else:
143-
self.t2 = None
135+
self.t1 = values.get("hada_t1", None)
136+
self.t2 = values.get("hada_t2", None)
144137

145138
self.rank = self.w1_b.shape[0]
146139

147-
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
140+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
148141
if self.t1 is None:
149142
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
150143

@@ -155,9 +148,6 @@ def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
155148

156149
return weight
157150

158-
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
159-
return {"weight": self.get_weight(orig_module.weight)}
160-
161151
def calc_size(self) -> int:
162152
model_size = super().calc_size()
163153
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
@@ -195,37 +185,26 @@ def __init__(
195185
):
196186
super().__init__(layer_key, values)
197187

198-
if "lokr_w1" in values:
199-
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
200-
self.w1_a = None
201-
self.w1_b = None
202-
else:
203-
self.w1 = None
188+
self.w1 = values.get("lokr_w1", None)
189+
if self.w1 is None:
204190
self.w1_a = values["lokr_w1_a"]
205191
self.w1_b = values["lokr_w1_b"]
206192

207-
if "lokr_w2" in values:
208-
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
209-
self.w2_a = None
210-
self.w2_b = None
211-
else:
212-
self.w2 = None
193+
self.w2 = values.get("lokr_w2", None)
194+
if self.w2 is None:
213195
self.w2_a = values["lokr_w2_a"]
214196
self.w2_b = values["lokr_w2_b"]
215197

216-
if "lokr_t2" in values:
217-
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
218-
else:
219-
self.t2 = None
198+
self.t2 = values.get("lokr_t2", None)
220199

221-
if "lokr_w1_b" in values:
222-
self.rank = values["lokr_w1_b"].shape[0]
223-
elif "lokr_w2_b" in values:
224-
self.rank = values["lokr_w2_b"].shape[0]
200+
if self.w1_b is not None:
201+
self.rank = self.w1_b.shape[0]
202+
elif self.w2_b is not None:
203+
self.rank = self.w2_b.shape[0]
225204
else:
226205
self.rank = None # unscaled
227206

228-
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
207+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
229208
w1: Optional[torch.Tensor] = self.w1
230209
if w1 is None:
231210
assert self.w1_a is not None
@@ -250,9 +229,6 @@ def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
250229

251230
return weight
252231

253-
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
254-
return {"weight": self.get_weight(orig_module.weight)}
255-
256232
def calc_size(self) -> int:
257233
model_size = super().calc_size()
258234
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
@@ -302,12 +278,9 @@ def __init__(
302278

303279
self.rank = None # unscaled
304280

305-
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
281+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
306282
return self.weight
307283

308-
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
309-
return {"weight": self.get_weight(orig_module.weight)}
310-
311284
def calc_size(self) -> int:
312285
model_size = super().calc_size()
313286
model_size += self.weight.nelement() * self.weight.element_size()
@@ -335,16 +308,13 @@ def __init__(
335308

336309
self.rank = None # unscaled
337310

338-
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
311+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
339312
weight = self.weight
340313
if not self.on_input:
341314
weight = weight.reshape(-1, 1)
342315
assert orig_weight is not None
343316
return orig_weight * weight
344317

345-
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
346-
return {"weight": self.get_weight(orig_module.weight)}
347-
348318
def calc_size(self) -> int:
349319
model_size = super().calc_size()
350320
model_size += self.weight.nelement() * self.weight.element_size()

0 commit comments

Comments
 (0)