From 33f7fddc933d348a1fb670e65419d350e2d0a4c3 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 28 Oct 2025 16:37:16 +0800 Subject: [PATCH 1/3] fix openai client logprobs --- tests/common/vllm_test.py | 26 +++++++++++++++++++ trinity/common/models/model.py | 5 ++-- .../common/workflows/agentscope_workflow.py | 5 ++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 6705efee7a..67434cd7f0 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -327,7 +327,20 @@ async def test_api(self): ) exps = self.model_wrapper.extract_experience_from_history() self.assertEqual(len(exps), 4) + for exp in exps: + self.assertTrue(len(exp.tokens) > 0) + self.assertTrue(len(exp.logprobs) > 0) + self.assertTrue(exp.prompt_length + len(exp.logprobs) == len(exp.tokens)) self.assertEqual(len(self.model_wrapper.extract_experience_from_history()), 0) + response = openai_client.chat.completions.create( + model=model_id, + messages=messages, + ) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 1) + self.assertTrue(len(exps[0].tokens) > 0) + self.assertTrue(len(exps[0].logprobs) > 0) + self.assertTrue(exps[0].prompt_length + len(exps[0].logprobs) == len(exps[0].tokens)) response = self.model_wrapper_no_history.get_openai_client().chat.completions.create( model=model_id, messages=messages, n=2 ) @@ -400,7 +413,20 @@ async def test_api_async(self): ) exps = self.model_wrapper.extract_experience_from_history() self.assertEqual(len(exps), 4) + for exp in exps: + self.assertTrue(len(exp.tokens) > 0) + self.assertTrue(len(exp.logprobs) > 0) + self.assertTrue(exp.prompt_length + len(exp.logprobs) == len(exp.tokens)) self.assertEqual(len(self.model_wrapper.extract_experience_from_history()), 0) + response = await openai_client.chat.completions.create( + model=model_id, + messages=messages, + ) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 1) + self.assertTrue(len(exps[0].tokens) > 0) + self.assertTrue(len(exps[0].logprobs) > 0) + self.assertTrue(exps[0].prompt_length + len(exps[0].logprobs) == len(exps[0].tokens)) response = ( await self.model_wrapper_no_history.get_openai_async_client().chat.completions.create( model=model_id, messages=messages, n=2 diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index b003065dd4..2562284703 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -2,6 +2,7 @@ """Base Model Class""" import asyncio import socket +from functools import partial from abc import ABC, abstractmethod from typing import Any, List, Optional, Sequence, Tuple, Union @@ -276,7 +277,7 @@ def get_openai_client(self) -> openai.OpenAI: ) if self.enable_history: # add a decorator to the openai client to record history - ori_create = self.openai_client.chat.completions.create + ori_create = partial(self.openai_client.chat.completions.create, logprobs=True) def record_chat_completions(*args, **kwargs): response = ori_create(*args, **kwargs) @@ -306,7 +307,7 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI: ) if self.enable_history: # add a decorator to the openai client to record history - ori_create = self.openai_async_client.chat.completions.create + ori_create = partial(self.openai_async_client.chat.completions.create, logprobs=True) async def record_chat_completions(*args, **kwargs): response = await ori_create(*args, **kwargs) diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index 359afa5b65..c49cbaab49 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -45,6 +45,11 @@ def __init__( self.chat_model: TrinityChatModel = TrinityChatModel( model.get_openai_async_client(), + generate_kwargs={ + "temperature": self.task.rollout_args.temperature, + "max_tokens": self.task.rollout_args.max_tokens or 4096, + "top_logprobs": self.task.rollout_args.logprobs, + } ) def construct_experiences( From deac9700d6422643c1652dac38e650937da00e79 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 28 Oct 2025 16:40:53 +0800 Subject: [PATCH 2/3] fix pre-commit --- trinity/common/models/model.py | 2 +- trinity/common/workflows/agentscope_workflow.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 2562284703..2a709f2fa8 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -2,8 +2,8 @@ """Base Model Class""" import asyncio import socket -from functools import partial from abc import ABC, abstractmethod +from functools import partial from typing import Any, List, Optional, Sequence, Tuple, Union import httpx diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index c49cbaab49..77a5d5ed7f 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -49,7 +49,7 @@ def __init__( "temperature": self.task.rollout_args.temperature, "max_tokens": self.task.rollout_args.max_tokens or 4096, "top_logprobs": self.task.rollout_args.logprobs, - } + }, ) def construct_experiences( From b9fc13bae306a67fb4eed2a2686348273d5f9691 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 28 Oct 2025 16:49:51 +0800 Subject: [PATCH 3/3] add more tests --- tests/common/vllm_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 67434cd7f0..709abba170 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -341,6 +341,14 @@ async def test_api(self): self.assertTrue(len(exps[0].tokens) > 0) self.assertTrue(len(exps[0].logprobs) > 0) self.assertTrue(exps[0].prompt_length + len(exps[0].logprobs) == len(exps[0].tokens)) + response = openai_client.chat.completions.create( + model=model_id, + messages=messages, + logprobs=False, + ) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 1) + self.assertTrue(len(exps[0].logprobs) == 0) response = self.model_wrapper_no_history.get_openai_client().chat.completions.create( model=model_id, messages=messages, n=2 ) @@ -427,6 +435,14 @@ async def test_api_async(self): self.assertTrue(len(exps[0].tokens) > 0) self.assertTrue(len(exps[0].logprobs) > 0) self.assertTrue(exps[0].prompt_length + len(exps[0].logprobs) == len(exps[0].tokens)) + response = await openai_client.chat.completions.create( + model=model_id, + messages=messages, + logprobs=False, + ) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 1) + self.assertTrue(len(exps[0].logprobs) == 0) response = ( await self.model_wrapper_no_history.get_openai_async_client().chat.completions.create( model=model_id, messages=messages, n=2