Skip to content

Commit 422ae99

Browse files
committed
Fix SD15 memory sharing
1 parent ab409ee commit 422ae99

File tree

2 files changed

+14
-24
lines changed

2 files changed

+14
-24
lines changed

layered_diffusion.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,9 @@ def apply_layered_diffusion_attn_sharing(
286286
layer_lora_state_dict = load_layer_model_state_dict(model_path)
287287
work_model = model.clone()
288288
patcher = AttentionSharingPatcher(
289-
work_model, self.frames, use_control=control_img is not None
289+
work_model, self.frames, control_img=control_img
290290
)
291291
patcher.load_state_dict(layer_lora_state_dict, strict=True)
292-
if control_img is not None:
293-
patcher.set_control(control_img)
294292
return (work_model,)
295293

296294

lib_layerdiffusion/attention_sharing.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class AttentionSharingUnit(torch.nn.Module):
7676
# call.
7777
transformer_options: dict = {}
7878

79-
def __init__(self, module, frames=2, use_control=True, rank=256):
79+
def __init__(self, module, frames=2, control_signals=None, rank=256):
8080
super().__init__()
8181

8282
self.heads = module.heads
@@ -142,9 +142,9 @@ def __init__(self, module, frames=2, use_control=True, rank=256):
142142
in_features=hidden_size, out_features=hidden_size
143143
)
144144

145+
self.control_signals = control_signals
145146
self.control_convs = None
146-
147-
if use_control:
147+
if control_signals is not None:
148148
self.control_convs = [
149149
torch.nn.Sequential(
150150
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
@@ -155,7 +155,6 @@ def __init__(self, module, frames=2, use_control=True, rank=256):
155155
]
156156
self.control_convs = torch.nn.ModuleList(self.control_convs)
157157

158-
self.control_signals = None
159158

160159
def forward(self, h, context=None, value=None):
161160
transformer_options = self.transformer_options
@@ -325,36 +324,29 @@ def __init__(self, layer_list):
325324

326325

327326
class AttentionSharingPatcher(torch.nn.Module):
328-
def __init__(self, unet, frames=2, use_control=True, rank=256):
327+
def __init__(self, unet, frames=2, control_img=None, rank=256):
329328
super().__init__()
330-
model_management.unload_model_clones(unet)
329+
control_signals = (
330+
AdditionalAttentionCondsEncoder()(control_img.cpu().float() * 2.0 - 1.0)
331+
if control_img is not None
332+
else None
333+
)
331334

332335
units = []
333336
for i in range(32):
334337
real_key = module_mapping_sd15[i]
335338
attn_module = utils.get_attr(unet.model.diffusion_model, real_key)
336339
u = AttentionSharingUnit(
337-
attn_module, frames=frames, use_control=use_control, rank=rank
340+
attn_module,
341+
frames=frames,
342+
control_signals=control_signals,
343+
rank=rank,
338344
)
339345
units.append(u)
340346
unet.add_object_patch("diffusion_model." + real_key, u)
341-
342347
self.hookers = HookerLayers(units)
343348

344-
if use_control:
345-
self.kwargs_encoder = AdditionalAttentionCondsEncoder()
346-
else:
347-
self.kwargs_encoder = None
348-
349349
self.dtype = torch.float32
350350
if model_management.should_use_fp16(model_management.get_torch_device()):
351351
self.dtype = torch.float16
352352
self.hookers.half()
353-
return
354-
355-
def set_control(self, img):
356-
img = img.cpu().float() * 2.0 - 1.0
357-
signals = self.kwargs_encoder(img)
358-
for m in self.hookers.layers:
359-
m.control_signals = signals
360-
return

0 commit comments

Comments
 (0)