-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluate_model_api.py
More file actions
51 lines (46 loc) · 2.52 KB
/
evaluate_model_api.py
File metadata and controls
51 lines (46 loc) · 2.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
from llmtf.model import ApiVLLMModel, ApiVLLMModelReasoning
from llmtf.evaluator import Evaluator
import os
import torch
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--base_url')
parser.add_argument('--model_name_or_path')
parser.add_argument('--api_key')
parser.add_argument('--output_dir')
parser.add_argument('--disable_thinking', action='store_true')
parser.add_argument('--dataset_names', nargs='+', default='all')
parser.add_argument('--few_shot_count', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--max_sample_per_dataset', type=int, default=10000000000000)
parser.add_argument('--max_prompt_len', type=int, default=4000)
parser.add_argument('--max_new_tokens_reasoning', type=int, default=3000)
parser.add_argument('--force_recalc', action='store_true')
parser.add_argument('--name_suffix', type=str, default=None)
parser.add_argument('--temperature', type=float, default=0.0)
parser.add_argument('--repetition_penalty', type=float, default=1.0)
parser.add_argument('--presence_penalty', type=float, default=0.0)
parser.add_argument('--num_return_sequences', type=int, default=1)
parser.add_argument('--end_thinking_token_id', type=int, default=None)
args = parser.parse_args()
os.environ['OPENAI_API_KEY'] = args.api_key
evaluator = Evaluator()
if args.disable_thinking:
model = ApiVLLMModel(api_base=args.base_url)
model.from_pretrained(args.model_name_or_path)
else:
model = ApiVLLMModelReasoning(api_base=args.base_url)
model.from_pretrained(
args.model_name_or_path,
max_new_tokens_reasoning=args.max_new_tokens_reasoning,
end_thinking_token_id=args.end_thinking_token_id
)
model.generation_config.temperature = args.temperature
model.generation_config.repetition_penalty = args.repetition_penalty
model.generation_config.presence_penalty = args.presence_penalty
model.generation_config.num_return_sequences = args.num_return_sequences
model.generation_config.do_sample = True
if args.temperature == 0.0:
model.generation_config.do_sample = False
evaluator.evaluate(model, args.output_dir, args.dataset_names, args.max_prompt_len, args.few_shot_count, batch_size=args.batch_size, max_sample_per_dataset=args.max_sample_per_dataset, force_recalc=args.force_recalc, name_suffix=args.name_suffix, enable_thinking=not args.disable_thinking)