Skip to content

Commit 2c204e7

Browse files
authored
Replace util.get_attr with ModelPatcher.get_model_object (#120)
1 parent 8c745b7 commit 2c204e7

File tree

1 file changed

+6
-39
lines changed

1 file changed

+6
-39
lines changed

lib_layerdiffusion/attention_sharing.py

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import torch
55
import einops
66

7-
from comfy import model_management, utils
7+
from comfy import model_management
88
from comfy.ldm.modules.attention import optimized_attention
9+
from comfy.model_patcher import ModelPatcher
910

1011

1112
module_mapping_sd15 = {
@@ -324,53 +325,19 @@ def __init__(self, layer_list):
324325
self.layers = torch.nn.ModuleList(layer_list)
325326

326327

327-
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
328-
current_loaded_models = model_management.current_loaded_models
329-
330-
to_unload = []
331-
for i, m in enumerate(current_loaded_models):
332-
if model.is_clone(m.model):
333-
to_unload = [i] + to_unload
334-
335-
if len(to_unload) == 0:
336-
return True
337-
338-
same_weights = 0
339-
for i in to_unload:
340-
if model.clone_has_same_weights(current_loaded_models[i].model):
341-
same_weights += 1
342-
343-
if same_weights == len(to_unload):
344-
unload_weight = False
345-
else:
346-
unload_weight = True
347-
348-
if not force_unload:
349-
if unload_weights_only and unload_weight is False:
350-
return None
351-
else:
352-
unload_weight = True
353-
354-
for i in to_unload:
355-
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
356-
357-
return unload_weight
358-
359-
360328
class AttentionSharingPatcher(torch.nn.Module):
361-
def __init__(self, unet, frames=2, use_control=True, rank=256):
329+
def __init__(self, unet: ModelPatcher, frames=2, use_control=True, rank=256):
362330
super().__init__()
363-
unload_model_clones(unet)
364331

365332
units = []
366333
for i in range(32):
367-
real_key = module_mapping_sd15[i]
368-
attn_module = utils.get_attr(unet.model.diffusion_model, real_key)
334+
key = "diffusion_model." + module_mapping_sd15[i]
335+
attn_module = unet.get_model_object(key)
369336
u = AttentionSharingUnit(
370337
attn_module, frames=frames, use_control=use_control, rank=rank
371338
)
372339
units.append(u)
373-
unet.add_object_patch("diffusion_model." + real_key, u)
340+
unet.add_object_patch(key, u)
374341

375342
self.hookers = HookerLayers(units)
376343

0 commit comments

Comments
 (0)