Skip to content

Commit 3b39dfd

Browse files
committed
full lab (part 2 benchmark untested)
1 parent 778df4f commit 3b39dfd

File tree

3 files changed

+96
-10
lines changed

3 files changed

+96
-10
lines changed

xtra_labs/llm_finetune/draft.py

Lines changed: 95 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
"""
22
Drafting lab flow in script format using PyTorch
33
"""
4-
4+
from datasets import load_dataset
5+
import math
56
import numpy as np
67
import pandas as pd
78
import tensorflow as tf
9+
import torch
10+
import torch.nn as nn
11+
import torch.nn.functional as F
812
import transformers
13+
from trl import SFTTrainer
914

1015
from utils import run_benchmark, make_spider_plot
1116

@@ -63,14 +68,14 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
6368
# Benchmark smaller model
6469
model_name_350m = "facebook/opt-350m"
6570
model_350m = transformers.AutoModelForCausalLM.from_pretrained(model_name_350m, device_map="auto")
66-
tokenizer_350m = transformers.AutoTokenizer.from_pretrained(model_350m)
71+
tokenizer_350m = transformers.AutoTokenizer.from_pretrained(model_name_350m)
6772

6873
category_accs_350m, avg_acc_350m = run_benchmark(model_350m, tokenizer_350m, dataset)
6974

7075
# Benchmark larger model
7176
model_name_2700m = "facebook/opt-2.7b"
7277
model_2700m = transformers.AutoModelForCausalLM.from_pretrained(model_name_2700m, device_map="auto")
73-
tokenizer_2700m = transformers.AutoTokenizer.from_pretrained(model_2700m)
78+
tokenizer_2700m = transformers.AutoTokenizer.from_pretrained(model_name_2700m)
7479

7580
category_accs_2700m, avg_acc_2700m = run_benchmark(model_2700m, tokenizer_2700m, dataset)
7681

@@ -81,16 +86,98 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
8186

8287
# Part 2
8388

84-
# new LoRA linear layer class
85-
86-
# new attention layer class
89+
# inspect current model
90+
print(model)
8791

88-
# replace attention modules with new module
92+
# new LoRA linear layer class
93+
class LoRALinear(nn.Linear):
94+
def __init__(
95+
self,
96+
in_features: int,
97+
out_features: int,
98+
r: int = 8,
99+
lora_alpha: int = 1,
100+
**kwargs
101+
):
102+
self.r = r
103+
self.in_features = in_features
104+
self.out_features = out_features
105+
self.lora_alpha = lora_alpha
106+
107+
nn.Linear.__init__(self, in_features, out_features, **kwargs)
108+
109+
# from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
110+
if r > 0:
111+
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
112+
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
113+
self.scaling = self.lora_alpha / self.r
114+
# Freezing the pre-trained weight matrix
115+
self.weight.requires_grad = False
116+
self.reset_parameters()
117+
118+
def reset_parameters(self):
119+
# from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
120+
nn.Linear.reset_parameters(self)
121+
if hasattr(self, 'lora_A'):
122+
# initialize B the same way as the default for nn.Linear and A to zero
123+
# this is different than what is described in the paper but should not affect performance
124+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
125+
nn.init.zeros_(self.lora_B)
126+
127+
def forward(self, x: torch.Tensor):
128+
if self.r > 0:
129+
result = F.linear(x, self.weight, bias=self.bias)
130+
result += (x @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
131+
return result
132+
else:
133+
return F.linear(x, self.weight, bias=self.bias)
134+
135+
# replace linear layers in model recursively
136+
def replace_linear_with_lora(module):
137+
for name, child in module.named_children():
138+
if isinstance(child, nn.Linear):
139+
setattr(module, name, LoRALinear(child.in_features, child.out_features))
140+
else:
141+
replace_linear_with_lora(child)
142+
143+
replace_linear_with_lora(model)
144+
145+
# inspect new model
146+
print(model)
89147

90148
# load chat dataset
149+
dataset_name = "timdettmers/openassistant-guanaco"
150+
ft_dataset = load_dataset(dataset_name, split="train")
91151

92152
# train model
153+
log_dir = "/scratch/checkpoints/test-sft/opt1.3b_768/"
154+
batch_size = 4
155+
context_length = 768
156+
args = transformers.TrainingArguments(log_dir,
157+
per_device_train_batch_size=batch_size,
158+
logging_first_step=True,
159+
logging_steps=20,
160+
save_steps=100,
161+
)
162+
163+
class PrinterCallback(transformers.TrainerCallback):
164+
def on_log(self, args, state, control, model, logs=None, **kwargs):
165+
start_text = "### Human: When the weather is sunny, what color is the sky?### Assistant:"
166+
generate(start_text, model, tokenizer, num_steps=200, until="###")
167+
168+
trainer = SFTTrainer(
169+
model,
170+
args=args,
171+
train_dataset=ft_dataset,
172+
dataset_text_field="text",
173+
max_seq_length=context_length,
174+
callbacks=[PrinterCallback()]
175+
)
176+
trainer.train()
93177

94178
# evaluate finetuned model on benchmark
179+
category_accs_1300m_ft, avg_acc_1300m_ft = run_benchmark(model, tokenizer, dataset)
95180

96-
# add to spider plot
181+
# add to spider plot
182+
benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "1300M-Model-Finetuned": category_accs_1300m_ft, "2700M-Model": category_accs_2700m}
183+
make_spider_plot(benchmark_data)

xtra_labs/llm_finetune/spider.png

14 KB
Loading

xtra_labs/llm_finetune/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def make_spider_plot(data):
6666
"""
6767
colors = ['#1aaf6c', '#429bf4', '#d42cea']
6868
i = 0
69+
fig, ax = plt.subplots(figsize=(8,6), subplot_kw=dict(polar=True))
6970
for k,v in data.items():
7071
labels = v.index.tolist()
7172
values = v.values.tolist()
@@ -74,8 +75,6 @@ def make_spider_plot(data):
7475
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
7576
angles += angles[:1]
7677
values += values[:1]
77-
78-
fig, ax = plt.subplots(figsize=(8,6), subplot_kw=dict(polar=True))
7978

8079
ax.plot(angles, values, color=colors[i], linewidth=1, label=k)
8180
ax.fill(angles, values, color=colors[i], alpha=0.25)

0 commit comments

Comments
 (0)