-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembedding_module.py
More file actions
59 lines (45 loc) · 1.67 KB
/
embedding_module.py
File metadata and controls
59 lines (45 loc) · 1.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
from dataclasses import dataclass
@dataclass
class EmbeddingConfig:
vocab_size: int
d_model: int
block_size: int
dropout: float = 0.1
class TokenPositionalEmbedding(nn.Module):
def __init__(self, cfg: EmbeddingConfig):
super().__init__()
self.cfg = cfg
# Token embedding table: [vocal_size, d_model]
self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
# Positional embedding table: [block_size, d_model]
self.pos_emb = nn.Embedding(cfg.block_size, cfg.d_model)
# Dropout for regularization
self.drop = nn.Dropout(cfg.dropout)
# Initialize weights
nn.init.normal_(self.token_emb.weight, mean=0, std=0.02)
nn.init.normal_(self.pos_emb.weight, mean=0, std=0.02)
def forward(self, input_ids):
"""
input_ids: (B, T)
returns: (B, T, d_model)
"""
B, T = input_ids.shape
assert T<=self.cfg.block_size, (
f"Sequence length {T} exceeds block_size {self.cfg.block_size}"
)
# Token embeddings: (B, T, d_model)
tok_emb = self.token_emb(input_ids)
# Positional indices [0..T-1]
pos = torch.arange(T, device=input_ids.device).unsqueeze(0) # (1, T)
pos_emb = self.pos_emb(pos) # (1, T, d_model)
# Combine + dropout
x = tok_emb + pos_emb
return self.drop(x)
if __name__ == '__main__':
cfg = EmbeddingConfig(vocab_size=5000, d_model=256, block_size=128)
model = TokenPositionalEmbedding(cfg)
input_ids = torch.randint(0, cfg.vocab_size, (2, 64))
out = model(input_ids)
print("Output shape:", out.shape)