Skip to content

Commit d69f1c1

Browse files
committed
LoRA experiment notes
1 parent 64959fd commit d69f1c1

File tree

2 files changed

+77
-28
lines changed

2 files changed

+77
-28
lines changed

labml_nn/lora/experiment.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
"""
2+
---
3+
title: Finetune GPT-2 with LoRA
4+
summary: This is training code with notes for fine-tuning pre-trained GPT-2 model with LoRA.
5+
---
6+
7+
# Finetune GPT-2 with [LoRA](index.html)
8+
9+
Here's a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.
10+
11+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/lora/experiment.ipynb)
12+
"""
13+
114
import torch
215
from labml import lab, monit, tracker
316
from labml.configs import BaseConfigs, option
@@ -9,19 +22,31 @@
922
from labml_nn.lora.gpt2 import GPTModel
1023

1124

12-
class Configs(BaseConfigs):
25+
class Trainer(BaseConfigs):
26+
"""
27+
## Trainer configurations and the training loop
28+
29+
The default configs can and will be over-ridden when we start the experiment
30+
"""
1331
device: torch.device = DeviceConfigs()
32+
33+
# GPT-2 configs
1434
layer_norm_epsilon: float = 1e-05
1535
n_embed: int = 768
1636
n_layer: int = 12
1737
n_positions: int = 1024
1838
vocab_size: int = 50257
39+
40+
# Training configs
1941
epochs: int = 10
2042
batch_size: int = 32
2143
learning_rate: float = 1e-4
2244
context_len: int = 512
23-
r: int = 32
2445

46+
# LoRA rank
47+
lora_r: int = 32
48+
49+
# Dataset
2550
text: TensorDataset = "tiny_shakespeare"
2651
tokenizer = AutoTokenizer.from_pretrained("gpt2")
2752
model: GPTModel
@@ -30,10 +55,15 @@ class Configs(BaseConfigs):
3055
data_loader: DataLoader
3156

3257
def _load_pretrained_weights(self):
33-
hf_model = AutoModelForCausalLM.from_pretrained("gpt2")
58+
"""
59+
### Load pre-trained [GPT-2 from huggingface](https://huggingface.co/openai-community/gpt2)
60+
"""
3461

62+
# Load the huggingface model and get the parameters
63+
hf_model = AutoModelForCausalLM.from_pretrained("gpt2")
3564
state_dict = hf_model.state_dict()
3665

66+
# Transformer embedding and prediction layer parameter mapping (`hf: ours`)
3767
mapping = {
3868
'transformer.wte.weight': 'token_embedding.weight',
3969
'transformer.wpe.weight': 'position_embedding.weight',
@@ -42,6 +72,7 @@ def _load_pretrained_weights(self):
4272
'lm_head.weight': 'lm_head.weight'
4373
}
4474

75+
# Mapping (`hf: ours`) of decoder layers
4576
for i in range(12):
4677
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
4778
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
@@ -56,12 +87,13 @@ def _load_pretrained_weights(self):
5687
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
5788
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'
5889

90+
# Move the parameters based on mapping
5991
new_state_dict = {}
6092
for old_key, new_key in mapping.items():
6193
if old_key in state_dict:
6294
new_state_dict[new_key] = state_dict[old_key]
6395

