Skip to content

Commit d1e8daa

Browse files
committed
replace convo1D layers with linear
1 parent 50c3cc4 commit d1e8daa

File tree

2 files changed

+13
-30
lines changed

2 files changed

+13
-30
lines changed

docs/transformers/LoRA/GPT2.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,37 +14,11 @@
1414
}
1515

1616

17-
# from transformers
18-
class Conv1D(nn.Module):
19-
"""
20-
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
21-
22-
Basically works like a linear layer but the weights are transposed.
23-
24-
Args:
25-
nf (`int`): The number of output features.
26-
nx (`int`): The number of input features.
27-
"""
28-
29-
def __init__(self, nf, nx):
30-
super().__init__()
31-
self.nf = nf
32-
self.weight = nn.Parameter(torch.empty(nx, nf))
33-
self.bias = nn.Parameter(torch.zeros(nf))
34-
nn.init.normal_(self.weight, std=0.02)
35-
36-
def forward(self, x):
37-
size_out = x.size()[:-1] + (self.nf,)
38-
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
39-
x = x.view(size_out)
40-
return x
41-
42-
4317
class HeadFFN(nn.Module): # todo rename
4418
def __init__(self, dim):
4519
super().__init__()
46-
self.c_fc = Conv1D(dim, config['n_embd'])
47-
self.c_proj = Conv1D(config['n_embd'], dim)
20+
self.c_fc = nn.Linear(config['n_embd'], dim)
21+
self.c_proj = nn.Linear(dim, config['n_embd'])
4822
self.act = nn.functional.gelu
4923

5024
def forward(self, hidden_states):
@@ -62,8 +36,8 @@ def __init__(self):
6236
self.head_dim = self.embed_dim // self.num_heads
6337
self.split_size = self.embed_dim
6438

65-
self.c_att = Conv1D(config['n_embd'] * 3, config['n_embd'])
66-
self.c_proj = Conv1D(config['n_embd'], config['n_embd'])
39+
self.c_att = nn.Linear(config['n_embd'], config['n_embd'] * 3)
40+
self.c_proj = nn.Linear(config['n_embd'], config['n_embd'])
6741

6842
def _split_heads(self, tensor, num_heads, attn_head_size):
6943
"""

docs/transformers/LoRA/gpt2_state_dict.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,13 @@
3232
if old_key in state_dict:
3333
new_state_dict[new_key] = state_dict[old_key]
3434

35+
# transpose weight matrices of convo 1d layers to use linear layers instead
36+
convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
37+
[f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
38+
[f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
39+
[f'blocks.{i}.attn.c_proj.weight' for i in range(12)])
40+
41+
for layer in convo_layers:
42+
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
43+
3544
torch.save(new_state_dict, 'transformed.pth')

0 commit comments

Comments
 (0)