|
| 1 | +import torch |
| 2 | +from torch import nn |
| 3 | + |
| 4 | +from einops import rearrange, repeat |
| 5 | +from einops.layers.torch import Rearrange |
| 6 | + |
| 7 | +# helpers |
| 8 | + |
| 9 | +def pair(t): |
| 10 | + return t if isinstance(t, tuple) else (t, t) |
| 11 | + |
| 12 | +# classes |
| 13 | + |
| 14 | +class PreNorm(nn.Module): |
| 15 | + def __init__(self, dim, fn): |
| 16 | + super().__init__() |
| 17 | + self.norm = nn.LayerNorm(dim) |
| 18 | + self.fn = fn |
| 19 | + def forward(self, x, **kwargs): |
| 20 | + return self.fn(self.norm(x), **kwargs) |
| 21 | + |
| 22 | +class FeedForward(nn.Module): |
| 23 | + def __init__(self, dim, hidden_dim, dropout = 0.): |
| 24 | + super().__init__() |
| 25 | + self.net = nn.Sequential( |
| 26 | + nn.Linear(dim, hidden_dim), |
| 27 | + nn.GELU(), |
| 28 | + nn.Dropout(dropout), |
| 29 | + nn.Linear(hidden_dim, dim), |
| 30 | + nn.Dropout(dropout) |
| 31 | + ) |
| 32 | + def forward(self, x): |
| 33 | + return self.net(x) |
| 34 | + |
| 35 | +class Attention(nn.Module): |
| 36 | + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): |
| 37 | + super().__init__() |
| 38 | + inner_dim = dim_head * heads |
| 39 | + project_out = not (heads == 1 and dim_head == dim) |
| 40 | + |
| 41 | + self.heads = heads |
| 42 | + self.scale = dim_head ** -0.5 |
| 43 | + |
| 44 | + self.attend = nn.Softmax(dim = -1) |
| 45 | + self.dropout = nn.Dropout(dropout) |
| 46 | + |
| 47 | + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) |
| 48 | + |
| 49 | + self.to_out = nn.Sequential( |
| 50 | + nn.Linear(inner_dim, dim), |
| 51 | + nn.Dropout(dropout) |
| 52 | + ) if project_out else nn.Identity() |
| 53 | + |
| 54 | + def forward(self, x): |
| 55 | + qkv = self.to_qkv(x).chunk(3, dim = -1) |
| 56 | + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) |
| 57 | + |
| 58 | + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
| 59 | + |
| 60 | + attn = self.attend(dots) |
| 61 | + attn = self.dropout(attn) |
| 62 | + |
| 63 | + out = torch.matmul(attn, v) |
| 64 | + out = rearrange(out, 'b h n d -> b n (h d)') |
| 65 | + return self.to_out(out) |
| 66 | + |
| 67 | +class Transformer(nn.Module): |
| 68 | + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): |
| 69 | + super().__init__() |
| 70 | + self.layers = nn.ModuleList([]) |
| 71 | + for _ in range(depth): |
| 72 | + self.layers.append(nn.ModuleList([ |
| 73 | + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), |
| 74 | + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) |
| 75 | + ])) |
| 76 | + def forward(self, x): |
| 77 | + for attn, ff in self.layers: |
| 78 | + x = attn(x) + x |
| 79 | + x = ff(x) + x |
| 80 | + return x |
| 81 | + |
| 82 | +class ViT(nn.Module): |
| 83 | + def __init__( |
| 84 | + self, |
| 85 | + *, |
| 86 | + image_size, |
| 87 | + image_patch_size, |
| 88 | + frames, |
| 89 | + frame_patch_size, |
| 90 | + num_classes, |
| 91 | + dim, |
| 92 | + spatial_depth, |
| 93 | + temporal_depth, |
| 94 | + heads, |
| 95 | + mlp_dim, |
| 96 | + pool = 'cls', |
| 97 | + channels = 3, |
| 98 | + dim_head = 64, |
| 99 | + dropout = 0., |
| 100 | + emb_dropout = 0. |
| 101 | + ): |
| 102 | + super().__init__() |
| 103 | + image_height, image_width = pair(image_size) |
| 104 | + patch_height, patch_width = pair(image_patch_size) |
| 105 | + |
| 106 | + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' |
| 107 | + assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size' |
| 108 | + |
| 109 | + num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size) |
| 110 | + patch_dim = channels * patch_height * patch_width * frame_patch_size |
| 111 | + |
| 112 | + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' |
| 113 | + |
| 114 | + self.to_patch_embedding = nn.Sequential( |
| 115 | + Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), |
| 116 | + nn.Linear(patch_dim, dim), |
| 117 | + ) |
| 118 | + |
| 119 | + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) |
| 120 | + self.dropout = nn.Dropout(emb_dropout) |
| 121 | + self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) |
| 122 | + self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) |
| 123 | + |
| 124 | + self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout) |
| 125 | + self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout) |
| 126 | + |
| 127 | + self.pool = pool |
| 128 | + self.to_latent = nn.Identity() |
| 129 | + |
| 130 | + self.mlp_head = nn.Sequential( |
| 131 | + nn.LayerNorm(dim), |
| 132 | + nn.Linear(dim, num_classes) |
| 133 | + ) |
| 134 | + |
| 135 | + def forward(self, img): |
| 136 | + x = self.to_patch_embedding(img) |
| 137 | + b, f, n, _ = x.shape |
| 138 | + |
| 139 | + spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f) |
| 140 | + x = torch.cat((spatial_cls_tokens, x), dim = 2) |
| 141 | + x += self.pos_embedding[:, :(n + 1)] |
| 142 | + x = self.dropout(x) |
| 143 | + |
| 144 | + x = rearrange(x, 'b f n d -> (b f) n d') |
| 145 | + |
| 146 | + # attend across space |
| 147 | + |
| 148 | + x = self.spatial_transformer(x) |
| 149 | + |
| 150 | + x = rearrange(x, '(b f) n d -> b f n d', b = b) |
| 151 | + |
| 152 | + # excise out the spatial cls tokens for temporal attention |
| 153 | + |
| 154 | + x = x[:, :, 0] |
| 155 | + |
| 156 | + # append temporal CLS tokens |
| 157 | + |
| 158 | + temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b) |
| 159 | + |
| 160 | + x = torch.cat((temporal_cls_tokens, x), dim = 1) |
| 161 | + |
| 162 | + # attend across time |
| 163 | + |
| 164 | + x = self.temporal_transformer(x) |
| 165 | + |
| 166 | + x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] |
| 167 | + |
| 168 | + x = self.to_latent(x) |
| 169 | + return self.mlp_head(x) |
0 commit comments