forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_longbench_v2.py
More file actions
781 lines (645 loc) · 28.2 KB
/
eval_longbench_v2.py
File metadata and controls
781 lines (645 loc) · 28.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
#!/usr/bin/env python3
"""
LongBench v2 evaluation script with TensorRT-LLM and sparse attention.
Usage:
python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/
# Run all LongBench v2 samples
python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/
# Enable CoT reasoning
python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ --cot
# Run with different difficulty levels
python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ --difficulty easy
"""
import argparse
import json
import os
import re
import time
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from datasets import load_dataset
from transformers import AutoTokenizer
# Add tensorrt_llm imports
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
RocketSparseAttentionConfig)
from tensorrt_llm.logger import logger
# Chat templates mapping
CHAT_TEMPLATES = {
"llama3.1-8b-instruct": "llama3",
"llama3-8b-instruct": "llama3",
"mistral-7b-instruct-v0.2": "mistral",
"longchat-7b-v1.5-32k": "vicuna"
}
def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="LongBench v2 evaluation with TensorRT-LLM and RocketKV")
# Model and data arguments
parser.add_argument('--model_path',
type=str,
required=True,
help='Path to model (HF model name or local path)')
parser.add_argument('--longbench_path',
type=str,
default='./LongBench',
help='Path to LongBench directory')
# Output arguments
parser.add_argument('--output_dir',
type=str,
required=True,
help='Directory to save results')
parser.add_argument('--exp_name',
type=str,
default=None,
help='Experiment name (auto-generated if not provided)')
# Model configuration
parser.add_argument('--attention_backend',
type=str,
default='VANILLA',
choices=['VANILLA', 'TRTLLM', 'FLASHINFER'],
help='Attention backend to use')
parser.add_argument('--backend',
type=str,
default='pytorch',
choices=['pytorch', 'tensorrt'],
help='LLM backend to use')
parser.add_argument('--chat_template',
type=str,
default='auto',
help='Chat template to use (auto-detect if "auto")')
# Sequence and batch configuration
parser.add_argument('--max_seq_len',
type=int,
default=133120,
help='Maximum sequence length')
parser.add_argument('--max_batch_size',
type=int,
default=1,
help='Maximum batch size')
parser.add_argument('--max_new_tokens',
type=int,
default=256,
help='Maximum new tokens to generate')
parser.add_argument(
'--max_num_tokens',
type=int,
default=133120,
help='Maximum total tokens across all sequences in a batch')
parser.add_argument('--tensor_parallel_size',
type=int,
default=1,
help='Tensor parallel size')
# RocketKV configuration
parser.add_argument('--rocket_sparse',
action='store_true',
help='Use rocket sparse attention')
parser.add_argument('--token_budget',
type=int,
default=2048,
help='Token budget for RocketKV (prompt_budget)')
parser.add_argument('--window_size',
type=int,
default=32,
help='Window size for RocketKV')
parser.add_argument('--kernel_size',
type=int,
default=63,
help='Kernel size for RocketKV')
parser.add_argument('--topk',
type=int,
default=64,
help='Top-k for RocketKV')
parser.add_argument('--kt_cache_dtype',
type=str,
default='float8_e5m2',
choices=['bfloat16', 'float8_e5m2'],
help='KT cache data type')
# KV cache configuration
parser.add_argument('--kv_cache_dtype',
type=str,
default='auto',
help='KV cache data type')
parser.add_argument('--kv_cache_fraction',
type=float,
default=0.7,
help='Fraction of GPU memory for KV cache')
# LongBench v2 specific arguments
parser.add_argument('--cot',
action='store_true',
help='Enable Chain-of-Thought reasoning')
parser.add_argument('--no_context',
action='store_true',
help='Test without long context (pure memorization)')
parser.add_argument('--rag',
type=int,
default=0,
help='Use top-N retrieved contexts (0 to disable)')
# Evaluation parameters
parser.add_argument('--num_samples',
type=int,
default=None,
help='Number of samples to evaluate (None for all)')
parser.add_argument('--start_idx',
type=int,
default=0,
help='Start index for evaluation')
parser.add_argument('--difficulty',
type=str,
choices=['easy', 'hard'],
default=None,
help='Filter by difficulty level')
parser.add_argument('--length',
type=str,
choices=['short', 'medium', 'long'],
default=None,
help='Filter by length category')
parser.add_argument('--domain',
type=str,
default=None,
help='Filter by domain')
parser.add_argument('--max_len',
type=int,
default=120000,
help='Maximum prompt length in tokens for truncation')
# System arguments
parser.add_argument('--log_level',
type=str,
default='info',
choices=['debug', 'info', 'warning', 'error'],
help='Logging level')
parser.add_argument('--seed', type=int, default=42, help='Random seed')
return parser.parse_args()
def load_longbench_v2_config(longbench_path: str) -> Dict[str, Any]:
"""Load LongBench v2 configuration files."""
config_dir = os.path.join(longbench_path, "config")
# Load model2maxlen.json for v2
maxlen_file = os.path.join(config_dir, "model2maxlen.json")
with open(maxlen_file, 'r', encoding='utf-8') as f:
model2maxlen = json.load(f)
# Load prompt templates
prompts_dir = os.path.join(longbench_path, "prompts")
templates = {}
template_files = {
'0shot': '0shot.txt',
'0shot_cot': '0shot_cot.txt',
'0shot_cot_ans': '0shot_cot_ans.txt',
'0shot_no_context': '0shot_no_context.txt',
'0shot_rag': '0shot_rag.txt'
}
for template_name, filename in template_files.items():
template_path = os.path.join(prompts_dir, filename)
if os.path.exists(template_path):
with open(template_path, 'r', encoding='utf-8') as f:
templates[template_name] = f.read()
return {'model2maxlen': model2maxlen, 'templates': templates}
def build_chat(tokenizer, prompt, chat_template):
"""Build chat prompt following LongBench's approach."""
if chat_template == "vicuna" or chat_template == "longchat":
try:
from fastchat.model import get_conversation_template
conv = get_conversation_template("vicuna")
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
except ImportError:
# Fallback if fastchat is not available
prompt = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUSER: {prompt}\nASSISTANT:"
elif chat_template == "llama2":
prompt = f"[INST]{prompt}[/INST]"
elif chat_template == "llama3":
prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
elif chat_template == "mistral":
prompt = f"[INST] {prompt} [/INST]"
# For other templates or "none", return prompt as-is
return prompt
def determine_chat_template(model_path: str, chat_template: str) -> str:
"""Determine chat template based on model path."""
if chat_template != 'auto':
return chat_template
model_path_lower = model_path.lower()
for model_key, template in CHAT_TEMPLATES.items():
if model_key.replace('-', '').replace('.',
'') in model_path_lower.replace(
'-', '').replace('.', ''):
return template
# Default fallback
if 'llama' in model_path_lower:
return 'llama3'
elif 'mistral' in model_path_lower:
return 'mistral'
else:
return 'none' # No special formatting
def extract_answer(response: str) -> Optional[str]:
"""Extract answer from response following LongBench v2's approach."""
response = response.replace('*', '')
# Try to extract answer in format "The correct answer is (X)"
match = re.search(r'The correct answer is \(([A-D])\)', response)
if match:
return match.group(1)
# Try to extract answer in format "The correct answer is X"
match = re.search(r'The correct answer is ([A-D])', response)
if match:
return match.group(1)
# Try to extract any single letter A-D
match = re.search(r'\b([A-D])\b', response)
if match:
return match.group(1)
return None
def post_process(pred: str, chat_template: str) -> str:
"""Post-process prediction following LongBench v2's approach."""
pred = pred.split("</s")[0].strip()
if chat_template == "qwen":
pred = pred.split("<|im_end|>")[0]
elif "llama2" in chat_template.lower():
pred = (pred.split("(Document")[0].split("\n\nQuestion")[0].split(
"\n\nAnswer")[0].split("[INST]")[0].split("[/INST]")[0].split(
"(Passage")[0].strip())
return pred
def truncate_prompt(prompt: str, tokenizer: AutoTokenizer, max_len: int) -> str:
"""Truncate prompt following LongBench v2's approach."""
# Encode the prompt using the tokenizer
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
# If prompt exceeds max_len, truncate by taking first half and last half
if len(input_ids) > max_len:
half = max_len // 2
truncated_ids = input_ids[:half] + input_ids[-half:]
# Decode back to text
prompt = tokenizer.decode(truncated_ids, skip_special_tokens=True)
return prompt
def format_prompt(sample: Dict[str, Any], template: str,
args: argparse.Namespace) -> str:
"""Format prompt for LongBench v2."""
context = sample['context']
# Handle RAG mode
if args.rag > 0 and 'retrieved_context' in sample:
retrieved = sample["retrieved_context"][:args.rag]
retrieved = sorted(retrieved, key=lambda x: x.get('c_idx', 0))
context = '\n\n'.join([
f"Retrieved chunk {idx+1}: {x['content']}"
for idx, x in enumerate(retrieved)
])
# Handle no context mode
if args.no_context:
context = ""
# Format the prompt using the template
prompt = template.replace('$DOC$', context.strip())
prompt = prompt.replace('$Q$', sample['question'].strip())
prompt = prompt.replace('$C_A$', sample['choice_A'].strip())
prompt = prompt.replace('$C_B$', sample['choice_B'].strip())
prompt = prompt.replace('$C_C$', sample['choice_C'].strip())
prompt = prompt.replace('$C_D$', sample['choice_D'].strip())
return prompt
def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
"""Initialize LLM and tokenizer."""
logger.info(f"Initializing LLM with model: {args.model_path}")
try:
# Configure KV cache
kv_cache_config = KvCacheConfig(
enable_block_reuse=False, # RocketKV doesn't support KV cache reuse
)
if args.rocket_sparse:
# Configure RocketKV sparse attention
sparse_attention_config = RocketSparseAttentionConfig(
window_size=args.window_size,
kernel_size=args.kernel_size,
prompt_budget=args.token_budget,
topk=args.topk,
kt_cache_dtype=args.kt_cache_dtype,
)
logger.info(f"Using RocketKV sparse attention")
else:
sparse_attention_config = None
logger.info("Using standard attention")
cuda_graph_config = CudaGraphConfig(
max_batch_size=args.max_batch_size
) if args.attention_backend == "TRTLLM" else None
# Initialize LLM
llm = LLM(
model=args.model_path,
backend=args.backend,
kv_cache_config=kv_cache_config,
attn_backend=args.attention_backend,
sparse_attention_config=sparse_attention_config,
tensor_parallel_size=args.tensor_parallel_size,
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
cuda_graph_config=cuda_graph_config,
)
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
logger.info("LLM and tokenizer initialized successfully")
return llm, tokenizer
except Exception as e:
logger.error(f"Failed to initialize LLM: {e}")
raise e
def evaluate_longbench_v2(llm: LLM, tokenizer: AutoTokenizer,
args: argparse.Namespace) -> Tuple[List[Dict], float]:
"""Evaluate on LongBench v2 dataset."""
# Load LongBench v2 configuration
config = load_longbench_v2_config(args.longbench_path)
# Determine max_len for the model if not explicitly set via arguments
model_name = os.path.basename(args.model_path)
if model_name in config[
'model2maxlen']: # Use default from config if available
max_len = config['model2maxlen'][model_name]
logger.info(f"Using model-specific max_len: {max_len} for {model_name}")
else:
max_len = args.max_len
logger.info(f"Using max_len: {max_len}")
# Update args with the determined max_len
args.max_len = max_len
# Load dataset
logger.info(f"Loading LongBench v2 dataset...")
dataset = load_dataset('THUDM/LongBench-v2',
split='train',
trust_remote_code=True)
data = [item for item in dataset]
# Apply filters
filtered_data = data
if args.difficulty:
filtered_data = [
item for item in filtered_data
if item['difficulty'] == args.difficulty
]
logger.info(
f"Filtered by difficulty '{args.difficulty}': {len(filtered_data)} samples"
)
if args.length:
filtered_data = [
item for item in filtered_data if item['length'] == args.length
]
logger.info(
f"Filtered by length '{args.length}': {len(filtered_data)} samples")
if args.domain:
filtered_data = [
item for item in filtered_data if item['domain'] == args.domain
]
logger.info(
f"Filtered by domain '{args.domain}': {len(filtered_data)} samples")
# Apply start_idx and num_samples
if args.num_samples:
end_idx = min(args.start_idx + args.num_samples, len(filtered_data))
filtered_data = filtered_data[args.start_idx:end_idx]
else:
filtered_data = filtered_data[args.start_idx:]
logger.info(f"Final dataset size: {len(filtered_data)} samples")
# Determine chat template
chat_template = determine_chat_template(args.model_path, args.chat_template)
logger.info(f"Using chat template: {chat_template}")
logger.info(f"Prepare and truncate prompts...")
# Select appropriate template
if args.no_context:
template_key = '0shot_no_context'
elif args.rag > 0:
template_key = '0shot_rag'
elif args.cot:
template_key = '0shot_cot'
else:
template_key = '0shot'
template = config['templates'][template_key]
logger.info(f"Using template: {template_key}")
# Set up extra end token ids
extra_end_token_ids = []
if chat_template == "llama3":
eot_id = tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0]
extra_end_token_ids.append(eot_id)
logger.info(f"Added llama3 end token: {eot_id}")
if chat_template == "qwen":
im_end_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
extra_end_token_ids.append(im_end_id)
logger.info(f"Added qwen end token: {im_end_id}")
# Prepare prompts
prompts = []
for sample in filtered_data:
formatted_prompt = format_prompt(sample, template, args)
# Apply chat formatting if needed
if chat_template != 'none':
formatted_prompt = build_chat(tokenizer, formatted_prompt,
chat_template)
# Apply truncation if prompt is too long
formatted_prompt = truncate_prompt(formatted_prompt, tokenizer,
args.max_len)
prompts.append(formatted_prompt)
if len(prompts) == 0:
logger.warning(f"No prompts to evaluate")
return [], 0.0
# Run inference
logger.info(f"Starting inference for {len(prompts)} samples...")
start_time = time.time()
# Set sampling parameters
max_new_tokens = 1024 if args.cot else 256
sampling_params = SamplingParams(
max_tokens=max_new_tokens,
temperature=0.1,
top_p=0.95,
stop_token_ids=extra_end_token_ids if extra_end_token_ids else None,
)
outputs = llm.generate(prompts, sampling_params)
inference_time = time.time() - start_time
logger.info(
f"Inference completed in {inference_time:.2f} seconds, average time per sample: {inference_time/len(prompts):.3f} seconds"
)
# Process outputs
results = []
for i, (sample, output) in enumerate(zip(filtered_data, outputs)):
prediction = output.outputs[0].text.strip()
processed_prediction = post_process(prediction, chat_template)
# Handle CoT mode
if args.cot:
# For CoT, we need to do a second inference to extract the final answer
cot_response = processed_prediction
# Format the CoT answer extraction prompt
cot_ans_template = config['templates']['0shot_cot_ans']
cot_ans_prompt = format_prompt(sample, cot_ans_template, args)
cot_ans_prompt = cot_ans_prompt.replace('$COT$', cot_response)
if chat_template != 'none':
cot_ans_prompt = build_chat(tokenizer, cot_ans_prompt,
chat_template)
# Apply truncation to CoT answer extraction prompt
cot_ans_prompt = truncate_prompt(cot_ans_prompt, tokenizer,
args.max_len)
# Generate final answer
ans_sampling_params = SamplingParams(
max_tokens=128,
temperature=0.1,
top_p=0.95,
stop_token_ids=extra_end_token_ids
if extra_end_token_ids else None,
)
ans_output = llm.generate([cot_ans_prompt], ans_sampling_params)[0]
final_prediction = post_process(ans_output.outputs[0].text.strip(),
chat_template)
extracted_answer = extract_answer(final_prediction)
else:
extracted_answer = extract_answer(processed_prediction)
# Calculate accuracy
is_correct = extracted_answer == sample[
'answer'] if extracted_answer else False
result = {
'_id': sample['_id'],
'domain': sample['domain'],
'sub_domain': sample['sub_domain'],
'difficulty': sample['difficulty'],
'length': sample['length'],
'question': sample['question'],
'choice_A': sample['choice_A'],
'choice_B': sample['choice_B'],
'choice_C': sample['choice_C'],
'choice_D': sample['choice_D'],
'answer': sample['answer'],
'prediction': processed_prediction,
'extracted_answer': extracted_answer,
'is_correct': is_correct,
'context_length': len(sample['context']),
'prompt_length': len(output.prompt_token_ids),
'output_length': len(output.outputs[0].token_ids),
}
if args.cot:
result['cot_response'] = cot_response
result['final_prediction'] = final_prediction
results.append(result)
return results, inference_time
def calculate_metrics(results: List[Dict]) -> Dict[str, Any]:
"""Calculate evaluation metrics for LongBench v2."""
if not results:
return {}
total_samples = len(results)
correct_samples = sum(1 for r in results if r['is_correct'])
overall_accuracy = correct_samples / total_samples
metrics = {
'overall_accuracy': round(overall_accuracy * 100, 2),
'total_samples': total_samples,
'correct_samples': correct_samples
}
# Calculate metrics by difficulty
difficulties = ['easy', 'hard']
for difficulty in difficulties:
diff_results = [r for r in results if r['difficulty'] == difficulty]
if diff_results:
diff_correct = sum(1 for r in diff_results if r['is_correct'])
metrics[f'{difficulty}_accuracy'] = round(
(diff_correct / len(diff_results)) * 100, 2)
# Calculate metrics by length
lengths = ['short', 'medium', 'long']
for length in lengths:
len_results = [r for r in results if r['length'] == length]
if len_results:
len_correct = sum(1 for r in len_results if r['is_correct'])
metrics[f'{length}_accuracy'] = round(
(len_correct / len(len_results)) * 100, 2)
# Calculate metrics by domain
domains = list(set(r['domain'] for r in results))
for domain in domains:
domain_results = [r for r in results if r['domain'] == domain]
if domain_results:
domain_correct = sum(1 for r in domain_results if r['is_correct'])
metrics[f'{domain}_accuracy'] = round(
(domain_correct / len(domain_results)) * 100, 2)
return metrics
def save_results(results: List[Dict], args: argparse.Namespace,
inference_time: float, output_dir: str):
"""Save evaluation results in format compatible with LongBench v2."""
os.makedirs(output_dir, exist_ok=True)
# Calculate metrics
metrics = calculate_metrics(results)
logger.info(f"Evaluation metrics: {metrics}")
# Save detailed results
results_file = os.path.join(output_dir, "longbench_v2_results.jsonl")
with open(results_file, 'w', encoding='utf-8') as f:
for result in results:
json.dump(result, f, ensure_ascii=False)
f.write('\n')
# Save prediction results in LongBench v2 format
pred_file = os.path.join(output_dir, "predictions.jsonl")
with open(pred_file, 'w', encoding='utf-8') as f:
for result in results:
pred_data = {
"_id": result['_id'],
"prediction": result['extracted_answer'],
"response": result['prediction'],
"judge": result['is_correct']
}
if args.cot:
pred_data['cot_response'] = result.get('cot_response', '')
pred_data['final_prediction'] = result.get(
'final_prediction', '')
json.dump(pred_data, f, ensure_ascii=False)
f.write('\n')
# Create summary
summary = {
'experiment_config': {
'model_path': args.model_path,
'attention_backend': args.attention_backend,
'rocket_sparse': args.rocket_sparse,
'token_budget': args.token_budget,
'cot': args.cot,
'no_context': args.no_context,
'rag': args.rag,
'difficulty_filter': args.difficulty,
'length_filter': args.length,
'domain_filter': args.domain,
'max_seq_len': args.max_seq_len,
'max_new_tokens': args.max_new_tokens
},
'evaluation_results': metrics,
'timing': {
'total_inference_time': inference_time,
'avg_inference_time':
inference_time / len(results) if results else 0,
'evaluation_timestamp': datetime.now().isoformat()
}
}
# Save summary
summary_file = os.path.join(output_dir, "summary.json")
with open(summary_file, 'w', encoding='utf-8') as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
logger.info(f"Results saved to {output_dir}")
return metrics
def main():
"""Main evaluation function."""
args = parse_arguments()
logger.set_level(args.log_level)
# Setup experiment name
if not args.exp_name:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
model_name = os.path.basename(args.model_path).replace('/', '_')
args.exp_name = f"longbench_v2_{model_name}_{timestamp}"
output_dir = os.path.join(args.output_dir, args.exp_name)
logger.info(
"=========== LongBench v2 Evaluation with TensorRT-LLM ===========")
os.makedirs(output_dir, exist_ok=True)
# Save configuration
config_file = os.path.join(output_dir, "config.json")
with open(config_file, 'w') as f:
json.dump(vars(args), f, indent=2)
logger.info(f"Configuration saved to {config_file}")
# Initialize LLM and tokenizer
llm, tokenizer = initialize_llm(args)
# Run evaluation
logger.info(f"Starting LongBench v2 evaluation...")
results, inference_time = evaluate_longbench_v2(llm, tokenizer, args)
# Save results and get metrics
metrics = save_results(results, args, inference_time, output_dir)
logger.info(f"{'-'*80}")
logger.info(f"Evaluation completed successfully!")
logger.info(f"Total time: {inference_time:.2f} seconds")
logger.info(f"Total samples: {len(results)}")
if metrics:
logger.info(
f"Overall accuracy: {metrics.get('overall_accuracy', 'N/A')}%")
if 'easy_accuracy' in metrics:
logger.info(
f"Easy difficulty accuracy: {metrics['easy_accuracy']}% ({metrics.get('easy_samples', 0)} samples)"
)
if 'hard_accuracy' in metrics:
logger.info(
f"Hard difficulty accuracy: {metrics['hard_accuracy']}% ({metrics.get('hard_samples', 0)} samples)"
)
for length in ['short', 'medium', 'long']:
if f'{length}_accuracy' in metrics:
logger.info(
f"{length.capitalize()} length accuracy: {metrics[f'{length}_accuracy']}% ({metrics.get(f'{length}_samples', 0)} samples)"
)
if __name__ == '__main__':
main()