Skip to content

Commit ed7287a

Browse files
committed
allow for one to set a greater dimension for the latent dimensions (dim_latent) than the image dimensions (dim), fix a big bug
1 parent 40c0e2d commit ed7287a

File tree

3 files changed

+23
-19
lines changed

3 files changed

+23
-19
lines changed

README.md

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,7 @@ The big surprise is that the generations can reach this level of fidelity. Will
1010

1111
Additionally, we will try adding an extra linear attention on the main branch as well as self conditioning in the pixel-space.
1212

13-
Update:
14-
15-
<img src="./images/sample.png" width="300px"></img>
16-
17-
*130k steps*
18-
19-
It works but the more I think about the paper, the less excited I am. There are a number of issues with the RIN / ISAB architecture. However I think the new sigmoid noise schedule remains interesting as well as the new concept of being able to self-condition on any hidden state of the network.
13+
The insight of being able to self-condition on any hidden state of the network as well as the newly proposed sigmoid noise schedule are the two main findings.
2014

2115
## Appreciation
2216

@@ -39,6 +33,7 @@ model = RIN(
3933
patch_size = 8, # patch size
4034
depth = 6, # depth
4135
num_latents = 128, # number of latents. they used 256 in the paper
36+
dim_latent = 512, # can be greater than the image dimension (dim) for greater capacity
4237
latent_self_attn_depth = 4, # number of latent self attention blocks per recurrent step, K in the paper
4338
).cuda()
4439

rin_pytorch/rin_pytorch.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class Attention(nn.Module):
151151
def __init__(
152152
self,
153153
dim,
154+
dim_context = None,
154155
heads = 4,
155156
dim_head = 32,
156157
norm = False,
@@ -159,6 +160,7 @@ def __init__(
159160
):
160161
super().__init__()
161162
hidden_dim = dim_head * heads
163+
dim_context = default(dim_context, dim)
162164

163165
self.time_cond = None
164166

@@ -176,10 +178,10 @@ def __init__(
176178
self.heads = heads
177179

178180
self.norm = LayerNorm(dim) if norm else nn.Identity()
179-
self.norm_context = LayerNorm(dim) if norm_context else nn.Identity()
181+
self.norm_context = LayerNorm(dim_context) if norm_context else nn.Identity()
180182

181183
self.to_q = nn.Linear(dim, hidden_dim, bias = False)
182-
self.to_kv = nn.Linear(dim, hidden_dim * 2, bias = False)
184+
self.to_kv = nn.Linear(dim_context, hidden_dim * 2, bias = False)
183185
self.to_out = nn.Linear(hidden_dim, dim, bias = False)
184186

185187
def forward(
@@ -193,6 +195,9 @@ def forward(
193195

194196
context = default(context, x)
195197

198+
if x.shape[-1] != self.norm.gamma.shape[-1]:
199+
print(context.shape, x.shape, self.norm.gamma.shape)
200+
196201
x = self.norm(x)
197202

198203
if exists(self.time_cond):
@@ -272,24 +277,26 @@ def __init__(
272277
self,
273278
dim,
274279
latent_self_attn_depth,
280+
dim_latent = None,
275281
**attn_kwargs
276282
):
277283
super().__init__()
284+
dim_latent = default(dim_latent, dim)
278285

279-
self.latents_attend_to_patches = Attention(dim, norm = True, norm_context = True, **attn_kwargs)
286+
self.latents_attend_to_patches = Attention(dim_latent, dim_context = dim, norm = True, norm_context = True, **attn_kwargs)
280287

281288
self.latent_self_attns = nn.ModuleList([])
282289
for _ in range(latent_self_attn_depth):
283290
self.latent_self_attns.append(nn.ModuleList([
284-
Attention(dim, norm = True, **attn_kwargs),
285-
FeedForward(dim)
291+
Attention(dim_latent, norm = True, **attn_kwargs),
292+
FeedForward(dim_latent)
286293
]))
287294

288295
self.patches_peg = PEG(dim)
289296
self.patches_self_attn = LinearAttention(dim, norm = True, **attn_kwargs)
290297
self.patches_self_attn_ff = FeedForward(dim)
291298

292-
self.patches_attend_to_latents = Attention(dim, norm = True, norm_context = True, **attn_kwargs)
299+
self.patches_attend_to_latents = Attention(dim, dim_context = dim_latent, norm = True, norm_context = True, **attn_kwargs)
293300
self.patches_cross_attn_ff = FeedForward(dim)
294301

295302
def forward(self, patches, latents, t):
@@ -312,7 +319,7 @@ def forward(self, patches, latents, t):
312319

313320
# patches attend to the latents
314321

315-
patches = self.latents_attend_to_patches(patches, latents, time = t) + patches
322+
patches = self.patches_attend_to_latents(patches, latents, time = t) + patches
316323

317324
patches = self.patches_cross_attn_ff(patches, time = t) + patches
318325

@@ -327,12 +334,14 @@ def __init__(
327334
channels = 3,
328335
depth = 6, # number of RIN blocks
329336
latent_self_attn_depth = 2, # how many self attentions for the latent per each round of cross attending from pixel space to latents and back
337+
dim_latent = None, # will default to image dim (dim)
330338
num_latents = 256, # they still had to use a fair amount of latents for good results (256), in line with the Perceiver line of papers from Deepmind
331339
learned_sinusoidal_dim = 16,
332340
**attn_kwargs
333341
):
334342
super().__init__()
335343
assert divisible_by(image_size, patch_size)
344+
dim_latent = default(dim_latent, dim)
336345

337346
self.channels = channels # times 2 due to self-conditioning
338347

@@ -368,12 +377,12 @@ def __init__(
368377
Rearrange('b (h w) (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size, h = patch_height_width)
369378
)
370379

371-
self.latents = nn.Parameter(torch.randn(num_latents, dim))
380+
self.latents = nn.Parameter(torch.randn(num_latents, dim_latent))
372381
nn.init.normal_(self.latents, std = 0.02)
373382

374383
self.init_self_cond_latents = nn.Sequential(
375-
FeedForward(dim),
376-
LayerNorm(dim)
384+
FeedForward(dim_latent),
385+
LayerNorm(dim_latent)
377386
)
378387

379388
nn.init.zeros_(self.init_self_cond_latents[-1].gamma)
@@ -382,7 +391,7 @@ def __init__(
382391

383392
attn_kwargs = {**attn_kwargs, 'time_cond_dim': time_dim}
384393

385-
self.blocks = nn.ModuleList([RINBlock(dim, latent_self_attn_depth = latent_self_attn_depth, **attn_kwargs) for _ in range(depth)])
394+
self.blocks = nn.ModuleList([RINBlock(dim, dim_latent = dim_latent, latent_self_attn_depth = latent_self_attn_depth, **attn_kwargs) for _ in range(depth)])
386395

387396
def forward(
388397
self,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.6',
6+
version = '0.1.0',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)