|
4 | 4 | import torch |
5 | 5 | import einops |
6 | 6 |
|
7 | | -from comfy import model_management, utils |
| 7 | +from comfy import model_management |
8 | 8 | from comfy.ldm.modules.attention import optimized_attention |
| 9 | +from comfy.model_patcher import ModelPatcher |
9 | 10 |
|
10 | 11 |
|
11 | 12 | module_mapping_sd15 = { |
@@ -324,53 +325,19 @@ def __init__(self, layer_list): |
324 | 325 | self.layers = torch.nn.ModuleList(layer_list) |
325 | 326 |
|
326 | 327 |
|
327 | | -def unload_model_clones(model, unload_weights_only=True, force_unload=True): |
328 | | - current_loaded_models = model_management.current_loaded_models |
329 | | - |
330 | | - to_unload = [] |
331 | | - for i, m in enumerate(current_loaded_models): |
332 | | - if model.is_clone(m.model): |
333 | | - to_unload = [i] + to_unload |
334 | | - |
335 | | - if len(to_unload) == 0: |
336 | | - return True |
337 | | - |
338 | | - same_weights = 0 |
339 | | - for i in to_unload: |
340 | | - if model.clone_has_same_weights(current_loaded_models[i].model): |
341 | | - same_weights += 1 |
342 | | - |
343 | | - if same_weights == len(to_unload): |
344 | | - unload_weight = False |
345 | | - else: |
346 | | - unload_weight = True |
347 | | - |
348 | | - if not force_unload: |
349 | | - if unload_weights_only and unload_weight is False: |
350 | | - return None |
351 | | - else: |
352 | | - unload_weight = True |
353 | | - |
354 | | - for i in to_unload: |
355 | | - current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) |
356 | | - |
357 | | - return unload_weight |
358 | | - |
359 | | - |
360 | 328 | class AttentionSharingPatcher(torch.nn.Module): |
361 | | - def __init__(self, unet, frames=2, use_control=True, rank=256): |
| 329 | + def __init__(self, unet: ModelPatcher, frames=2, use_control=True, rank=256): |
362 | 330 | super().__init__() |
363 | | - unload_model_clones(unet) |
364 | 331 |
|
365 | 332 | units = [] |
366 | 333 | for i in range(32): |
367 | | - real_key = module_mapping_sd15[i] |
368 | | - attn_module = utils.get_attr(unet.model.diffusion_model, real_key) |
| 334 | + key = "diffusion_model." + module_mapping_sd15[i] |
| 335 | + attn_module = unet.get_model_object(key) |
369 | 336 | u = AttentionSharingUnit( |
370 | 337 | attn_module, frames=frames, use_control=use_control, rank=rank |
371 | 338 | ) |
372 | 339 | units.append(u) |
373 | | - unet.add_object_patch("diffusion_model." + real_key, u) |
| 340 | + unet.add_object_patch(key, u) |
374 | 341 |
|
375 | 342 | self.hookers = HookerLayers(units) |
376 | 343 |
|
|
0 commit comments