From bb4841ece439bac8b81be8171b54a5f577889139 Mon Sep 17 00:00:00 2001 From: Ryan Kim Date: Fri, 24 May 2024 22:30:53 -0500 Subject: [PATCH] Layer Norm modification Modified layer norms in the encoder layers to fit more closely to the original ViT paper. --- vit_pytorch/vit.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vit_pytorch/vit.py b/vit_pytorch/vit.py index 5b34a44a..84159043 100644 --- a/vit_pytorch/vit.py +++ b/vit_pytorch/vit.py @@ -65,20 +65,23 @@ def forward(self, x): class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): super().__init__() - self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) + self.layer_norm = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), FeedForward(dim, mlp_dim, dropout = dropout) ])) + self.layer_norm.append(nn.ModuleList([nn.LayerNorm(dim), + nn.LayerNorm(dim) + ])) def forward(self, x): - for attn, ff in self.layers: - x = attn(x) + x - x = ff(x) + x + for i, (attn, ff) in enumerate(self.layers): + x = attn(self.layer_norm[i][0](x)) + x + x = ff(self.layer_norm[i][1](x)) + x - return self.norm(x) + return x class ViT(nn.Module): def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):