Skip to content

Commit b323532

Browse files
committed
parameterize axial positional embeddings as MLPs for height and width, for testing image size extrapolation
1 parent 9062275 commit b323532

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ diffusion = GaussianDiffusion(
8585
scale = 1.
8686
)
8787

88-
training_images = torch.randn(8, 3, 128, 128) # images are normalized from 0 to 1
88+
training_images = torch.randn(8, 3, 128, 128).cuda() # images are normalized from 0 to 1
8989
loss = diffusion(training_images)
9090
loss.backward()
9191
# after a lot of training

rin_pytorch/rin_pytorch.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,29 @@ def __init__(
389389
nn.LayerNorm(dim) if dual_patchnorm else None,
390390
)
391391

392-
self.axial_pos_emb = nn.Parameter(torch.randn(2, patch_height_width, dim) * 0.02)
392+
# axial positional embeddings, parameterized by an MLP
393+
394+
pos_emb_dim = dim // 2
395+
396+
self.axial_pos_emb_height_mlp = nn.Sequential(
397+
Rearrange('... -> ... 1'),
398+
nn.Linear(1, pos_emb_dim),
399+
nn.SiLU(),
400+
nn.Linear(pos_emb_dim, pos_emb_dim),
401+
nn.SiLU(),
402+
nn.Linear(pos_emb_dim, dim)
403+
)
404+
405+
self.axial_pos_emb_width_mlp = nn.Sequential(
406+
Rearrange('... -> ... 1'),
407+
nn.Linear(1, pos_emb_dim),
408+
nn.SiLU(),
409+
nn.Linear(pos_emb_dim, pos_emb_dim),
410+
nn.SiLU(),
411+
nn.Linear(pos_emb_dim, dim)
412+
)
413+
414+
# nn.Parameter(torch.randn(2, patch_height_width, dim) * 0.02)
393415

394416
self.to_pixels = nn.Sequential(
395417
LayerNorm(dim),
@@ -414,6 +436,10 @@ def __init__(
414436

415437
self.blocks = nn.ModuleList([RINBlock(dim, dim_latent = dim_latent, latent_self_attn_depth = latent_self_attn_depth, **attn_kwargs) for _ in range(depth)])
416438

439+
@property
440+
def device(self):
441+
return next(self.parameters()).device
442+
417443
def forward(
418444
self,
419445
x,
@@ -451,7 +477,9 @@ def forward(
451477

452478
patches = self.to_patches(x)
453479

454-
pos_emb_h, pos_emb_w = self.axial_pos_emb
480+
height_range = width_range = torch.linspace(0., 1., steps = int(math.sqrt(patches.shape[-2])), device = self.device)
481+
pos_emb_h, pos_emb_w = self.axial_pos_emb_height_mlp(height_range), self.axial_pos_emb_width_mlp(width_range)
482+
455483
pos_emb = rearrange(pos_emb_h, 'i d -> i 1 d') + rearrange(pos_emb_w, 'j d -> 1 j d')
456484
patches = patches + rearrange(pos_emb, 'i j d -> (i j) d')
457485

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.5.4',
6+
version = '0.6.0',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)