Skip to content

Commit 1068af1

Browse files
committed
update eval
1 parent f22d72f commit 1068af1

File tree

2 files changed

+558
-9
lines changed

2 files changed

+558
-9
lines changed

llm_judge/gen_model_answer_medusa.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from medusa.model.kv_cache import initialize_past_key_values
2525
from medusa.model.medusa_choices import *
2626

27-
def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, max_steps = 512):
27+
def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, top_p=0.8, sampling = 'typical', fast = True, max_steps = 512):
2828
assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
2929
# Avoid modifying the input_ids in-place
3030
input_ids = input_ids.clone()
@@ -71,6 +71,7 @@ def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, pos
7171
logits,
7272
medusa_buffers["tree_indices"],
7373
medusa_buffers["retrieve_indices"],
74+
temperature, posterior_threshold, posterior_alpha, top_p, sampling, fast
7475
)
7576
medusa_logits, logits, outputs = tree_decoding(
7677
model,
@@ -81,7 +82,7 @@ def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, pos
8182
medusa_buffers["retrieve_indices"],
8283
)
8384
best_candidate, accept_length = evaluate_posterior(
84-
logits, candidates, temperature, posterior_threshold, posterior_alpha
85+
logits, candidates, temperature, posterior_threshold, posterior_alpha , top_p, sampling, fast
8586
)
8687
input_ids, logits, medusa_logits, new_token = update_inference_inputs(
8788
input_ids,
@@ -117,6 +118,9 @@ def run_eval(
117118
temperature,
118119
posterior_threshold,
119120
posterior_alpha,
121+
top_p,
122+
sampling,
123+
fast,
120124
medusa_choices,
121125
):
122126
questions = load_questions(question_file, question_begin, question_end)
@@ -153,6 +157,9 @@ def run_eval(
153157
temperature,
154158
posterior_threshold,
155159
posterior_alpha,
160+
sampling,
161+
top_p,
162+
fast,
156163
medusa_choices,
157164
)
158165
)
@@ -174,15 +181,22 @@ def get_model_answers(
174181
temperature,
175182
posterior_threshold,
176183
posterior_alpha,
184+
sampling,
185+
top_p,
186+
fast,
177187
medusa_choices,
178188
):
179189

180190
# Medusa model setup
181-
num_heads = 4
191+
192+
num_heads = -1
193+
for choice in medusa_choices:
194+
if len(choice) > num_heads:
195+
num_heads = len(choice)
182196

183197
model = MedusaModel.from_pretrained(
184198
model_path,
185-
medusa_num_heads = num_heads,
199+
# medusa_num_heads = num_heads,
186200
torch_dtype=torch.float16,
187201
low_cpu_mem_usage=True,
188202
device_map="auto"
@@ -200,7 +214,7 @@ def get_model_answers(
200214

201215
# warmup
202216
for _ in range(3):
203-
torch.manual_seed(0)
217+
# torch.manual_seed(0)
204218
conv = get_conversation_template(model_id)
205219
turns = []
206220
idxs = []
@@ -227,9 +241,12 @@ def get_model_answers(
227241
model,
228242
tokenizer,
229243
medusa_choices,
230-
temperature,
244+
0.7,
231245
posterior_threshold,
232246
posterior_alpha,
247+
top_p=top_p,
248+
sampling=sampling,
249+
fast = fast,
233250
)
234251
torch.cuda.synchronize()
235252
total_time = time.time() - start_time
@@ -261,6 +278,7 @@ def get_model_answers(
261278
if conv.name == "xgen" and output.startswith("Assistant:"):
262279
output = output.replace("Assistant:", "", 1).strip()
263280
except RuntimeError as e:
281+
print(e)
264282
print("ERROR question ID: ", question["question_id"])
265283
output = "ERROR"
266284

@@ -280,7 +298,7 @@ def get_model_answers(
280298

281299
choices = []
282300
for i in range(num_choices):
283-
torch.manual_seed(i)
301+
# torch.manual_seed(i)
284302
conv = get_conversation_template(model_id)
285303
turns = []
286304
idxs = []
@@ -310,6 +328,9 @@ def get_model_answers(
310328
temperature,
311329
posterior_threshold,
312330
posterior_alpha,
331+
top_p=top_p,
332+
sampling=sampling,
333+
fast = fast,
313334
)
314335
torch.cuda.synchronize()
315336
total_time = time.time() - start_time
@@ -456,19 +477,39 @@ def reorg_answer_file(answer_file):
456477
help="The posterior alpha for medusa sampling.",
457478
)
458479

480+
parser.add_argument(
481+
"--top-p",
482+
type=float,
483+
default=0.8,
484+
help="The top-p for medusa sampling.",
485+
)
486+
487+
parser.add_argument(
488+
"--sampling",
489+
type=str,
490+
default="typical",
491+
help="The sampling method for medusa sampling.",
492+
)
493+
494+
parser.add_argument(
495+
"--fast",
496+
action="store_true",
497+
help="Whether to use fast decoding.",
498+
)
499+
459500
parser.add_argument(
460501
"--medusa-choices",
461502
type=str,
462503
default="mc_sim_7b_63",
463504
help="The medusa choices for medusa sampling.",
464505
)
465506

466-
507+
467508

468509

469510
args = parser.parse_args()
470511

471-
args.model_id = args.model_id+"-temperature-"+str(args.temperature)+"-posterior_threshold-"+str(args.posterior_threshold)+"-posterior_alpha-"+str(args.posterior_alpha)
512+
args.model_id = args.model_id+"-temperature-"+str(args.temperature)+"-posterior_threshold-"+str(args.posterior_threshold)+"-posterior_alpha-"+str(args.posterior_alpha)+"-top_p-"+str(args.top_p)+"-sampling-"+args.sampling+"-fast-"+str(args.fast)
472513
args.medusa_choices = eval(args.medusa_choices)
473514
if args.num_gpus_total // args.num_gpus_per_model > 1:
474515
import ray
@@ -499,6 +540,9 @@ def reorg_answer_file(answer_file):
499540
args.temperature,
500541
args.posterior_threshold,
501542
args.posterior_alpha,
543+
args.top_p,
544+
args.sampling,
545+
args.fast,
502546
args.medusa_choices,
503547
)
504548

0 commit comments

Comments
 (0)