Skip to content

Commit f7c787e

Browse files
committed
make it possible to use hypernetworks without opt split attention
1 parent 97bc0b9 commit f7c787e

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

modules/hypernetwork.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
import traceback
55

66
import torch
7-
from modules import devices
7+
8+
from ldm.util import default
9+
from modules import devices, shared
10+
import torch
11+
from torch import einsum
12+
from einops import rearrange, repeat
813

914

1015
class HypernetworkModule(torch.nn.Module):
@@ -48,15 +53,36 @@ def load_hypernetworks(path):
4853

4954
return res
5055

51-
def apply(self, x, context=None, mask=None, original=None):
5256

57+
def attention_CrossAttention_forward(self, x, context=None, mask=None):
58+
h = self.heads
59+
60+
q = self.to_q(x)
61+
context = default(context, x)
5362

54-
if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork:
55-
if context.shape[1] == 77 and CrossAttention.noise_cond:
56-
context = context + (torch.randn_like(context) * 0.1)
57-
h_k, h_v = CrossAttention.hypernetwork[context.shape[2]]
58-
k = self.to_k(h_k(context))
59-
v = self.to_v(h_v(context))
63+
hypernetwork = shared.selected_hypernetwork()
64+
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
65+
66+
if hypernetwork_layers is not None:
67+
k = self.to_k(hypernetwork_layers[0](context))
68+
v = self.to_v(hypernetwork_layers[1](context))
6069
else:
6170
k = self.to_k(context)
6271
v = self.to_v(context)
72+
73+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
74+
75+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
76+
77+
if mask is not None:
78+
mask = rearrange(mask, 'b ... -> b (...)')
79+
max_neg_value = -torch.finfo(sim.dtype).max
80+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
81+
sim.masked_fill_(~mask, max_neg_value)
82+
83+
# attention, what we cannot get enough of
84+
attn = sim.softmax(dim=-1)
85+
86+
out = einsum('b i j, b j d -> b i d', attn, v)
87+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
88+
return self.to_out(out)

modules/sd_hijack.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch.nn.functional import silu
99

1010
import modules.textual_inversion.textual_inversion
11-
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
11+
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
1212
from modules.shared import opts, device, cmd_opts
1313

1414
import ldm.modules.attention
@@ -20,6 +20,8 @@
2020

2121

2222
def apply_optimizations():
23+
undo_optimizations()
24+
2325
ldm.modules.diffusionmodules.model.nonlinearity = silu
2426

2527
if cmd_opts.opt_split_attention_v1:
@@ -30,7 +32,7 @@ def apply_optimizations():
3032

3133

3234
def undo_optimizations():
33-
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
35+
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
3436
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
3537
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
3638

0 commit comments

Comments
 (0)