Skip to content

Commit 49e52c0

Browse files
committed
add a final norm to the latents after all the latent self attention
1 parent ed7287a commit 49e52c0

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def __init__(
278278
dim,
279279
latent_self_attn_depth,
280280
dim_latent = None,
281+
final_norm = True,
281282
**attn_kwargs
282283
):
283284
super().__init__()
@@ -292,6 +293,8 @@ def __init__(
292293
FeedForward(dim_latent)
293294
]))
294295

296+
self.latent_final_norm = LayerNorm(dim_latent) if final_norm else nn.Identity()
297+
295298
self.patches_peg = PEG(dim)
296299
self.patches_self_attn = LinearAttention(dim, norm = True, **attn_kwargs)
297300
self.patches_self_attn_ff = FeedForward(dim)
@@ -323,6 +326,7 @@ def forward(self, patches, latents, t):
323326

324327
patches = self.patches_cross_attn_ff(patches, time = t) + patches
325328

329+
latents = self.latent_final_norm(latents)
326330
return patches, latents
327331

328332
class RIN(nn.Module):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.0',
6+
version = '0.1.1',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)