Skip to content

Commit 0b5cbad

Browse files
authored
Create model.py
Это скорее фан-пулреквест. Я коллекционирую интересные репозитории и иногда экспериментирую с небольшими изменениями. Возможно, правки полезны, возможно — нет, так что feel free to ignore :) В любом случае спасибо за классный проект!
1 parent 729963d commit 0b5cbad

File tree

1 file changed

+36
-0
lines changed
  • PyTorch/CustomStuff/MyAwesomeModel

1 file changed

+36
-0
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
import torch.nn as nn
3+
from config import Config
4+
5+
def generate_causal_mask(size):
6+
mask = torch.triu(torch.ones(size, size) * float('-inf'), diagonal=1)
7+
return mask
8+
9+
class GPT(nn.Module):
10+
def __init__(self):
11+
super().__init__()
12+
self.token_emb = nn.Embedding(Config.vocab_size, Config.d_model)
13+
self.pos_emb = nn.Parameter(torch.zeros(1, Config.seq_len, Config.d_model))
14+
encoder_layer = nn.TransformerEncoderLayer(
15+
d_model=Config.d_model,
16+
nhead=Config.n_heads,
17+
dim_feedforward=4 * Config.d_model,
18+
dropout=0.1,
19+
activation='gelu'
20+
)
21+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=Config.n_layers)
22+
self.ln_f = nn.LayerNorm(Config.d_model)
23+
self.head = nn.Linear(Config.d_model, Config.vocab_size)
24+
25+
def forward(self, idx):
26+
B, T = idx.size()
27+
tok = self.token_emb(idx) # (B, T, d_model)
28+
pos = self.pos_emb[:, :T, :] # (1, T, d_model)
29+
x = tok + pos
30+
x = x.transpose(0, 1) # (T, B, d_model)
31+
mask = generate_causal_mask(T).to(x.device)
32+
x = self.transformer(x, mask=mask)
33+
x = x.transpose(0, 1) # (B, T, d_model)
34+
x = self.ln_f(x)
35+
logits = self.head(x) # (B, T, vocab_size)
36+
return logits

0 commit comments

Comments
 (0)