-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathrun-openai.py
More file actions
133 lines (124 loc) · 4.46 KB
/
run-openai.py
File metadata and controls
133 lines (124 loc) · 4.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import os
import typing
from functools import partial
import openai
import pfgen
def callback(
tasks: list[dict[str, str]],
params: dict[str, typing.Any],
extra_eos_tokens: list[str] | None,
add_no_think: bool,
) -> typing.Iterator[str | None]:
mode = params["mode"]
temperature = params["temperature"]
kwargs: dict[str, typing.Any] = {}
kwargs["base_url"] = os.getenv("OPENAI_BASE_URL")
kwargs["api_key"] = os.getenv("OPENAI_API_KEY")
client = openai.OpenAI(**kwargs)
for task in tasks:
kwargs = {}
if mode == "chat":
system_prompt = "/no_think\n" if add_no_think else ""
system_prompt += task["system_prompt"]
kwargs["messages"] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": task["user_prompt"]},
]
elif mode == "qa":
prompt = "/no_think\n" if add_no_think else ""
prompt += task["prompt"]
kwargs["messages"] = [{"role": "user", "content": prompt}]
elif mode == "completion":
kwargs["prompt"] = "/no_think\n" if add_no_think else ""
kwargs["prompt"] += task["prompt"]
else:
raise ValueError(f"Unsupported mode: {mode}")
try:
stop = params.get("stop", [])
if extra_eos_tokens is not None:
stop.extend(extra_eos_tokens)
stop = list(set(stop))
if mode in ["qa", "chat"]:
results = client.chat.completions.create(
model=params["model"],
max_tokens=params["max_tokens"],
temperature=temperature,
top_p=params["top_p"],
stop=stop,
**kwargs,
)
yield results.choices[0].message.content.removeprefix("A:").strip()
elif mode == "completion":
results = client.completions.create(
model=params["model"],
max_tokens=params["max_tokens"],
temperature=temperature,
top_p=params["top_p"],
stop=stop,
stream=False,
**kwargs,
)
yield results.choices[0].text.strip()
except openai.OpenAIError as e:
print(f"API Error: {e}")
yield None
if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--mode",
type=str,
default="qa",
choices=["chat", "qa", "completion"],
help="Which chat template to use.",
)
parser.add_argument(
"--model",
type=str,
default="openai/gpt-4o",
help="OpenAI model name.",
)
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling.")
parser.add_argument("--num-trials", type=int, default=10, help="Number of trials to run.")
parser.add_argument("--top-p", type=float, default=0.98, help="Top-p for sampling.")
parser.add_argument(
"--max-tokens",
type=int,
help="Maximum tokens to generate (overrides default).",
)
parser.add_argument("--extra-eos-tokens", type=str, nargs="+", help="Extra EOS strings")
parser.add_argument(
"--disable-thinking",
action="store_true",
help="Disable reasoning when generation by Qwen3 models",
)
parser.add_argument("--num-retries", type=int, default=10, help="Number of retries.")
parser.add_argument(
"--ignore-failure",
action="store_true",
default=False,
help="Do not throw an exception if answer generation fails.",
)
args = parser.parse_args()
wrapped_callback = partial(
callback,
extra_eos_tokens=args.extra_eos_tokens,
add_no_think=args.disable_thinking,
)
# Prepare optional kwargs
extra_kwargs = {}
if args.max_tokens is not None:
extra_kwargs["max_tokens"] = args.max_tokens
pfgen.run_tasks(
args.mode,
wrapped_callback,
engine="openai-api",
model=args.model,
temperature=args.temperature,
top_p=args.top_p,
num_trials=args.num_trials,
enable_thinking=not args.disable_thinking,
num_retries=args.num_retries,
ignore_failure=args.ignore_failure,
**extra_kwargs,
)