55import os
66import random
77import shutil
8- import unittest
98from datetime import datetime
109
1110import httpx
12- import openai
1311import ray
1412
1513from tests .tools import (
2725from trinity .common .config import ExperienceBufferConfig , InferenceModelConfig
2826from trinity .common .constants import StorageType
2927from trinity .explorer .explorer import Explorer
28+ from trinity .explorer .proxy .client import TrinityClient
3029from trinity .manager .state_manager import StateManager
3130
3231
@@ -158,8 +157,9 @@ def run_serve(config):
158157 run_stage (config )
159158
160159
161- def run_agent (base_url , model_path : str ):
162- client = openai .Client (base_url = base_url , api_key = "testkey" )
160+ def run_agent (proxy_url , model_path : str ):
161+ proxy_client = TrinityClient (proxy_url = proxy_url )
162+ openai_client = proxy_client .get_openai_client ()
163163 contents = [
164164 "Hello, how are you?" ,
165165 "What is the capital of China?" ,
@@ -172,10 +172,11 @@ def run_agent(base_url, model_path: str):
172172 "What is the best way to learn programming?" ,
173173 "Describe the process of photosynthesis." ,
174174 ]
175- response = client .chat .completions .create (
175+ response = openai_client .chat .completions .create (
176176 model = model_path ,
177177 messages = [{"role" : "user" , "content" : random .choice (contents )}],
178178 )
179+ proxy_client .feedback (reward = 2.0 , msg_ids = [response .id ])
179180 return response .choices [0 ].message .content
180181
181182
@@ -191,7 +192,7 @@ def setUp(self):
191192 self .config .explorer .rollout_model .engine_num = 4
192193 self .config .explorer .rollout_model .enable_openai_api = True
193194 self .config .checkpoint_root_dir = get_checkpoint_path ()
194- self .config .explorer .api_port = 8010
195+ self .config .explorer .proxy_port = 8010
195196 self .config .explorer .service_status_check_interval = 30
196197 self .config .buffer .trainer_input .experience_buffer = ExperienceBufferConfig (
197198 name = "experience_buffer" ,
@@ -201,7 +202,6 @@ def setUp(self):
201202 if multiprocessing .get_start_method (allow_none = True ) != "spawn" :
202203 multiprocessing .set_start_method ("spawn" , force = True )
203204
204- @unittest .skip ("Require improvement for agent mode" )
205205 async def test_serve (self ): # noqa: C901
206206 serve_process = multiprocessing .Process (target = run_serve , args = (self .config ,))
207207 serve_process .start ()
@@ -238,7 +238,7 @@ async def test_serve(self): # noqa: C901
238238 apps = []
239239 for i in range (task_num ):
240240 app_process = multiprocessing .Process (
241- target = run_agent , args = (server_url + "/v1" , self .config .model .model_path )
241+ target = run_agent , args = (server_url , self .config .model .model_path )
242242 )
243243 apps .append (app_process )
244244 app_process .start ()
@@ -248,22 +248,20 @@ async def test_serve(self): # noqa: C901
248248 self .assertFalse (app .is_alive ())
249249
250250 finish_step = None
251-
251+ proxy_client = TrinityClient ( proxy_url = server_url )
252252 for i in range (20 ):
253- async with httpx .AsyncClient () as client :
254- response = await client .get (f"{ server_url } /metrics" )
255- self .assertEqual (response .status_code , 200 )
256- metrics = response .json ()
257- metrics_keys = list (metrics .keys ())
258- self .assertIn ("explore_step_num" , metrics_keys )
259- self .assertIn ("rollout/total_experience_count" , metrics_keys )
260- self .assertIn ("rollout/model_0/total_request_count" , metrics_keys )
261- self .assertIn ("rollout/model_3/model_version" , metrics_keys )
262- if not finish_step and metrics ["rollout/total_experience_count" ] == task_num :
263- finish_step = metrics ["explore_step_num" ]
264- if finish_step and metrics ["explore_step_num" ] >= finish_step + 1 :
265- # wait for one more step to ensure all data are written to buffer
266- break
253+ metrics = await proxy_client .get_metrics_async ()
254+ metrics_keys = list (metrics .keys ())
255+ self .assertIn ("explore_step_num" , metrics_keys )
256+ self .assertIn ("rollout/total_experience_count" , metrics_keys )
257+ self .assertIn ("rollout/model_0/total_request_count" , metrics_keys )
258+ self .assertIn ("rollout/model_3/model_version" , metrics_keys )
259+ if not finish_step and metrics ["rollout/total_experience_count" ] == task_num :
260+ finish_step = metrics ["explore_step_num" ]
261+ await proxy_client .commit_async ()
262+ if finish_step and metrics ["explore_step_num" ] >= finish_step + 1 :
263+ # wait for one more step to ensure all data are written to buffer
264+ break
267265 await asyncio .sleep (3 )
268266
269267 serve_process .terminate ()
@@ -277,6 +275,9 @@ async def test_serve(self): # noqa: C901
277275 exps = await buffer_reader .read_async (batch_size = 10 )
278276 for exp in exps :
279277 self .assertTrue (len (exp .tokens ) > 0 )
278+ self .assertTrue (len (exp .logprobs ) > 0 )
279+ self .assertTrue (exp .prompt_length > 0 )
280+ self .assertTrue (exp .reward == 2.0 )
280281 self .assertEqual (len (exps ), task_num )
281282
282283 def tearDown (self ):
0 commit comments