Skip to content

Commit 67649a9

Browse files
committed
added bias + no Linear inheritance
1 parent 2795ce3 commit 67649a9

File tree

1 file changed

+73
-38
lines changed

1 file changed

+73
-38
lines changed

xtra_labs/llm_finetune/draft.py

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55
import math
66
import numpy as np
77
import pandas as pd
8+
import random
89
import tensorflow as tf
910
import torch
1011
import torch.nn as nn
1112
import torch.nn.functional as F
13+
from torch.nn import CrossEntropyLoss
14+
from torch.optim import Adam
1215
import transformers
1316
from trl import SFTTrainer
17+
from tqdm import tqdm
1418

1519
from utils import run_benchmark, make_spider_plot
1620

@@ -61,8 +65,8 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
6165

6266
# TEXT: some background on LLM benchmarking
6367
# Load benchmark dataset and evaluate model
64-
dataset = pd.read_csv("benchmark.csv")
65-
# category_accs_1300m, avg_acc_1300m = run_benchmark(model, tokenizer, dataset)
68+
benchmark_dataset = pd.read_csv("benchmark.csv")
69+
# category_accs_1300m, avg_acc_1300m = run_benchmark(model, tokenizer, benchmark_dataset)
6670

6771
# TEXT: ask them to make a prediction on how accuracy will be affected by different model sizes
6872

@@ -71,14 +75,14 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
7175
# model_350m = transformers.AutoModelForCausalLM.from_pretrained(model_name_350m, device_map="auto")
7276
# tokenizer_350m = transformers.AutoTokenizer.from_pretrained(model_name_350m)
7377

74-
# category_accs_350m, avg_acc_350m = run_benchmark(model_350m, tokenizer_350m, dataset)
78+
# category_accs_350m, avg_acc_350m = run_benchmark(model_350m, tokenizer_350m, benchmark_dataset)
7579

7680
# Benchmark larger model
7781
# model_name_2700m = "facebook/opt-2.7b"
7882
# model_2700m = transformers.AutoModelForCausalLM.from_pretrained(model_name_2700m, device_map="auto")
7983
# tokenizer_2700m = transformers.AutoTokenizer.from_pretrained(model_name_2700m)
8084

81-
# category_accs_2700m, avg_acc_2700m = run_benchmark(model_2700m, tokenizer_2700m, dataset)
85+
# category_accs_2700m, avg_acc_2700m = run_benchmark(model_2700m, tokenizer_2700m, benchmark_dataset)
8286

8387
# Spider plot
8488

@@ -87,88 +91,119 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
8791

8892
# Part 2
8993

94+
def count_grad_parameters(model):
95+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
96+
9097
# inspect current model
9198
# print(model)
99+
first_lin_layer = model.model.decoder.layers[0].self_attn.k_proj
100+
print(count_grad_parameters(model))
92101

93102
# new LoRA linear layer class
94-
class LoRALinear(nn.Linear):
103+
class LoRALinear(nn.Module):
95104
def __init__(
96105
self,
97106
in_features: int,
98107
out_features: int,
99108
pretrained_weight: torch.Tensor,
109+
pretrained_bias: torch.Tensor,
100110
r: int = 8,
101111
lora_alpha: int = 1,
102112
**kwargs
103113
):
114+
super(LoRALinear, self).__init__()
115+
104116
self.r = r
105117
self.in_features = in_features
106118
self.out_features = out_features
107119
self.lora_alpha = lora_alpha
108120

109-
nn.Linear.__init__(self, in_features, out_features, **kwargs)
110-
self.weight.data = pretrained_weight
121+
self.weight = nn.Parameter(pretrained_weight)
122+
self.weight.requires_grad = False
111123

112-
# from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
113-
if r > 0:
114-
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
115-
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
116-
self.scaling = self.lora_alpha / self.r
117-
# Freezing the pre-trained weight matrix
118-
self.weight.requires_grad = False
124+
self.bias = nn.Parameter(pretrained_bias)
125+
self.bias.requires_grad = False
119126

