@@ -46,11 +46,18 @@ def __init__(
46
46
self .rank = None # set in layer implementation
47
47
self .layer_key = layer_key
48
48
49
- def get_weight (self , orig_weight : Optional [ torch .Tensor ] ) -> torch .Tensor :
49
+ def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
50
50
raise NotImplementedError ()
51
51
52
- def get_parameters (self , orig_module : Optional [torch .nn .Module ]) -> Dict [str , torch .Tensor ]:
53
- raise NotImplementedError ()
52
+ def get_bias (self , orig_bias : torch .Tensor ) -> Optional [torch .Tensor ]:
53
+ return self .bias
54
+
55
+ def get_parameters (self , orig_module : torch .nn .Module ) -> Dict [str , torch .Tensor ]:
56
+ params = {"weight" : self .get_weight (orig_module .weight )}
57
+ bias = self .get_bias (orig_module .bias )
58
+ if bias is not None :
59
+ params ["bias" ] = bias
60
+ return params
54
61
55
62
def calc_size (self ) -> int :
56
63
model_size = 0
@@ -79,14 +86,11 @@ def __init__(
79
86
80
87
self .up = values ["lora_up.weight" ]
81
88
self .down = values ["lora_down.weight" ]
82
- if "lora_mid.weight" in values :
83
- self .mid : Optional [torch .Tensor ] = values ["lora_mid.weight" ]
84
- else :
85
- self .mid = None
89
+ self .mid = values .get ("lora_mid.weight" , None )
86
90
87
91
self .rank = self .down .shape [0 ]
88
92
89
- def get_weight (self , orig_weight : Optional [ torch .Tensor ] ) -> torch .Tensor :
93
+ def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
90
94
if self .mid is not None :
91
95
up = self .up .reshape (self .up .shape [0 ], self .up .shape [1 ])
92
96
down = self .down .reshape (self .down .shape [0 ], self .down .shape [1 ])
@@ -96,9 +100,6 @@ def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
96
100
97
101
return weight
98
102
99
- def get_parameters (self , orig_module : Optional [torch .nn .Module ]) -> Dict [str , torch .Tensor ]:
100
- return {"weight" : self .get_weight (orig_module .weight )}
101
-
102
103
def calc_size (self ) -> int :
103
104
model_size = super ().calc_size ()
104
105
for val in [self .up , self .mid , self .down ]:
@@ -131,20 +132,12 @@ def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
131
132
self .w1_b = values ["hada_w1_b" ]
132
133
self .w2_a = values ["hada_w2_a" ]
133
134
self .w2_b = values ["hada_w2_b" ]
134
-
135
- if "hada_t1" in values :
136
- self .t1 : Optional [torch .Tensor ] = values ["hada_t1" ]
137
- else :
138
- self .t1 = None
139
-
140
- if "hada_t2" in values :
141
- self .t2 : Optional [torch .Tensor ] = values ["hada_t2" ]
142
- else :
143
- self .t2 = None
135
+ self .t1 = values .get ("hada_t1" , None )
136
+ self .t2 = values .get ("hada_t2" , None )
144
137
145
138
self .rank = self .w1_b .shape [0 ]
146
139
147
- def get_weight (self , orig_weight : Optional [ torch .Tensor ] ) -> torch .Tensor :
140
+ def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
148
141
if self .t1 is None :
149
142
weight : torch .Tensor = (self .w1_a @ self .w1_b ) * (self .w2_a @ self .w2_b )
150
143
@@ -155,9 +148,6 @@ def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
155
148
156
149
return weight
157
150
158
- def get_parameters (self , orig_module : Optional [torch .nn .Module ]) -> Dict [str , torch .Tensor ]:
159
- return {"weight" : self .get_weight (orig_module .weight )}
160
-
161
151
def calc_size (self ) -> int :
162
152
model_size = super ().calc_size ()
163
153
for val in [self .w1_a , self .w1_b , self .w2_a , self .w2_b , self .t1 , self .t2 ]:
@@ -195,37 +185,26 @@ def __init__(
195
185
):
196
186
super ().__init__ (layer_key , values )
197
187
198
- if "lokr_w1" in values :
199
- self .w1 : Optional [torch .Tensor ] = values ["lokr_w1" ]
200
- self .w1_a = None
201
- self .w1_b = None
202
- else :
203
- self .w1 = None
188
+ self .w1 = values .get ("lokr_w1" , None )
189
+ if self .w1 is None :
204
190
self .w1_a = values ["lokr_w1_a" ]
205
191
self .w1_b = values ["lokr_w1_b" ]
206
192
207
- if "lokr_w2" in values :
208
- self .w2 : Optional [torch .Tensor ] = values ["lokr_w2" ]
209
- self .w2_a = None
210
- self .w2_b = None
211
- else :
212
- self .w2 = None
193
+ self .w2 = values .get ("lokr_w2" , None )
194
+ if self .w2 is None :
213
195
self .w2_a = values ["lokr_w2_a" ]
214
196
self .w2_b = values ["lokr_w2_b" ]
215
197
216
- if "lokr_t2" in values :
217
- self .t2 : Optional [torch .Tensor ] = values ["lokr_t2" ]
218
- else :
219
- self .t2 = None
198
+ self .t2 = values .get ("lokr_t2" , None )
220
199
221
- if "lokr_w1_b" in values :
222
- self .rank = values [ "lokr_w1_b" ] .shape [0 ]
223
- elif "lokr_w2_b" in values :
224
- self .rank = values [ "lokr_w2_b" ] .shape [0 ]
200
+ if self . w1_b is not None :
201
+ self .rank = self . w1_b .shape [0 ]
202
+ elif self . w2_b is not None :
203
+ self .rank = self . w2_b .shape [0 ]
225
204
else :
226
205
self .rank = None # unscaled
227
206
228
- def get_weight (self , orig_weight : Optional [ torch .Tensor ] ) -> torch .Tensor :
207
+ def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
229
208
w1 : Optional [torch .Tensor ] = self .w1
230
209
if w1 is None :
231
210
assert self .w1_a is not None
@@ -250,9 +229,6 @@ def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
250
229
251
230
return weight
252
231
253
- def get_parameters (self , orig_module : Optional [torch .nn .Module ]) -> Dict [str , torch .Tensor ]:
254
- return {"weight" : self .get_weight (orig_module .weight )}
255
-
256
232
def calc_size (self ) -> int :
257
233
model_size = super ().calc_size ()
258
234
for val in [self .w1 , self .w1_a , self .w1_b , self .w2 , self .w2_a , self .w2_b , self .t2 ]:
@@ -302,12 +278,9 @@ def __init__(
302
278
303
279
self .rank = None # unscaled
304
280
305
- def get_weight (self , orig_weight : Optional [ torch .Tensor ] ) -> torch .Tensor :
281
+ def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
306
282
return self .weight
307
283
308
- def get_parameters (self , orig_module : Optional [torch .nn .Module ]) -> Dict [str , torch .Tensor ]:
309
- return {"weight" : self .get_weight (orig_module .weight )}
310
-
311
284
def calc_size (self ) -> int :
312
285
model_size = super ().calc_size ()
313
286
model_size += self .weight .nelement () * self .weight .element_size ()
@@ -335,16 +308,13 @@ def __init__(
335
308
336
309
self .rank = None # unscaled
337
310
338
- def get_weight (self , orig_weight : Optional [ torch .Tensor ] ) -> torch .Tensor :
311
+ def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
339
312
weight = self .weight
340
313
if not self .on_input :
341
314
weight = weight .reshape (- 1 , 1 )
342
315
assert orig_weight is not None
343
316
return orig_weight * weight
344
317
345
- def get_parameters (self , orig_module : Optional [torch .nn .Module ]) -> Dict [str , torch .Tensor ]:
346
- return {"weight" : self .get_weight (orig_module .weight )}
347
-
348
318
def calc_size (self ) -> int :
349
319
model_size = super ().calc_size ()
350
320
model_size += self .weight .nelement () * self .weight .element_size ()
0 commit comments