Skip to content

Commit d0a300a

Browse files
authored
evals: log reasoning and extend max_tokens for chat completions (#62)
1 parent 754a56b commit d0a300a

File tree

4 files changed

+52
-7
lines changed

4 files changed

+52
-7
lines changed

gpt_oss/evals/__main__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from datetime import datetime
44

55
from . import report
6+
from .basic_eval import BasicEval
67
from .gpqa_eval import GPQAEval
78
from .aime_eval import AIME25Eval
89
from .healthbench_eval import HealthBenchEval
@@ -81,6 +82,7 @@ def main():
8182
reasoning_effort=reasoning_effort,
8283
temperature=args.temperature,
8384
base_url=args.base_url,
85+
max_tokens=131_072,
8486
)
8587

8688
print(f"Running with args {args}")
@@ -98,9 +100,11 @@ def get_evals(eval_name, debug_mode):
98100
)
99101
# Set num_examples = None to reproduce full evals
100102
match eval_name:
103+
case "basic":
104+
return BasicEval()
101105
case "gpqa":
102106
return GPQAEval(
103-
n_repeats=8,
107+
n_repeats=1 if args.debug else 8,
104108
num_examples=num_examples,
105109
debug=debug_mode,
106110
n_threads=args.n_threads or 1,
@@ -131,7 +135,7 @@ def get_evals(eval_name, debug_mode):
131135
)
132136
case "aime25":
133137
return AIME25Eval(
134-
n_repeats=8,
138+
n_repeats=1 if args.debug else 8,
135139
num_examples=num_examples,
136140
n_threads=args.n_threads or 1,
137141
)

gpt_oss/evals/basic_eval.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
Basic eval
3+
"""
4+
from . import report
5+
6+
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
7+
8+
class BasicEval(Eval):
9+
def __init__(self,):
10+
self.examples = [{
11+
"question": "hi",
12+
"answer": "hi, how can i help?",
13+
}]
14+
15+
def __call__(self, sampler: SamplerBase) -> EvalResult:
16+
def fn(row: dict):
17+
sampler_response = sampler([
18+
sampler._pack_message(content=row["question"], role="user")
19+
])
20+
response_text = sampler_response.response_text
21+
extracted_answer = response_text
22+
actual_queried_prompt_messages = sampler_response.actual_queried_message_list
23+
score = 1.0 if len(extracted_answer) > 0 else 0.0
24+
html = report.jinja_env.from_string(report.HTML_JINJA).render(
25+
prompt_messages=actual_queried_prompt_messages,
26+
next_message=dict(content=response_text, role="assistant"),
27+
score=score,
28+
correct_answer=row["answer"],
29+
extracted_answer=extracted_answer,
30+
)
31+
convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
32+
return SingleEvalResult(
33+
html=html, score=score, convo=convo, metrics={"chars": len(response_text)}
34+
)
35+
36+
results = report.map_with_progress(fn, self.examples, num_threads=1)
37+
return report.aggregate_results(results)
38+

gpt_oss/evals/chat_completions_sampler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def __init__(
2727
reasoning_effort: str | None = None,
2828
base_url: str = "http://localhost:8000/v1",
2929
):
30-
self.api_key_name = "OPENAI_API_KEY"
3130
self.client = OpenAI(base_url=base_url, timeout=24 * 60 * 60)
3231
self.model = model
3332
self.system_message = system_message
@@ -63,8 +62,13 @@ def __call__(self, message_list: MessageList) -> SamplerResponse:
6362
temperature=self.temperature,
6463
max_tokens=self.max_tokens,
6564
)
66-
content = response.choices[0].message.content
67-
if content is None:
65+
66+
choice = response.choices[0]
67+
content = choice.message.content
68+
if getattr(choice.message, "reasoning", None):
69+
message_list.append(self._pack_message("assistant", choice.message.reasoning))
70+
71+
if not content:
6872
raise ValueError("OpenAI API returned empty response; retrying")
6973
return SamplerResponse(
7074
response_text=content,

gpt_oss/evals/responses_sampler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ def __init__(
1717
model: str,
1818
developer_message: str | None = None,
1919
temperature: float = 1.0,
20-
max_tokens: int = 1024,
20+
max_tokens: int = 131_072,
2121
reasoning_model: bool = False,
2222
reasoning_effort: str | None = None,
2323
base_url: str = "http://localhost:8000/v1",
2424
):
25-
self.api_key_name = "OPENAI_API_KEY"
2625
self.client = OpenAI(base_url=base_url, timeout=24*60*60)
2726
self.model = model
2827
self.developer_message = developer_message

0 commit comments

Comments
 (0)