Skip to content

Commit cbc38bb

Browse files
committed
GPT 2 implementation
1 parent 89a3ae8 commit cbc38bb

File tree

2 files changed

+274
-0
lines changed

2 files changed

+274
-0
lines changed

docs/transformers/LoRA/GPT2.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
import torch
2+
import torch.nn as nn
3+
from transformers import AutoTokenizer
4+
5+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
6+
7+
# config from GPT
8+
config = {
9+
"_name_or_path": "gpt2",
10+
"activation_function": "gelu_new",
11+
"architectures": [
12+
"GPT2LMHeadModel"
13+
],
14+
"attn_pdrop": 0.1,
15+
"bos_token_id": 50256,
16+
"embd_pdrop": 0.1,
17+
"eos_token_id": 0,
18+
"initializer_range": 0.02,
19+
"layer_norm_epsilon": 1e-05,
20+
"model_type": "gpt2",
21+
"n_ctx": 1024,
22+
"n_embd": 768,
23+
"n_head": 12,
24+
"n_inner": None,
25+
"n_layer": 12,
26+
"n_positions": 1024,
27+
"reorder_and_upcast_attn": False,
28+
"resid_pdrop": 0.1,
29+
"scale_attn_by_inverse_layer_idx": False,
30+
"scale_attn_weights": True,
31+
"summary_activation": None,
32+
"summary_first_dropout": 0.1,
33+
"summary_proj_to_labels": True,
34+
"summary_type": "cls_index",
35+
"summary_use_proj": True,
36+
"task_specific_params": {
37+
"text-generation": {
38+
"do_sample": True,
39+
"max_length": 50
40+
}
41+
},
42+
"transformers_version": "4.42.4",
43+
"use_cache": True,
44+
"vocab_size": 50257
45+
}
46+
47+
import math
48+
from torch import Tensor
49+
50+
51+
# from transformers
52+
class Conv1D(nn.Module):
53+
"""
54+
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
55+
56+
Basically works like a linear layer but the weights are transposed.
57+
58+
Args:
59+
nf (`int`): The number of output features.
60+
nx (`int`): The number of input features.
61+
"""
62+
63+
def __init__(self, nf, nx):
64+
super().__init__()
65+
self.nf = nf
66+
self.weight = nn.Parameter(torch.empty(nx, nf))
67+
self.bias = nn.Parameter(torch.zeros(nf))
68+
nn.init.normal_(self.weight, std=0.02)
69+
70+
def forward(self, x):
71+
size_out = x.size()[:-1] + (self.nf,)
72+
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
73+
x = x.view(size_out)
74+
return x
75+
76+
77+
# from transformers
78+
class NewGELUActivation(nn.Module):
79+
"""
80+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
81+
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
82+
"""
83+
84+
def forward(self, input: Tensor) -> Tensor:
85+
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
86+
87+
88+
class HeadFFN(nn.Module): # todo rename
89+
def __init__(self, dim):
90+
super().__init__()
91+
self.c_fc = Conv1D(dim, config['n_embd'])
92+
self.c_proj = Conv1D(config['n_embd'], dim)
93+
self.act = NewGELUActivation()
94+
self.dropout = nn.Dropout(config['resid_pdrop'])
95+
96+
def forward(self, hidden_states):
97+
hidden_states = self.c_fc(hidden_states)
98+
hidden_states = self.act(hidden_states)
99+
hidden_states = self.c_proj(hidden_states)
100+
hidden_states = self.dropout(hidden_states)
101+
return hidden_states
102+
103+
104+
class MultiHead(nn.Module):
105+
def __init__(self):
106+
super().__init__()
107+
self.embed_dim = config['n_embd']
108+
self.num_heads = config['n_head']
109+
self.head_dim = self.embed_dim // self.num_heads
110+
self.split_size = self.embed_dim
111+
112+
self.c_att = Conv1D(config['n_embd'] * 3, config['n_embd'])
113+
self.c_proj = Conv1D(config['n_embd'], config['n_embd'])
114+
115+
self.resid_dropout = nn.Dropout(config['resid_pdrop'])
116+
self.attn_dropout = nn.Dropout(config['attn_pdrop'])
117+
118+
def _split_heads(self, tensor, num_heads, attn_head_size):
119+
"""
120+
Splits hidden_size dim into attn_head_size and num_heads
121+
"""
122+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
123+
tensor = tensor.view(new_shape)
124+
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
125+
126+
def forward(self, hidden_states):
127+
batch_size, seq_length, _ = hidden_states.size()
128+
129+
query, key, value = self.c_att(hidden_states).split(self.split_size, dim=2)
130+
131+
query = self._split_heads(query, self.num_heads, self.head_dim)
132+
key = self._split_heads(key, self.num_heads, self.head_dim)
133+
value = self._split_heads(value, self.num_heads, self.head_dim)
134+
135+
attn_output = torch.nn.functional.scaled_dot_product_attention(
136+
query,
137+
key,
138+
value,
139+
attn_mask=None,
140+
dropout_p=self.attn_dropout.p if self.training else 0.0,
141+
is_causal=True, # for the triangular mask
142+
)
143+
144+
# todo why this?
145+
attn_output = attn_output.transpose(1, 2).contiguous()
146+
attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)
147+
148+
attn_output = self.c_proj(attn_output)
149+
attn_output = self.resid_dropout(attn_output)
150+
151+
return attn_output
152+
153+
154+
class Block(nn.Module):
155+
def __init__(self):
156+
super().__init__()
157+
self.pre_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
158+
self.attn = MultiHead()
159+
self.post_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
160+
self.ffn = HeadFFN(config['n_embd'] * 4)
161+
162+
def forward(self, hidden_states):
163+
residual = hidden_states
164+
hidden_states = self.pre_norm(hidden_states)
165+
166+
attn_output = self.attn(hidden_states)
167+
168+
hidden_states = attn_output + residual
169+
residual = hidden_states
170+
hidden_states = self.post_norm(hidden_states)
171+
feed_forward_output = self.ffn(hidden_states)
172+
hidden_states = feed_forward_output + residual
173+
174+
return hidden_states
175+
176+
177+
class GPTModel(nn.Module):
178+
# todo ignored token type embeds, past key values
179+
def __init__(self):
180+
super().__init__()
181+
182+
self.token_embedding = nn.Embedding(config['vocab_size'], config['n_embd'])
183+
self.position_embedding = nn.Embedding(config['n_positions'], config['n_embd'])
184+
185+
self.dropout = nn.Dropout(p=config['embd_pdrop'], inplace=False)
186+
187+
self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])])
188+
189+
self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
190+
191+
self.lm_head = nn.Linear(config['n_embd'], config['vocab_size'], bias=False)
192+
193+
def forward(self, input_ids):
194+
batch_size, input_shape = input_ids.size()
195+
196+
token_embeddings = self.token_embedding(input_ids) # B T C
197+
position_ids = torch.arange(input_shape) # T C
198+
position_embeddings = self.position_embedding(position_ids) # B T C
199+
200+
embeddings = token_embeddings + position_embeddings
201+
202+
hidden_states = self.dropout(embeddings)
203+
204+
for block in self.blocks:
205+
hidden_states = block(hidden_states)
206+
207+
hidden_states = self.final_norm(hidden_states)
208+
209+
logits = self.lm_head(hidden_states)
210+
211+
return logits
212+
213+
214+
model = GPTModel()
215+
216+
state_dict = torch.load('transformed.pth')
217+
218+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
219+
if missing_keys:
220+
print(f"Missing keys: {missing_keys}")
221+
if unexpected_keys:
222+
print(f"Unexpected keys: {unexpected_keys}")
223+
224+
prompt = "hello how are you"
225+
tokenized = tokenizer(prompt, return_tensors="pt")
226+
227+
with torch.no_grad():
228+
model.eval()
229+
res = model(tokenized['input_ids'])
230+
231+
print(res)
232+
233+
output_ids = torch.argmax(res, dim=-1)
234+
235+
# Decode the token indices back to text
236+
output_text = tokenizer.decode(output_ids[0])
237+
238+
# Print the tokens of the output
239+
print(output_text)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
from transformers import AutoModelForCausalLM
3+
4+
model = AutoModelForCausalLM.from_pretrained("gpt2")
5+
6+
state_dict = model.state_dict()
7+
8+
mapping = {
9+
'transformer.wte.weight': 'token_embedding.weight',
10+
'transformer.wpe.weight': 'position_embedding.weight',
11+
'transformer.ln_f.weight': 'final_norm.weight',
12+
'transformer.ln_f.bias': 'final_norm.bias',
13+
'lm_head.weight': 'lm_head.weight'
14+
}
15+
16+
for i in range(12):
17+
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
18+
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
19+
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
20+
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
21+
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
22+
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
23+
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
24+
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
25+
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
26+
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
27+
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
28+
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'
29+
30+
new_state_dict = {}
31+
for old_key, new_key in mapping.items():
32+
if old_key in state_dict:
33+
new_state_dict[new_key] = state_dict[old_key]
34+
35+
torch.save(new_state_dict, 'transformed.pth')

0 commit comments

Comments
 (0)