Skip to content

Commit 16601b8

Browse files
committed
ready cross attention for attending to and from latents
2
1 parent 3517d37 commit 16601b8

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,18 +137,26 @@ def __init__(
137137
self.scale = dim_head ** -0.5
138138
self.heads = heads
139139
hidden_dim = dim_head * heads
140-
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
140+
self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias = False)
141+
self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias = False)
141142
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
142143

143-
def forward(self, x):
144+
def forward(
145+
self,
146+
x,
147+
context = None
148+
):
144149
b, c, h, w = x.shape
145-
qkv = self.to_qkv(x).chunk(3, dim = 1)
150+
context = default(context, x)
151+
152+
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = 1))
146153
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
147154

148155
q = q * self.scale
149156

150157
sim = einsum('b h d i, b h d j -> b h i j', q, k)
151158
attn = sim.softmax(dim = -1)
159+
152160
out = einsum('b h i j, b h d j -> b h i d', attn, v)
153161
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
154162
return self.to_out(out)

0 commit comments

Comments
 (0)