|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from ..nn.transformer import ViTBlock |
| 4 | + |
| 5 | + |
| 6 | +class ViTTrunk(nn.Module): |
| 7 | + """ |
| 8 | + Vision Transformer (ViT) trunk that processes a sequence of patch embeddings with positional encoding |
| 9 | + and optional learnable registers, using a stack of ViTBlock layers. |
| 10 | +
|
| 11 | + Args: |
| 12 | + seq_len (int): Length of the input sequence (number of patches). |
| 13 | + d_model (int): Dimension of the model. |
| 14 | + num_layers (int): Number of transformer blocks. |
| 15 | + num_heads (int): Number of attention heads. |
| 16 | + num_registers (int, optional): Number of learnable registers to prepend to the sequence. Default: 10. |
| 17 | +
|
| 18 | + Forward Args: |
| 19 | + x (Tensor): Input tensor of shape [B, C, H/P, W/P], where P is the patch size. |
| 20 | +
|
| 21 | + Returns: |
| 22 | + Tensor: Output tensor of shape [B, C, H/P, W/P]. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__(self, seq_len, d_model, num_layers, num_heads, num_registers=10): |
| 26 | + super().__init__() |
| 27 | + self.trunk = nn.ModuleList( |
| 28 | + [ViTBlock(d_model, num_heads) for _ in range(num_layers)] |
| 29 | + ) |
| 30 | + self.pos_enc = nn.Parameter(torch.zeros(seq_len, d_model)) |
| 31 | + self.registers = nn.Parameter( |
| 32 | + torch.randn(num_registers, d_model) / (d_model**0.5) |
| 33 | + ) |
| 34 | + |
| 35 | + def forward(self, x): |
| 36 | + """ |
| 37 | + Forward pass for the ViTTrunk. |
| 38 | +
|
| 39 | + Args: |
| 40 | + x (Tensor): Input tensor of shape [B, C, H/P, W/P]. |
| 41 | +
|
| 42 | + Returns: |
| 43 | + Tensor: Output tensor of shape [B, C, H/P, W/P]. |
| 44 | + """ |
| 45 | + # x: [B,C,H/P,W/P] |
| 46 | + B, C, Hp, Wp = x.shape |
| 47 | + x = x.view(B, C, Hp * Wp).permute(0, 2, 1) |
| 48 | + x = x + self.pos_enc |
| 49 | + # x: [B, L, C] |
| 50 | + x = torch.cat([self.registers.unsqueeze(0).expand(B, -1, -1), x], dim=1) |
| 51 | + for block in self.trunk: |
| 52 | + x = block(x) |
| 53 | + |
| 54 | + x = x[:, len(self.registers) :, :] |
| 55 | + # x = F.gelu(x) |
| 56 | + x = x.permute(0, 2, 1).reshape(B, C, Hp, Wp) |
| 57 | + # x: [B,C,H/P,W/P] |
| 58 | + return x |
0 commit comments