Skip to content

Commit 46a2244

Browse files
committed
use axial positional embedding, and make sure to always show gratitude
1 parent 514b7c9 commit 46a2244

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ The big surprise is that the generations can reach this level of fidelity. Will
1010

1111
Additionally, we will try adding an extra linear attention on the main branch, in addition to the full self attention on the latents. Self conditioning will also be applied to the non-latent images in pixel-space. Let us see how far we can push this approach.
1212

13+
## Appreciation
14+
15+
- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work on cutting edge artificial intelligence research
16+
1317
## Install
1418

1519
```bash

rin_pytorch/rin_pytorch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def __init__(
293293
nn.Linear(pixel_patch_dim * 2, dim)
294294
)
295295

296-
self.pos_emb = nn.Parameter(torch.randn(num_patches, dim))
296+
self.axial_pos_emb = nn.Parameter(torch.randn(2, patch_height_width, dim) * 0.02)
297297

298298
self.to_pixels = nn.Sequential(
299299
LayerNorm(dim),
@@ -302,6 +302,7 @@ def __init__(
302302
)
303303

304304
self.latents = nn.Parameter(torch.randn(num_latents, dim))
305+
nn.init.normal_(self.latents, std = 0.02)
305306

306307
self.init_self_cond_latents = nn.Sequential(
307308
FeedForward(dim),
@@ -362,7 +363,9 @@ def forward(
362363

363364
patches = self.to_patches(x)
364365

365-
patches = patches + self.pos_emb
366+
pos_emb_h, pos_emb_w = self.axial_pos_emb
367+
pos_emb = rearrange(pos_emb_h, 'i d -> i 1 d') + rearrange(pos_emb_w, 'j d -> 1 j d')
368+
patches = patches + rearrange(pos_emb, 'i j d -> (i j) d')
366369

367370
# the recurrent interface network body
368371

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.2',
6+
version = '0.0.3',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)