Skip to content

Commit 778df4f

Browse files
committed
spider plot code
1 parent 5b48c67 commit 778df4f

File tree

3 files changed

+65
-28
lines changed

3 files changed

+65
-28
lines changed

xtra_labs/llm_finetune/draft.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,31 @@
77
import tensorflow as tf
88
import transformers
99

10-
from utils import run_benchmark
10+
from utils import run_benchmark, make_spider_plot
1111

1212
# Part 1
1313

1414
# TEXT: overview of LLM lab
1515
# Load pretrained LLM (medium size model)
1616

17-
model_name = "facebook/opt-1.3b"
18-
model = transformers.TFAutoModelForCausalLM.from_pretrained(model_name)
17+
model_name = "facebook/opt-1.3b"
18+
# had to load non TF version to run benchmarking code
19+
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
1920
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
2021

2122
# TEXT: explain tokenizer
2223
# Include cell for tokenizer inspection
2324

2425
# TEXT: explain how LLMs are trained for next token prediction
2526
# Write a function to predict next token
26-
27-
def predict_next_token(probs):
27+
def predict_next_token(probs, tokenizer):
2828
new_token = np.random.choice(len(probs), p=probs.numpy())
2929
print(tokenizer.decode(new_token), end='', flush=True)
3030
return new_token
3131

3232
# TEXT: explain that next token prediction must be called multiple times for inference
3333
# Call in loop for autoregressive inference
34-
35-
def generate(start_text, num_steps=20, temp=1.):
34+
def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
3635
print(start_text, end="")
3736
x = tokenizer.encode(start_text)
3837
num_start = len(x)
@@ -42,46 +41,43 @@ def generate(start_text, num_steps=20, temp=1.):
4241
logits = model(input_tensor).logits
4342
probs = tf.nn.softmax(logits/temp)[0, -1, :]
4443

45-
new_token = predict_next_token(probs)
44+
new_token = predict_next_token(probs, tokenizer)
4645
x.append(new_token)
4746

4847
output = tokenizer.decode(x[num_start:])
4948
return output
5049

5150
# Test autoregressive generation
52-
5351
# while True:
5452
# print("\n\n\n\n\n")
5553
# input_text = input("Prompt: ")
56-
# output = generate(input_text)
54+
# output = generate(input_text, model, tokenizer)
5755

5856
# TEXT: some background on LLM benchmarking
5957
# Load benchmark dataset and evaluate model
60-
6158
dataset = pd.read_csv("benchmark.csv")
62-
category_accs_1300m, avg_acc_1300m = run_benchmark(model, tokenizer)
59+
category_accs_1300m, avg_acc_1300m = run_benchmark(model, tokenizer, dataset)
6360

6461
# TEXT: ask them to make a prediction on how accuracy will be affected by different model sizes
6562

6663
# Benchmark smaller model
67-
6864
model_name_350m = "facebook/opt-350m"
69-
model_350m = transformers.TFAutoModelForCausalLM.from_pretrained(model_name_350m)
65+
model_350m = transformers.AutoModelForCausalLM.from_pretrained(model_name_350m, device_map="auto")
7066
tokenizer_350m = transformers.AutoTokenizer.from_pretrained(model_350m)
7167

72-
category_accs_350m, avg_acc_350m = run_benchmark(model_350m, tokenizer_350m)
68+
category_accs_350m, avg_acc_350m = run_benchmark(model_350m, tokenizer_350m, dataset)
7369

7470
# Benchmark larger model
75-
7671
model_name_2700m = "facebook/opt-2.7b"
77-
model_2700m = transformers.TFAutoModelForCausalLM.from_pretrained(model_name_2700m)
72+
model_2700m = transformers.AutoModelForCausalLM.from_pretrained(model_name_2700m, device_map="auto")
7873
tokenizer_2700m = transformers.AutoTokenizer.from_pretrained(model_2700m)
7974

80-
category_accs_2700m, avg_acc_2700m = run_benchmark(model_2700m, tokenizer_2700m)
75+
category_accs_2700m, avg_acc_2700m = run_benchmark(model_2700m, tokenizer_2700m, dataset)
8176

8277
# Spider plot
8378

84-
print(category_accs_1300m)
79+
benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "2700M-Model": category_accs_2700m}
80+
make_spider_plot(benchmark_data)
8581

8682
# Part 2
8783

xtra_labs/llm_finetune/spider.png

70.4 KB
Loading

xtra_labs/llm_finetune/utils.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
"""
22
Contains functions that the students will not interface with
33
"""
4-
4+
import matplotlib.pyplot as plt
55
import numpy as np
66
import pandas as pd
77
import tensorflow as tf
88
import torch
99
import torch.nn.functional as F
1010
from tqdm import tqdm
1111

12-
dataset = pd.read_csv("benchmark.csv")
13-
14-
def run_benchmark(model, tokenizer, few_shot=7, num_steps=500, verbose=False):
15-
# device = model.device
12+
def run_benchmark(model, tokenizer, dataset, few_shot=7, num_steps=500, verbose=False):
13+
device = model.device
1614
dataset["Correct"] = 0.0
1715

1816
# Loop through every question in the benchmark
@@ -32,9 +30,8 @@ def run_benchmark(model, tokenizer, few_shot=7, num_steps=500, verbose=False):
3230

3331
# Run the model
3432
with torch.no_grad():
35-
# converting to tensorflow tensor for model input
36-
x = tokenizer.encode(text, return_tensors="pt")
37-
logits = model(x).logits.numpy()
33+
x = tokenizer.encode(text, return_tensors="pt").to(device)
34+
logits = model(x).logits
3835
probs = F.softmax(logits, dim=-1)[0, :-1, :] # shape: [seq_len-1, vocab_size]
3936
y = x[0, 1:] # shape: [seq_len-1]
4037

@@ -60,4 +57,48 @@ def run_benchmark(model, tokenizer, few_shot=7, num_steps=500, verbose=False):
6057
sorted_accs = accs.sort_values()
6158
print(sorted_accs)
6259

63-
return sorted_accs, dataset["Correct"].mean()
60+
return accs, dataset["Correct"].mean()
61+
62+
def make_spider_plot(data):
63+
"""
64+
Data is a dictionary where keys are different entities
65+
Values are pd Series where series indices are plot labels and series values show performance
66+
"""
67+
colors = ['#1aaf6c', '#429bf4', '#d42cea']
68+
i = 0
69+
for k,v in data.items():
70+
labels = v.index.tolist()
71+
values = v.values.tolist()
72+
73+
num_vars = len(labels)
74+
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
75+
angles += angles[:1]
76+
values += values[:1]
77+
78+
fig, ax = plt.subplots(figsize=(8,6), subplot_kw=dict(polar=True))
79+
80+
ax.plot(angles, values, color=colors[i], linewidth=1, label=k)
81+
ax.fill(angles, values, color=colors[i], alpha=0.25)
82+
83+
i+=1
84+
85+
ax.set_theta_offset(np.pi / 2)
86+
ax.set_theta_direction(-1)
87+
ax.set_thetagrids(np.degrees(angles[:-1]), labels)
88+
for label, angle in zip(ax.get_xticklabels(), angles):
89+
if angle in (0, np.pi):
90+
label.set_horizontalalignment('center')
91+
elif 0 < angle < np.pi:
92+
label.set_horizontalalignment('left')
93+
else:
94+
label.set_horizontalalignment('right')
95+
96+
ax.set_ylim(0, 1)
97+
ax.set_rlabel_position(180 / num_vars)
98+
99+
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
100+
101+
plt.savefig("spider.png")
102+
103+
104+

0 commit comments

Comments
 (0)