diff --git a/vit_pytorch/vit.py b/vit_pytorch/vit.py index 5b34a44a..bfcad3c0 100644 --- a/vit_pytorch/vit.py +++ b/vit_pytorch/vit.py @@ -48,19 +48,27 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): ) if project_out else nn.Identity() def forward(self, x): - x = self.norm(x) - - qkv = self.to_qkv(x).chunk(3, dim = -1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) - + assert x.dim() == 3, "Input x must have three dimensions: (batch_size, sequence_length, embedding_dim)" + + qkv = self.to_qkv(x) + q, k, v = self.rearrange(qkv).chunk(3, dim=-1) + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale - + + if mask is not None: + mask = mask.unsqueeze(1).expand(dots.size(0), self.heads, dots.size(2), dots.size(3)) + dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max) + attn = self.attend(dots) attn = self.dropout(attn) - + out = torch.matmul(attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') - return self.to_out(out) + + out = self.rearrange(out, 'b h n d -> b n (h d)', h=self.heads).contiguous() + + out = self.to_out(out) + + return out class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):