diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 6705efee7a..709abba170 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -327,7 +327,28 @@ async def test_api(self): ) exps = self.model_wrapper.extract_experience_from_history() self.assertEqual(len(exps), 4) + for exp in exps: + self.assertTrue(len(exp.tokens) > 0) + self.assertTrue(len(exp.logprobs) > 0) + self.assertTrue(exp.prompt_length + len(exp.logprobs) == len(exp.tokens)) self.assertEqual(len(self.model_wrapper.extract_experience_from_history()), 0) + response = openai_client.chat.completions.create( + model=model_id, + messages=messages, + ) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 1) + self.assertTrue(len(exps[0].tokens) > 0) + self.assertTrue(len(exps[0].logprobs) > 0) + self.assertTrue(exps[0].prompt_length + len(exps[0].logprobs) == len(exps[0].tokens)) + response = openai_client.chat.completions.create( + model=model_id, + messages=messages, + logprobs=False, + ) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 1) + self.assertTrue(len(exps[0].logprobs) == 0) response = self.model_wrapper_no_history.get_openai_client().chat.completions.create( model=model_id, messages=messages, n=2 ) @@ -400,7 +421,28 @@ async def test_api_async(self): ) exps = self.model_wrapper.extract_experience_from_history() self.assertEqual(len(exps), 4) + for exp in exps: + self.assertTrue(len(exp.tokens) > 0) + self.assertTrue(len(exp.logprobs) > 0) + self.assertTrue(exp.prompt_length + len(exp.logprobs) == len(exp.tokens)) self.assertEqual(len(self.model_wrapper.extract_experience_from_history()), 0) + response = await openai_client.chat.completions.create( + model=model_id, + messages=messages, + ) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 1) + self.assertTrue(len(exps[0].tokens) > 0) + self.assertTrue(len(exps[0].logprobs) > 0) + self.assertTrue(exps[0].prompt_length + len(exps[0].logprobs) == len(exps[0].tokens)) + response = await openai_client.chat.completions.create( + model=model_id, + messages=messages, + logprobs=False, + ) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 1) + self.assertTrue(len(exps[0].logprobs) == 0) response = ( await self.model_wrapper_no_history.get_openai_async_client().chat.completions.create( model=model_id, messages=messages, n=2 diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index b003065dd4..2a709f2fa8 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -3,6 +3,7 @@ import asyncio import socket from abc import ABC, abstractmethod +from functools import partial from typing import Any, List, Optional, Sequence, Tuple, Union import httpx @@ -276,7 +277,7 @@ def get_openai_client(self) -> openai.OpenAI: ) if self.enable_history: # add a decorator to the openai client to record history - ori_create = self.openai_client.chat.completions.create + ori_create = partial(self.openai_client.chat.completions.create, logprobs=True) def record_chat_completions(*args, **kwargs): response = ori_create(*args, **kwargs) @@ -306,7 +307,7 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI: ) if self.enable_history: # add a decorator to the openai client to record history - ori_create = self.openai_async_client.chat.completions.create + ori_create = partial(self.openai_async_client.chat.completions.create, logprobs=True) async def record_chat_completions(*args, **kwargs): response = await ori_create(*args, **kwargs) diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index 359afa5b65..77a5d5ed7f 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -45,6 +45,11 @@ def __init__( self.chat_model: TrinityChatModel = TrinityChatModel( model.get_openai_async_client(), + generate_kwargs={ + "temperature": self.task.rollout_args.temperature, + "max_tokens": self.task.rollout_args.max_tokens or 4096, + "top_logprobs": self.task.rollout_args.logprobs, + }, ) def construct_experiences(