Skip to content

Commit 42d10b4

Browse files
authored
Fix openai client logprobs calculation (#347)
1 parent 7964099 commit 42d10b4

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

tests/common/vllm_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,28 @@ async def test_api(self):
327327
)
328328
exps = self.model_wrapper.extract_experience_from_history()
329329
self.assertEqual(len(exps), 4)
330+
for exp in exps:
331+
self.assertTrue(len(exp.tokens) > 0)
332+
self.assertTrue(len(exp.logprobs) > 0)
333+
self.assertTrue(exp.prompt_length + len(exp.logprobs) == len(exp.tokens))
330334
self.assertEqual(len(self.model_wrapper.extract_experience_from_history()), 0)
335+
response = openai_client.chat.completions.create(
336+
model=model_id,
337+
messages=messages,
338+
)
339+
exps = self.model_wrapper.extract_experience_from_history()
340+
self.assertEqual(len(exps), 1)
341+
self.assertTrue(len(exps[0].tokens) > 0)
342+
self.assertTrue(len(exps[0].logprobs) > 0)
343+
self.assertTrue(exps[0].prompt_length + len(exps[0].logprobs) == len(exps[0].tokens))
344+
response = openai_client.chat.completions.create(
345+
model=model_id,
346+
messages=messages,
347+
logprobs=False,
348+
)
349+
exps = self.model_wrapper.extract_experience_from_history()
350+
self.assertEqual(len(exps), 1)
351+
self.assertTrue(len(exps[0].logprobs) == 0)
331352
response = self.model_wrapper_no_history.get_openai_client().chat.completions.create(
332353
model=model_id, messages=messages, n=2
333354
)
@@ -400,7 +421,28 @@ async def test_api_async(self):
400421
)
401422
exps = self.model_wrapper.extract_experience_from_history()
402423
self.assertEqual(len(exps), 4)
424+
for exp in exps:
425+
self.assertTrue(len(exp.tokens) > 0)
426+
self.assertTrue(len(exp.logprobs) > 0)
427+
self.assertTrue(exp.prompt_length + len(exp.logprobs) == len(exp.tokens))
403428
self.assertEqual(len(self.model_wrapper.extract_experience_from_history()), 0)
429+
response = await openai_client.chat.completions.create(
430+
model=model_id,
431+
messages=messages,
432+
)
433+
exps = self.model_wrapper.extract_experience_from_history()
434+
self.assertEqual(len(exps), 1)
435+
self.assertTrue(len(exps[0].tokens) > 0)
436+
self.assertTrue(len(exps[0].logprobs) > 0)
437+
self.assertTrue(exps[0].prompt_length + len(exps[0].logprobs) == len(exps[0].tokens))
438+
response = await openai_client.chat.completions.create(
439+
model=model_id,
440+
messages=messages,
441+
logprobs=False,
442+
)
443+
exps = self.model_wrapper.extract_experience_from_history()
444+
self.assertEqual(len(exps), 1)
445+
self.assertTrue(len(exps[0].logprobs) == 0)
404446
response = (
405447
await self.model_wrapper_no_history.get_openai_async_client().chat.completions.create(
406448
model=model_id, messages=messages, n=2

trinity/common/models/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import socket
55
from abc import ABC, abstractmethod
6+
from functools import partial
67
from typing import Any, List, Optional, Sequence, Tuple, Union
78

89
import httpx
@@ -276,7 +277,7 @@ def get_openai_client(self) -> openai.OpenAI:
276277
)
277278
if self.enable_history:
278279
# add a decorator to the openai client to record history
279-
ori_create = self.openai_client.chat.completions.create
280+
ori_create = partial(self.openai_client.chat.completions.create, logprobs=True)
280281

281282
def record_chat_completions(*args, **kwargs):
282283
response = ori_create(*args, **kwargs)
@@ -306,7 +307,7 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
306307
)
307308
if self.enable_history:
308309
# add a decorator to the openai client to record history
309-
ori_create = self.openai_async_client.chat.completions.create
310+
ori_create = partial(self.openai_async_client.chat.completions.create, logprobs=True)
310311

311312
async def record_chat_completions(*args, **kwargs):
312313
response = await ori_create(*args, **kwargs)

trinity/common/workflows/agentscope_workflow.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def __init__(
4545

4646
self.chat_model: TrinityChatModel = TrinityChatModel(
4747
model.get_openai_async_client(),
48+
generate_kwargs={
49+
"temperature": self.task.rollout_args.temperature,
50+
"max_tokens": self.task.rollout_args.max_tokens or 4096,
51+
"top_logprobs": self.task.rollout_args.logprobs,
52+
},
4853
)
4954

5055
def construct_experiences(

0 commit comments

Comments
 (0)