Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,28 @@ async def test_api(self):
)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 4)
for exp in exps:
self.assertTrue(len(exp.tokens) > 0)
self.assertTrue(len(exp.logprobs) > 0)
self.assertTrue(exp.prompt_length + len(exp.logprobs) == len(exp.tokens))
self.assertEqual(len(self.model_wrapper.extract_experience_from_history()), 0)
response = openai_client.chat.completions.create(
model=model_id,
messages=messages,
)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 1)
self.assertTrue(len(exps[0].tokens) > 0)
self.assertTrue(len(exps[0].logprobs) > 0)
self.assertTrue(exps[0].prompt_length + len(exps[0].logprobs) == len(exps[0].tokens))
response = openai_client.chat.completions.create(
model=model_id,
messages=messages,
logprobs=False,
)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 1)
self.assertTrue(len(exps[0].logprobs) == 0)
response = self.model_wrapper_no_history.get_openai_client().chat.completions.create(
model=model_id, messages=messages, n=2
)
Expand Down Expand Up @@ -400,7 +421,28 @@ async def test_api_async(self):
)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 4)
for exp in exps:
self.assertTrue(len(exp.tokens) > 0)
self.assertTrue(len(exp.logprobs) > 0)
self.assertTrue(exp.prompt_length + len(exp.logprobs) == len(exp.tokens))
self.assertEqual(len(self.model_wrapper.extract_experience_from_history()), 0)
response = await openai_client.chat.completions.create(
model=model_id,
messages=messages,
)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 1)
self.assertTrue(len(exps[0].tokens) > 0)
self.assertTrue(len(exps[0].logprobs) > 0)
self.assertTrue(exps[0].prompt_length + len(exps[0].logprobs) == len(exps[0].tokens))
response = await openai_client.chat.completions.create(
model=model_id,
messages=messages,
logprobs=False,
)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 1)
self.assertTrue(len(exps[0].logprobs) == 0)
response = (
await self.model_wrapper_no_history.get_openai_async_client().chat.completions.create(
model=model_id, messages=messages, n=2
Expand Down
5 changes: 3 additions & 2 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import socket
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, List, Optional, Sequence, Tuple, Union

import httpx
Expand Down Expand Up @@ -276,7 +277,7 @@ def get_openai_client(self) -> openai.OpenAI:
)
if self.enable_history:
# add a decorator to the openai client to record history
ori_create = self.openai_client.chat.completions.create
ori_create = partial(self.openai_client.chat.completions.create, logprobs=True)

def record_chat_completions(*args, **kwargs):
response = ori_create(*args, **kwargs)
Expand Down Expand Up @@ -306,7 +307,7 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
)
if self.enable_history:
# add a decorator to the openai client to record history
ori_create = self.openai_async_client.chat.completions.create
ori_create = partial(self.openai_async_client.chat.completions.create, logprobs=True)

async def record_chat_completions(*args, **kwargs):
response = await ori_create(*args, **kwargs)
Expand Down
5 changes: 5 additions & 0 deletions trinity/common/workflows/agentscope_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def __init__(

self.chat_model: TrinityChatModel = TrinityChatModel(
model.get_openai_async_client(),
generate_kwargs={
"temperature": self.task.rollout_args.temperature,
"max_tokens": self.task.rollout_args.max_tokens or 4096,
"top_logprobs": self.task.rollout_args.logprobs,
},
)

def construct_experiences(
Expand Down