64-
# transpose weight matrices of convo 1d layers to use linear layers instead
96+
# GPT-2 hugging face uses 1D Convolution layers. We need to transpose those weights since we use linear layers
6597
convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
6698
[f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
6799
[f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
@@ -70,29 +102,37 @@ def _load_pretrained_weights(self):
70102
for layer in convo_layers:
71103
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
72104

105+
# Load out model
73106
self.model.load_state_dict(new_state_dict, strict=False) # state dict does not have lora weights
74107

75-
del hf_model
76-
del state_dict
77-
del new_state_dict
78-
79108
def initialize(self):
109+
"""
110+
### Initialize the model, optimizer and dataloader
111+
"""
112+
# Initialize the model
80113
self.model = GPTModel(
81114
layer_norm_epsilon=self.layer_norm_epsilon,
82115
n_embd=self.n_embed,
83116
n_layer=self.n_layer,
84117
n_positions=self.n_positions,
85118
vocab_size=self.vocab_size,
86-
r=self.r,
87-
device=self.device
88-
).to(self.device)
119+
r=self.lora_r,
120+
)
121+
self.model.to(self.device)
122+
# Load pre-trained model weights
89123
self._load_pretrained_weights()
90124

125+
# Initialize the optimizer
91126
self.optimizer = Adam(self.model.parameters(), lr=self.learning_rate)
92127

128+
# Initialize the data loader
93129
self.data_loader = DataLoader(self.text, batch_size=self.batch_size, shuffle=True)
94130

95131
def run(self):
132+
"""
133+
### Training loop
134+
"""
135+
96136
for _ in monit.loop(self.epochs):
97137
for i, batch in monit.enum('Train', self.data_loader):
98138
inputs = batch[0]
@@ -117,8 +157,8 @@ def run(self):
117157
tracker.new_line()
118158

119159

120-
@option(Configs.text)
121-
def tiny_shakespeare(c: Configs):
160+
@option(Trainer.text)
161+
def tiny_shakespeare(c: Trainer):
122162
"""
123163
### Tiny Shakespeare dataset
124164

labml_nn/lora/gpt2.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
class FFN(nn.Module):
77
def __init__(self, dim: int, n_embed: int, r: int):
88
super().__init__()
9+
# lin1
910
self.c_fc = Linear(n_embed, dim, r=r, bias=True)
11+
# lin2
1012
self.c_proj = Linear(dim, n_embed, r=r, bias=True)
1113
self.act = nn.functional.gelu
1214

@@ -25,7 +27,9 @@ def __init__(self, n_embed: int, r: int):
2527
self.head_dim = self.embed_dim // self.num_heads
2628
self.split_size = self.embed_dim
2729

30+
# qkv
2831
self.c_att = Linear(n_embed, n_embed * 3, r=r, bias=True)
32+
# out
2933
self.c_proj = Linear(n_embed, n_embed, r=r, bias=True)
3034

3135
def _split_heads(self, tensor, num_heads, attn_head_size):
@@ -87,7 +91,7 @@ def forward(self, hidden_states):
8791

8892
class GPTModel(nn.Module):
8993
def __init__(self, layer_norm_epsilon: float, n_embd: int, n_layer: int, n_positions: int,
90-
vocab_size: int, r: int, device: torch.device):
94+
vocab_size: int, r: int):
9195
super().__init__()
9296

9397
self.token_embedding = Embedding(vocab_size, n_embd, r=r)
@@ -100,22 +104,27 @@ def __init__(self, layer_norm_epsilon: float, n_embd: int, n_layer: int, n_posit
100104

101105
self.lm_head = Linear(n_embd, vocab_size, r=r, bias=False)
102106

103-
self.device = device
104-
105-
def forward(self, input_ids):
106-
batch_size, input_shape = input_ids.size()
107+
def forward(self, input_ids: torch.Tensor):
108+
"""
109+
:param input_ids: has shape `[batch_size, seq_len]`
110+
"""
111+
batch_size, seq_len = input_ids.shape
107112

108-
token_embeddings = self.token_embedding(input_ids) # B T C
109-
position_ids = torch.arange(input_shape, device=self.device) # T C
110-
position_embeddings = self.position_embedding(position_ids) # B T C
113+
# Get token embeddings
114+
token_embeddings = self.token_embedding(input_ids)
115+
# Get position ids
116+
position_ids = torch.arange(seq_len, device=input_ids.device)[None, :]
117+
# Get position embeddings
118+
position_embeddings = self.position_embedding(position_ids)
111119

112-
hidden_states = token_embeddings + position_embeddings
120+
# Add position embeddings
121+
x = token_embeddings + position_embeddings
113122

123+
# Run through transformer blocks
114124
for block in self.blocks:
115-
hidden_states = block(hidden_states)
116-
117-
hidden_states = self.final_norm(hidden_states)
118-
119-
logits = self.lm_head(hidden_states)
125+
x = block(x)
120126

121-
return logits
127+
# Final normalization
128+
x = self.final_norm(x)
129+
# Get logits from projection layer
130+
return self.lm_head(x)

0 commit comments

Comments
 (0)