11
11
12
12
from invokeai .backend .model_manager import BaseModelType
13
13
from invokeai .backend .raw_model import RawModel
14
- from invokeai .backend .util .devices import TorchDevice
15
14
16
15
17
16
class LoRALayerBase :
@@ -57,14 +56,9 @@ def calc_size(self) -> int:
57
56
model_size += val .nelement () * val .element_size ()
58
57
return model_size
59
58
60
- def to (
61
- self ,
62
- device : Optional [torch .device ] = None ,
63
- dtype : Optional [torch .dtype ] = None ,
64
- non_blocking : bool = False ,
65
- ) -> None :
59
+ def to (self , device : Optional [torch .device ] = None , dtype : Optional [torch .dtype ] = None ) -> None :
66
60
if self .bias is not None :
67
- self .bias = self .bias .to (device = device , dtype = dtype , non_blocking = non_blocking )
61
+ self .bias = self .bias .to (device = device , dtype = dtype )
68
62
69
63
70
64
# TODO: find and debug lora/locon with bias
@@ -106,19 +100,14 @@ def calc_size(self) -> int:
106
100
model_size += val .nelement () * val .element_size ()
107
101
return model_size
108
102
109
- def to (
110
- self ,
111
- device : Optional [torch .device ] = None ,
112
- dtype : Optional [torch .dtype ] = None ,
113
- non_blocking : bool = False ,
114
- ) -> None :
115
- super ().to (device = device , dtype = dtype , non_blocking = non_blocking )
103
+ def to (self , device : Optional [torch .device ] = None , dtype : Optional [torch .dtype ] = None ) -> None :
104
+ super ().to (device = device , dtype = dtype )
116
105
117
- self .up = self .up .to (device = device , dtype = dtype , non_blocking = non_blocking )
118
- self .down = self .down .to (device = device , dtype = dtype , non_blocking = non_blocking )
106
+ self .up = self .up .to (device = device , dtype = dtype )
107
+ self .down = self .down .to (device = device , dtype = dtype )
119
108
120
109
if self .mid is not None :
121
- self .mid = self .mid .to (device = device , dtype = dtype , non_blocking = non_blocking )
110
+ self .mid = self .mid .to (device = device , dtype = dtype )
122
111
123
112
124
113
class LoHALayer (LoRALayerBase ):
@@ -167,23 +156,18 @@ def calc_size(self) -> int:
167
156
model_size += val .nelement () * val .element_size ()
168
157
return model_size
169
158
170
- def to (
171
- self ,
172
- device : Optional [torch .device ] = None ,
173
- dtype : Optional [torch .dtype ] = None ,
174
- non_blocking : bool = False ,
175
- ) -> None :
159
+ def to (self , device : Optional [torch .device ] = None , dtype : Optional [torch .dtype ] = None ) -> None :
176
160
super ().to (device = device , dtype = dtype )
177
161
178
- self .w1_a = self .w1_a .to (device = device , dtype = dtype , non_blocking = non_blocking )
179
- self .w1_b = self .w1_b .to (device = device , dtype = dtype , non_blocking = non_blocking )
162
+ self .w1_a = self .w1_a .to (device = device , dtype = dtype )
163
+ self .w1_b = self .w1_b .to (device = device , dtype = dtype )
180
164
if self .t1 is not None :
181
- self .t1 = self .t1 .to (device = device , dtype = dtype , non_blocking = non_blocking )
165
+ self .t1 = self .t1 .to (device = device , dtype = dtype )
182
166
183
- self .w2_a = self .w2_a .to (device = device , dtype = dtype , non_blocking = non_blocking )
184
- self .w2_b = self .w2_b .to (device = device , dtype = dtype , non_blocking = non_blocking )
167
+ self .w2_a = self .w2_a .to (device = device , dtype = dtype )
168
+ self .w2_b = self .w2_b .to (device = device , dtype = dtype )
185
169
if self .t2 is not None :
186
- self .t2 = self .t2 .to (device = device , dtype = dtype , non_blocking = non_blocking )
170
+ self .t2 = self .t2 .to (device = device , dtype = dtype )
187
171
188
172
189
173
class LoKRLayer (LoRALayerBase ):
@@ -264,32 +248,27 @@ def calc_size(self) -> int:
264
248
model_size += val .nelement () * val .element_size ()
265
249
return model_size
266
250
267
- def to (
268
- self ,
269
- device : Optional [torch .device ] = None ,
270
- dtype : Optional [torch .dtype ] = None ,
271
- non_blocking : bool = False ,
272
- ) -> None :
251
+ def to (self , device : Optional [torch .device ] = None , dtype : Optional [torch .dtype ] = None ) -> None :
273
252
super ().to (device = device , dtype = dtype )
274
253
275
254
if self .w1 is not None :
276
255
self .w1 = self .w1 .to (device = device , dtype = dtype )
277
256
else :
278
257
assert self .w1_a is not None
279
258
assert self .w1_b is not None
280
- self .w1_a = self .w1_a .to (device = device , dtype = dtype , non_blocking = non_blocking )
281
- self .w1_b = self .w1_b .to (device = device , dtype = dtype , non_blocking = non_blocking )
259
+ self .w1_a = self .w1_a .to (device = device , dtype = dtype )
260
+ self .w1_b = self .w1_b .to (device = device , dtype = dtype )
282
261
283
262
if self .w2 is not None :
284
- self .w2 = self .w2 .to (device = device , dtype = dtype , non_blocking = non_blocking )
263
+ self .w2 = self .w2 .to (device = device , dtype = dtype )
285
264
else :
286
265
assert self .w2_a is not None
287
266
assert self .w2_b is not None
288
- self .w2_a = self .w2_a .to (device = device , dtype = dtype , non_blocking = non_blocking )
289
- self .w2_b = self .w2_b .to (device = device , dtype = dtype , non_blocking = non_blocking )
267
+ self .w2_a = self .w2_a .to (device = device , dtype = dtype )
268
+ self .w2_b = self .w2_b .to (device = device , dtype = dtype )
290
269
291
270
if self .t2 is not None :
292
- self .t2 = self .t2 .to (device = device , dtype = dtype , non_blocking = non_blocking )
271
+ self .t2 = self .t2 .to (device = device , dtype = dtype )
293
272
294
273
295
274
class FullLayer (LoRALayerBase ):
@@ -319,15 +298,10 @@ def calc_size(self) -> int:
319
298
model_size += self .weight .nelement () * self .weight .element_size ()
320
299
return model_size
321
300
322
- def to (
323
- self ,
324
- device : Optional [torch .device ] = None ,
325
- dtype : Optional [torch .dtype ] = None ,
326
- non_blocking : bool = False ,
327
- ) -> None :
301
+ def to (self , device : Optional [torch .device ] = None , dtype : Optional [torch .dtype ] = None ) -> None :
328
302
super ().to (device = device , dtype = dtype )
329
303
330
- self .weight = self .weight .to (device = device , dtype = dtype , non_blocking = non_blocking )
304
+ self .weight = self .weight .to (device = device , dtype = dtype )
331
305
332
306
333
307
class IA3Layer (LoRALayerBase ):
@@ -359,16 +333,11 @@ def calc_size(self) -> int:
359
333
model_size += self .on_input .nelement () * self .on_input .element_size ()
360
334
return model_size
361
335
362
- def to (
363
- self ,
364
- device : Optional [torch .device ] = None ,
365
- dtype : Optional [torch .dtype ] = None ,
366
- non_blocking : bool = False ,
367
- ):
336
+ def to (self , device : Optional [torch .device ] = None , dtype : Optional [torch .dtype ] = None ):
368
337
super ().to (device = device , dtype = dtype )
369
338
370
- self .weight = self .weight .to (device = device , dtype = dtype , non_blocking = non_blocking )
371
- self .on_input = self .on_input .to (device = device , dtype = dtype , non_blocking = non_blocking )
339
+ self .weight = self .weight .to (device = device , dtype = dtype )
340
+ self .on_input = self .on_input .to (device = device , dtype = dtype )
372
341
373
342
374
343
AnyLoRALayer = Union [LoRALayer , LoHALayer , LoKRLayer , FullLayer , IA3Layer ]
@@ -390,15 +359,10 @@ def __init__(
390
359
def name (self ) -> str :
391
360
return self ._name
392
361
393
- def to (
394
- self ,
395
- device : Optional [torch .device ] = None ,
396
- dtype : Optional [torch .dtype ] = None ,
397
- non_blocking : bool = False ,
398
- ) -> None :
362
+ def to (self , device : Optional [torch .device ] = None , dtype : Optional [torch .dtype ] = None ) -> None :
399
363
# TODO: try revert if exception?
400
364
for _key , layer in self .layers .items ():
401
- layer .to (device = device , dtype = dtype , non_blocking = non_blocking )
365
+ layer .to (device = device , dtype = dtype )
402
366
403
367
def calc_size (self ) -> int :
404
368
model_size = 0
@@ -521,7 +485,7 @@ def from_checkpoint(
521
485
# lower memory consumption by removing already parsed layer values
522
486
state_dict [layer_key ].clear ()
523
487
524
- layer .to (device = device , dtype = dtype , non_blocking = TorchDevice . get_non_blocking ( device ) )
488
+ layer .to (device = device , dtype = dtype )
525
489
model .layers [layer_key ] = layer
526
490
527
491
return model
0 commit comments