Skip to content

Commit 6294228

Browse files
committed
update medusa eval
1 parent 1593482 commit 6294228

File tree

6 files changed

+1984
-501
lines changed

6 files changed

+1984
-501
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,5 @@ notebooks/test*.ipynb
172172
notebooks/*.pdf
173173
llm_judge/*.sh
174174
llm_judge/data/mt_bench_test
175-
data
175+
data
176+
medusa/eval/*.sh

medusa/eval/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
We use [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca_eval/blob/0cd24d711fe90d0c1aae5bde03fe98ee48ae52f8/alpaca_eval.json) dataset for evaluating each head's accuracy during generation in `heads_accuracy.py`.
3+
4+
```
5+
python heads_accuracy.py --model_path 'FasterDecoding/medusa-vicuna-7b-v1.3' --model_name 'medusa-vicuna-7b-v1.3' --medusa_num_heads 5 --data_path '../../data/alpaca_eval.json'
6+
```
7+
8+
9+
To create the tree and plot the tree (requires `pygraphviz` package), please run:
10+
11+
```
12+
python gen_results.py --accuracy-path '../../data/medusa-vicuna-7b-v1.3_heads_accuracy.pt' --output-path '../../data/graph.jpg'
13+
```
14+
15+
If you want to use the tree, please add the generated tree (in a nested tuple) to `../model/medusa_choices.py`.
16+
17+
Citation:
18+
19+
```
20+
@misc{alpaca_eval,
21+
author = {Xuechen Li and Tianyi Zhang and Yann Dubois and Rohan Taori and Ishaan Gulrajani and Carlos Guestrin and Percy Liang and Tatsunori B. Hashimoto },
22+
title = {AlpacaEval: An Automatic Evaluator of Instruction-following Models},
23+
year = {2023},
24+
publisher = {GitHub},
25+
journal = {GitHub repository},
26+
howpublished = {\url{https://github.com/tatsu-lab/alpaca_eval}}
27+
}```

medusa/eval/gen_results.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import matplotlib.pyplot as plt
2+
import copy
3+
import networkx as nx
4+
import torch
5+
import argparse
6+
7+
def load_accuracy_table(path):
8+
test_accuracy = torch.load(path)
9+
accuracy_table = []
10+
for i in range(len(test_accuracy)):
11+
accuracy_table.append(test_accuracy[i].sum(0)/16100)
12+
return torch.stack(accuracy_table)
13+
14+
def get_node_expectation(accuracies, node):
15+
expectation = copy.deepcopy(accuracies[0, node[0]])
16+
for i in range(1, len(node)):
17+
expectation *= accuracies[i, node[i]]
18+
return expectation
19+
20+
def explore_graph(accuracies, max_depth, max_child, num_iterations):
21+
explored_nodes = {}
22+
accept_nodes = [tuple([0])]
23+
expectations = get_node_expectation(accuracies, accept_nodes[0])
24+
explored_nodes[tuple(accept_nodes[0])] = expectations
25+
26+
for _ in range(num_iterations):
27+
# find neighbors
28+
neighbors = []
29+
for node in accept_nodes:
30+
if node[-1] < max_child[len(node) - 1] - 1:
31+
neighbor = list(copy.deepcopy(node))
32+
neighbor[-1] = neighbor[-1] + 1
33+
neighbors.append(neighbor)
34+
if len(node) < max_depth:
35+
neighbor = list(copy.deepcopy(node))
36+
neighbor.append(0)
37+
neighbors.append(neighbor)
38+
39+
# find the best neighbor
40+
best_neighbor = None
41+
best_neighbor_expectation = 0
42+
for neighbor in neighbors:
43+
if tuple(neighbor) in accept_nodes:
44+
continue
45+
if tuple(neighbor) in explored_nodes:
46+
neighbor_expectation = explored_nodes[tuple(neighbor)]
47+
else:
48+
neighbor_expectation = get_node_expectation(accuracies, neighbor)
49+
explored_nodes[tuple(neighbor)] = neighbor_expectation
50+
if neighbor_expectation > best_neighbor_expectation:
51+
best_neighbor = neighbor
52+
best_neighbor_expectation = neighbor_expectation
53+
accept_nodes.append(tuple(best_neighbor))
54+
expectations += best_neighbor_expectation
55+
56+
return accept_nodes
57+
58+
def plot_and_save_graph(accept_nodes, output_path):
59+
plt.figure(figsize=(40, 20))
60+
61+
G = nx.DiGraph()
62+
63+
for path in accept_nodes:
64+
for i in range(len(path)):
65+
if i == 0:
66+
parent = 'root'
67+
else:
68+
parent = tuple(path[:i])
69+
child = tuple(path[:i+1])
70+
G.add_edge(parent, child)
71+
72+
pos = nx.nx_agraph.graphviz_layout(G, prog='dot')
73+
nx.draw(G, pos, with_labels=True, node_size=500, node_color="skyblue", font_size=10, width=2, edge_color="gray")
74+
plt.savefig(output_path)
75+
76+
def main():
77+
parser = argparse.ArgumentParser(description="Generate Results.")
78+
parser.add_argument('--accuracy-path', type=str, required=True, help="Path to load accuracy tensor.")
79+
parser.add_argument('--output-path', type=str, required=True, help="Path to save the generated graph.")
80+
parser.add_argument('--max-depth', type=int, default=5, help="Maximum depth of the graph.")
81+
parser.add_argument('--num-iterations', type=int, default=62, help="Number of exploration iterations.")
82+
parser.add_argument('--max-child', nargs='+', type=int, default=[10, 10, 10, 10, 10], help="Maximum number of children per depth.")
83+
84+
args = parser.parse_args()
85+
86+
accuracies = load_accuracy_table(args.accuracy_path)
87+
accept_nodes = explore_graph(accuracies, args.max_depth, args.max_child, args.num_iterations)
88+
89+
print("Accepted Nodes:", accept_nodes)
90+
91+
try:
92+
plot_and_save_graph(accept_nodes, args.output_path)
93+
print(f"Graph saved to {args.output_path}.")
94+
except Exception as e:
95+
print(f"Failed to save the graph due to the following error: {e}")
96+
print("Ensure that Graphviz and pygraphviz are installed and set up correctly.")
97+
98+
if __name__ == "__main__":
99+
main()

medusa/eval/heads_accuracy.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import os
2+
import torch
3+
import json
4+
from contextlib import contextmanager
5+
import numpy as np
6+
from medusa.model.medusa_model import MedusaModel
7+
from medusa.model.kv_cache import *
8+
from medusa.model.utils import *
9+
from medusa.model.medusa_choices import *
10+
from copy import deepcopy
11+
import matplotlib.pyplot as plt
12+
import torch.nn.functional as F
13+
from fastchat.model.model_adapter import get_conversation_template
14+
from tqdm import tqdm
15+
import argparse
16+
17+
def get_accuracies(medusa, logit):
18+
# get the correct counts of each head
19+
seq_len, choices, topk = medusa.shape
20+
results = []
21+
for choice in range(choices):
22+
results.append(medusa[:-choice - 1,choice].eq(logit[choice + 1:,0]))
23+
return results
24+
25+
26+
27+
def main(args):
28+
model = MedusaModel.from_pretrained(
29+
args.model_path,
30+
medusa_num_heads=args.medusa_num_heads,
31+
torch_dtype=torch.float16,
32+
low_cpu_mem_usage=True,
33+
device_map="auto"
34+
)
35+
tokenizer = model.get_tokenizer()
36+
37+
38+
data = json.load(open(args.data_path))
39+
past_key_values, past_key_values_data, current_length_data = initialize_past_key_values(model.base_model, model.medusa_num_decoder_layers)
40+
model.past_key_values = past_key_values
41+
model.past_key_values_data = past_key_values_data
42+
model.current_length_data = current_length_data
43+
results = None
44+
45+
for sample in tqdm((data)):
46+
conv = get_conversation_template("vicuna")
47+
conv.messages = []
48+
conv.append_message(conv.roles[0], sample["instruction"])
49+
conv.append_message(conv.roles[1], "")
50+
prompt = conv.get_prompt()
51+
steps = args.steps
52+
logits_ids = []
53+
medusa_topk_ids = []
54+
55+
with torch.inference_mode():
56+
input_ids = tokenizer([prompt]).input_ids
57+
input_ids = torch.as_tensor(input_ids).cuda()
58+
model.current_length_data.zero_() # this is for rerun
59+
reset_medusa_mode(model)
60+
medusa_logits, outputs, logits = model(
61+
input_ids, past_key_values=past_key_values, output_orig=True
62+
)
63+
_, medusa_topk = medusa_logits[...,-1,:].topk(20, dim=-1)
64+
input_id = logits[:, -1:].argmax(dim=-1)
65+
logits_ids.append(input_id.detach().cpu())
66+
medusa_topk_ids.append(medusa_topk.detach().cpu())
67+
for _ in range(steps):
68+
medusa_logits, outputs, logits = model(
69+
input_id, past_key_values=past_key_values, output_orig=True
70+
)
71+
_, medusa_topk = medusa_logits[...,-1,:].topk(20, dim=-1)
72+
input_id = logits[:, -1:].argmax(dim=-1)
73+
logits_ids.append(input_id.detach().cpu())
74+
medusa_topk_ids.append(medusa_topk.detach().cpu())
75+
logits_ids = torch.stack(logits_ids, dim=0)
76+
medusa_topk_ids = torch.stack(medusa_topk_ids, dim=0).squeeze(2)
77+
if results is None:
78+
results = get_accuracies(medusa_topk_ids, logits_ids)
79+
else:
80+
# cat sub results
81+
cur_results = get_accuracies(medusa_topk_ids, logits_ids)
82+
for i in range(len(results)):
83+
results[i] = torch.cat((results[i], cur_results[i]), dim=0)
84+
85+
save_path = os.path.join(args.save_dir, args.model_name + "_heads_accuracy.pt")
86+
torch.save(results, save_path)
87+
88+
if __name__ == "__main__":
89+
parser = argparse.ArgumentParser(description="Medusa Model Evaluator")
90+
91+
parser.add_argument("--model_path", type=str, required=True,
92+
help="Path to the pre-trained Medusa model.")
93+
parser.add_argument("--model_name", type=str, required=True,
94+
help="Name of the model.")
95+
parser.add_argument("--medusa_num_heads", type=int, default=5,
96+
help="Number of medusa heads.")
97+
parser.add_argument("--data_path", type=str, required=True,
98+
help="Path to the evaluation data in JSON format.")
99+
parser.add_argument("--save_dir", type=str, default="../../data",
100+
help="Directory to save the results.")
101+
parser.add_argument("--steps", type=int, default=20,
102+
help="Number of steps to run the model.")
103+
args = parser.parse_args()
104+
105+
# If the save directory doesn't exist, create it
106+
if not os.path.exists(args.save_dir):
107+
os.makedirs(args.save_dir)
108+
main(args)

0 commit comments

Comments
 (0)