|
4 | 4 | import traceback
|
5 | 5 |
|
6 | 6 | import torch
|
7 |
| -from modules import devices |
| 7 | + |
| 8 | +from ldm.util import default |
| 9 | +from modules import devices, shared |
| 10 | +import torch |
| 11 | +from torch import einsum |
| 12 | +from einops import rearrange, repeat |
8 | 13 |
|
9 | 14 |
|
10 | 15 | class HypernetworkModule(torch.nn.Module):
|
@@ -48,15 +53,36 @@ def load_hypernetworks(path):
|
48 | 53 |
|
49 | 54 | return res
|
50 | 55 |
|
51 |
| -def apply(self, x, context=None, mask=None, original=None): |
52 | 56 |
|
| 57 | +def attention_CrossAttention_forward(self, x, context=None, mask=None): |
| 58 | + h = self.heads |
| 59 | + |
| 60 | + q = self.to_q(x) |
| 61 | + context = default(context, x) |
53 | 62 |
|
54 |
| - if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: |
55 |
| - if context.shape[1] == 77 and CrossAttention.noise_cond: |
56 |
| - context = context + (torch.randn_like(context) * 0.1) |
57 |
| - h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] |
58 |
| - k = self.to_k(h_k(context)) |
59 |
| - v = self.to_v(h_v(context)) |
| 63 | + hypernetwork = shared.selected_hypernetwork() |
| 64 | + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) |
| 65 | + |
| 66 | + if hypernetwork_layers is not None: |
| 67 | + k = self.to_k(hypernetwork_layers[0](context)) |
| 68 | + v = self.to_v(hypernetwork_layers[1](context)) |
60 | 69 | else:
|
61 | 70 | k = self.to_k(context)
|
62 | 71 | v = self.to_v(context)
|
| 72 | + |
| 73 | + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) |
| 74 | + |
| 75 | + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale |
| 76 | + |
| 77 | + if mask is not None: |
| 78 | + mask = rearrange(mask, 'b ... -> b (...)') |
| 79 | + max_neg_value = -torch.finfo(sim.dtype).max |
| 80 | + mask = repeat(mask, 'b j -> (b h) () j', h=h) |
| 81 | + sim.masked_fill_(~mask, max_neg_value) |
| 82 | + |
| 83 | + # attention, what we cannot get enough of |
| 84 | + attn = sim.softmax(dim=-1) |
| 85 | + |
| 86 | + out = einsum('b i j, b j d -> b i d', attn, v) |
| 87 | + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) |
| 88 | + return self.to_out(out) |
0 commit comments