Skip to content

Commit 627fa4e

Browse files
committed
in this paper, attention is all we need
1 parent 8890468 commit 627fa4e

File tree

1 file changed

+38
-76
lines changed

1 file changed

+38
-76
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 38 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -65,42 +65,20 @@ def __init__(self, fn):
6565
def forward(self, x, *args, **kwargs):
6666
return self.fn(x, *args, **kwargs) + x
6767

68-
def Upsample(dim, dim_out = None):
69-
return nn.Sequential(
70-
nn.Upsample(scale_factor = 2, mode = 'nearest'),
71-
nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
72-
)
73-
74-
def Downsample(dim, dim_out = None):
75-
return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
68+
# use layernorm without bias, more stable
7669

7770
class LayerNorm(nn.Module):
7871
def __init__(self, dim):
7972
super().__init__()
80-
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
81-
82-
def forward(self, x):
83-
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
84-
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
85-
mean = torch.mean(x, dim = 1, keepdim = True)
86-
return (x - mean) * var.clamp(min = eps).rsqrt() * self.g
87-
88-
class PreNorm(nn.Module):
89-
def __init__(self, dim, fn):
90-
super().__init__()
91-
self.fn = fn
92-
self.norm = LayerNorm(dim)
73+
self.gamma = nn.Parameter(torch.ones(dim))
74+
self.register_buffer("beta", torch.zeros(dim))
9375

9476
def forward(self, x):
95-
x = self.norm(x)
96-
return self.fn(x)
77+
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
9778

9879
# positional embeds
9980

10081
class LearnedSinusoidalPosEmb(nn.Module):
101-
""" following @crowsonkb 's lead with learned sinusoidal pos emb """
102-
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
103-
10482
def __init__(self, dim):
10583
super().__init__()
10684
assert (dim % 2) == 0
@@ -114,54 +92,13 @@ def forward(self, x):
11492
fouriered = torch.cat((x, fouriered), dim = -1)
11593
return fouriered
11694

117-
# building block modules
118-
119-
class Block(nn.Module):
120-
def __init__(self, dim, dim_out, groups = 8):
121-
super().__init__()
122-
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
123-
self.norm = nn.GroupNorm(groups, dim_out)
124-
self.act = nn.SiLU()
125-
126-
def forward(self, x, scale_shift = None):
127-
x = self.proj(x)
128-
x = self.norm(x)
129-
130-
if exists(scale_shift):
131-
scale, shift = scale_shift
132-
x = x * (scale + 1) + shift
133-
134-
x = self.act(x)
135-
return x
136-
137-
class ResnetBlock(nn.Module):
138-
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
139-
super().__init__()
140-
self.mlp = nn.Sequential(
141-
nn.SiLU(),
142-
nn.Linear(time_emb_dim, dim_out * 2)
143-
) if exists(time_emb_dim) else None
144-
145-
self.block1 = Block(dim, dim_out, groups = groups)
146-
self.block2 = Block(dim_out, dim_out, groups = groups)
147-
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
148-
149-
def forward(self, x, time_emb = None):
150-
151-
scale_shift = None
152-
if exists(self.mlp) and exists(time_emb):
153-
time_emb = self.mlp(time_emb)
154-
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
155-
scale_shift = time_emb.chunk(2, dim = 1)
156-
157-
h = self.block1(x, scale_shift = scale_shift)
158-
159-
h = self.block2(h)
160-
161-
return h + self.res_conv(x)
162-
16395
class LinearAttention(nn.Module):
164-
def __init__(self, dim, heads = 4, dim_head = 32):
96+
def __init__(
97+
self,
98+
dim,
99+
heads = 4,
100+
dim_head = 32
101+
):
165102
super().__init__()
166103
self.scale = dim_head ** -0.5
167104
self.heads = heads
@@ -182,7 +119,6 @@ def forward(self, x):
182119
k = k.softmax(dim = -1)
183120

184121
q = q * self.scale
185-
v = v / (h * w)
186122

187123
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
188124

@@ -191,7 +127,12 @@ def forward(self, x):
191127
return self.to_out(out)
192128

193129
class Attention(nn.Module):
194-
def __init__(self, dim, heads = 4, dim_head = 32):
130+
def __init__(
131+
self,
132+
dim,
133+
heads = 4,
134+
dim_head = 32
135+
):
195136
super().__init__()
196137
self.scale = dim_head ** -0.5
197138
self.heads = heads
@@ -212,6 +153,27 @@ def forward(self, x):
212153
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
213154
return self.to_out(out)
214155

156+
class FiLM(nn.Module):
157+
def __init__(
158+
self,
159+
dim,
160+
hidden_dim
161+
):
162+
super().__init__()
163+
self.net = nn.Sequential(
164+
nn.Linear(dim, hidden_dim * 4),
165+
nn.SiLU(),
166+
nn.Linear(hidden_dim * 4, hidden_dim * 2)
167+
)
168+
169+
nn.init.zeros_(self.net[-1].weight)
170+
nn.init.zeros_(self.net[-1].bias)
171+
172+
def forward(self, conditions, hiddens):
173+
scale, shift = self.net(conditions).chunk(2, dim = -1)
174+
scale, shift = map(lambda t: rearrange(t, 'b d -> b 1 d'), (scale, shift))
175+
return hiddens * (scale + 1) + shift
176+
215177
# model
216178

217179
class RIN(nn.Module):
@@ -352,7 +314,7 @@ def beta_linear_log_snr(t):
352314
return -log(expm1(1e-4 + 10 * (t ** 2)))
353315

354316
def alpha_cosine_log_snr(t, s = 0.008):
355-
return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version
317+
return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5)
356318

357319
def gamma_sigmoid_log_snr(t, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
358320
v_start = torch.tensor(start / tau).sigmoid()

0 commit comments

Comments
 (0)