-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathT1_model.py
More file actions
137 lines (117 loc) · 5.12 KB
/
T1_model.py
File metadata and controls
137 lines (117 loc) · 5.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken
# Hyperparameters
context_length = 128 # Length of the token chunk each batch
d_model = 512 # The size of our model token embeddings
num_blocks = 12 # Number of transformer blocks
num_heads = 8 # Number of heads in Multi-head attention
dropout = 0.1 # Dropout rate
device = 'cuda' if torch.cuda.is_available() else 'cpu' # Use GPU if it's available.
TORCH_SEED = 1337
torch.manual_seed(TORCH_SEED)
# Define feed forward network
class FeedForwardNetwork(nn.Module):
def __init__(self):
super().__init__()
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(),
nn.Linear(d_model * 4, d_model),
nn.Dropout(dropout)
)
def forward(self, x):
return self.ffn(x)
# Define Scaled Dot Product Attention
class Attention(nn.Module):
def __init__(self):
super().__init__()
self.Wq = nn.Linear(d_model, d_model // num_heads, bias=False)
self.Wk = nn.Linear(d_model, d_model // num_heads, bias=False)
self.Wv = nn.Linear(d_model, d_model // num_heads, bias=False)
self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
q = self.Wq(x)
k = self.Wk(x)
v = self.Wv(x)
weights = (q @ k.transpose(-2, -1)) / math.sqrt(d_model // num_heads)
weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
weights = self.dropout(weights)
output = weights @ v
return output
# Define Multi-head Attention
class MultiHeadAttention(nn.Module):
def __init__(self):
super().__init__()
self.heads = nn.ModuleList([Attention() for _ in range(num_heads)])
self.projection_layer = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
head_outputs = [head(x) for head in self.heads]
head_outputs = torch.cat(head_outputs, dim=-1)
out = self.dropout(self.projection_layer(head_outputs))
return out
# Define Transformer Block
class TransformerBlock(nn.Module):
def __init__(self):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.mha = MultiHeadAttention()
self.ffn = FeedForwardNetwork()
def forward(self, x):
x = x + self.mha(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
# Define the model
class Model(nn.Module):
def __init__(self, max_token_value=100256): # if not passed, force to be default tiktoken cl100k vocab size
super().__init__()
self.token_embedding_lookup_table = nn.Embedding(max_token_value, d_model)
self.transformer_blocks = nn.Sequential(*(
[TransformerBlock() for _ in range(num_blocks)] +
[nn.LayerNorm(d_model)]
))
self.model_out_linear_layer = nn.Linear(d_model, max_token_value)
def forward(self, idx, targets=None):
B, T = idx.shape
position_encoding_lookup_table = torch.zeros(context_length, d_model, device=device)
position = torch.arange(0, context_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)
position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)
# change position_encoding_lookup_table from (context_length, d_model) to (T, d_model)
position_embedding = position_encoding_lookup_table[:T, :].to(device)
x = self.token_embedding_lookup_table(idx) + position_embedding
x = self.transformer_blocks(x)
# get the final logits
logits = self.model_out_linear_layer(x)
if targets is not None:
B, T, C = logits.shape
logits_reshaped = logits.view(B * T, C)
targets_reshaped = targets.view(B * T)
loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)
else:
loss = None
return logits, loss
def generate(self, idx, max_new_tokens=100):
# idx is (B,T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop idx to the max size of our positional embeddings table
idx_crop = idx[:, -context_length:]
# Get predictions
logits, loss = self.forward(idx_crop)
# Get the last time step from logits where the dimensions of the logits are (B,T,C)
logits_last_timestep = logits[:, -1, :]
# Apply softmax to get probabilities
probs = F.softmax(input=logits_last_timestep, dim=-1)
# Sample from the probabilities' distribution.
idx_next = torch.multinomial(input=probs, num_samples=1)
# Append the sampled indexes idx_next to idx
idx = torch.cat((idx, idx_next), dim=1)
return idx