Skip to content

Commit 25db5d9

Browse files
committed
add the PEG module from vit literature
1 parent 46a2244 commit 25db5d9

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,22 @@ def forward(
215215
out = rearrange(out, 'b h n d -> b n (h d)')
216216
return self.to_out(out)
217217

218+
class PEG(nn.Module):
219+
def __init__(
220+
self,
221+
dim
222+
):
223+
super().__init__()
224+
self.ds_conv = nn.Conv2d(dim, dim, 3, padding = 1, groups = dim)
225+
226+
def forward(self, x):
227+
b, n, d = x.shape
228+
hw = int(math.sqrt(n))
229+
x = rearrange(x, 'b (h w) d -> b d h w', h = hw)
230+
x = self.ds_conv(x)
231+
x = rearrange(x, 'b d h w -> b (h w) d')
232+
return x
233+
218234
class FeedForward(nn.Module):
219235
def __init__(self, dim, mult = 4, time_cond_dim = None):
220236
super().__init__()
@@ -326,6 +342,7 @@ def __init__(
326342
FeedForward(dim)
327343
]))
328344

345+
self.patches_peg = PEG(dim)
329346
self.patches_self_attn = LinearAttention(dim, norm = True, **attn_kwargs)
330347
self.patches_self_attn_ff = FeedForward(dim)
331348

@@ -370,6 +387,8 @@ def forward(
370387
# the recurrent interface network body
371388

372389
for _ in range(self.depth):
390+
patches = self.patches_peg(patches) + patches
391+
373392
# latents extract or cluster information from the patches
374393

375394
latents = self.latents_attend_to_patches(latents, patches, time = t) + latents

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.3',
6+
version = '0.0.5',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)