127+
# from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
128+
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
129+
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
130+
self.scaling = self.lora_alpha / self.r
131+
120132
def forward(self, x: torch.Tensor):
121-
if self.r > 0:
122-
result = F.linear(x, self.weight, bias=self.bias)
123-
result += (x @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
124-
return result
125-
else:
126-
return F.linear(x, self.weight, bias=self.bias)
133+
result = F.linear(x, self.weight, bias=self.bias)
134+
result += self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1) * self.scaling
135+
return result
127136

128137
# replace linear layers in model recursively
129138
def replace_linear_with_lora(module):
130139
for name, child in module.named_children():
131140
if isinstance(child, nn.Linear):
132-
setattr(module, name, LoRALinear(child.in_features, child.out_features, child.weight))
141+
setattr(module, name, LoRALinear(child.in_features, child.out_features, child.weight, child.bias))
133142
break
134143
else:
135144
replace_linear_with_lora(child)
136145

137146
replace_linear_with_lora(model)
138147

139148
# inspect new model
140-
# print(model)
149+
first_lin_layer = model.model.decoder.layers[0].self_attn.k_proj
150+
print(count_grad_parameters(model))
151+
exit()
141152

142153
# load chat dataset
143154
dataset_name = "timdettmers/openassistant-guanaco"
144155
ft_dataset = load_dataset(dataset_name, split="train")
145156

146157
# train model (barebones loop)
147-
batch_size = 4
148158
context_length = 768
159+
loss_fn = CrossEntropyLoss()
160+
161+
learning_rate = 1e-4
162+
optimizer = Adam(model.parameters(), lr=learning_rate)
163+
num_epochs = 5
149164

150165
model = model.to("cuda")
151-
for batch in ft_dataset:
152-
prompt = batch["text"]
153-
154-
# encode with tokenizer
155-
x = tokenizer.encode(prompt)
156-
x_tensor = torch.tensor(x).view(1, -1).to("cuda")
157-
input_tensor = x_tensor[:,:context_length]
158-
target_next_word = x_tensor[:,context_length]
159166

160-
# run through model
161-
logits = model(input_tensor).logits
167+
for epoch in range(num_epochs):
168+
total_loss = 0
169+
num_batches = 0
162170

163-
probs = F.softmax(logits, dim=-1)[0, -1, :].cpu().detach()
164-
new_token = np.random.choice(len(probs), p=probs.numpy())
165-
print(tokenizer.decode(new_token), end='', flush=True)
171+
for batch in tqdm(ft_dataset):
172+
prompt = batch["text"]
173+
174+
# encode with tokenizer
175+
x = tokenizer.encode(prompt)
176+
x_tensor = torch.tensor(x).view(1, -1).to("cuda")
177+
max_len = min(context_length, x_tensor.shape[1]-1)
178+
selected_len = random.randint(1,max_len)
179+
180+
input_tensor = x_tensor[:,:selected_len]
181+
target_tensor = x_tensor[0,1:selected_len+1]
182+
183+
# zero gradients
184+
optimizer.zero_grad()
185+
186+
# run through model
187+
logits = model(input_tensor).logits[0]
188+
189+
# apply loss
190+
loss = loss_fn(logits, target_tensor)
191+
192+
# backpropagation
193+
loss.backward()
194+
195+
# optimizer step
196+
optimizer.step()
166197

167-
# apply loss
198+
total_loss += loss.item()
199+
num_batches += 1
168200

201+
# Print average loss for the epoch
202+
average_loss = total_loss / num_batches
203+
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}")
169204

170205
# evaluate finetuned model on benchmark
171-
category_accs_1300m_ft, avg_acc_1300m_ft = run_benchmark(model, tokenizer, dataset)
206+
category_accs_1300m_ft, avg_acc_1300m_ft = run_benchmark(model, tokenizer, benchmark_dataset)
172207

173208
# add to spider plot
174209
# benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "1300M-Model-Finetuned": category_accs_1300m_ft, "2700M-Model": category_accs_2700m}

0 commit comments

Comments
 (0)