Skip to content

Commit aaee30d

Browse files
committed
Port back unload_model_clones
1 parent 66eeb4e commit aaee30d

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

lib_layerdiffusion/attention_sharing.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,42 @@ def __init__(self, layer_list):
324324
self.layers = torch.nn.ModuleList(layer_list)
325325

326326

327+
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
328+
current_loaded_models = model_management.current_loaded_models
329+
330+
to_unload = []
331+
for i, m in enumerate(current_loaded_models):
332+
if model.is_clone(m.model):
333+
to_unload = [i] + to_unload
334+
335+
if len(to_unload) == 0:
336+
return True
337+
338+
same_weights = 0
339+
for i in to_unload:
340+
if model.clone_has_same_weights(current_loaded_models[i].model):
341+
same_weights += 1
342+
343+
if same_weights == len(to_unload):
344+
unload_weight = False
345+
else:
346+
unload_weight = True
347+
348+
if not force_unload:
349+
if unload_weights_only and unload_weight == False:
350+
return None
351+
else:
352+
unload_weight = True
353+
354+
for i in to_unload:
355+
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
356+
357+
return unload_weight
358+
327359
class AttentionSharingPatcher(torch.nn.Module):
328360
def __init__(self, unet, frames=2, use_control=True, rank=256):
329361
super().__init__()
362+
unload_model_clones(unet)
330363

331364
units = []
332365
for i in range(32):

0 commit comments

Comments
 (0)