Skip to content

Commit 224f2f3

Browse files
committed
.
1 parent 457a62a commit 224f2f3

File tree

4 files changed

+92
-28
lines changed

4 files changed

+92
-28
lines changed

areal/experimental/openai/proxy/server.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,12 @@ def __init__(
318318
self.host_ip = gethostip()
319319
self._localhost = "0.0.0.0"
320320
self.server_config = uvicorn.Config(
321-
self.app, host=self._localhost, port=self.port, log_level=uvicorn_log_level
321+
self.app,
322+
host=self._localhost,
323+
port=self.port,
324+
log_level=uvicorn_log_level,
325+
timeout_keep_alive=300,
326+
workers=4,
322327
)
323328
self.server = uvicorn.Server(self.server_config)
324329
self.thread = threading.Thread(target=self.server.run, daemon=True)
@@ -378,15 +383,12 @@ async def fetch_next_session(self) -> str:
378383
except Empty:
379384
await asyncio.sleep(0.1)
380385

381-
async def wait_for_session(
382-
self, session_id: str, discount: float = 1.0, style: str = "individual"
383-
) -> SessionData:
386+
async def wait_for_session(self, session_id: str) -> SessionData:
384387
if session_id not in self.session_cache:
385388
raise KeyError(f"Session {session_id} not found")
386389
# Wait for session to be completed using event
387390
await self.session_cache[session_id].wait_for_finish()
388-
session = self.session_cache.pop(session_id)
389-
return session.export_interactions(discount=discount, style=style)
391+
return self.session_cache.pop(session_id)
390392

391393
def set_reward(self, session_id: str, completion_id: str, reward: float):
392394
"""Set reward for a specific completion/response by its ID."""

areal/experimental/openai/proxy/workflow.py

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
import atexit
5+
import os
6+
import threading
37
import traceback
8+
from concurrent.futures import ProcessPoolExecutor
49
from typing import TYPE_CHECKING
510

611
from areal.api.workflow_api import AgentWorkflow, RolloutWorkflow
712
from areal.core import workflow_context
8-
from areal.utils import logging
13+
from areal.utils import logging, stats_tracker
14+
from areal.utils.perf_tracer import session_context, trace_session
915

1016
from .client_session import OpenAIProxyClientSession
1117
from .server import OpenAIProxyServer
@@ -16,6 +22,46 @@
1622
logger = logging.getLogger("OpenAIProxyWorkflow")
1723

1824

25+
# Lazy-initialized thread pool for async HTTP requests
26+
_executor: ProcessPoolExecutor | None = None
27+
_executor_lock = threading.Lock()
28+
29+
30+
def _get_executor() -> ProcessPoolExecutor:
31+
"""Get or create the shared process pool executor."""
32+
global _executor
33+
if _executor is None:
34+
with _executor_lock:
35+
if _executor is None:
36+
_executor = ProcessPoolExecutor(max_workers=4)
37+
# Register cleanup on process exit
38+
atexit.register(_shutdown_executor)
39+
return _executor
40+
41+
42+
def _shutdown_executor() -> None:
43+
"""Shutdown the shared thread pool executor if it exists.
44+
45+
Called via atexit at process exit, when no other threads should be
46+
accessing the executor.
47+
"""
48+
global _executor
49+
if _executor is not None:
50+
_executor.shutdown(wait=False)
51+
_executor = None
52+
53+
54+
def _wrap_run(agent, data, extra_envs):
55+
for key, value in extra_envs.items():
56+
os.environ[key] = value
57+
58+
try:
59+
return asyncio.run(agent.run(None, data))
60+
except Exception:
61+
logger.error(f"Agent task failed: {traceback.format_exc()}")
62+
raise
63+
64+
1965
class OpenAIProxyWorkflow(RolloutWorkflow):
2066
def __init__(
2167
self,
@@ -32,6 +78,20 @@ def __init__(
3278
self.discount = discount
3379
self.export_style = export_style
3480

81+
@trace_session("run_agent")
82+
async def _run_agent(self, base_url: str, data: dict):
83+
extra_envs = {
84+
"OPENAI_BASE_URL": base_url,
85+
}
86+
executor = _get_executor()
87+
fut = executor.submit(_wrap_run, self.agent, data, extra_envs)
88+
try:
89+
return await asyncio.wrap_future(fut)
90+
except Exception:
91+
logger.error(f"Agent task failed: {traceback.format_exc()}")
92+
raise
93+
94+
@session_context()
3595
async def arun_episode(self, engine: TRolloutEngine, data):
3696
# Ensure that we own the same engine instance
3797
task_id = workflow_context.get().task_id
@@ -48,11 +108,8 @@ async def arun_episode(self, engine: TRolloutEngine, data):
48108
async with OpenAIProxyClientSession(
49109
base_url=self.proxy_server.public_addr, task_id=str(task_id)
50110
) as session:
51-
try:
52-
rewards = await self.agent.run(session.session_url, data)
53-
except Exception:
54-
logger.error(f"Agent task failed: {traceback.format_exc()}")
55-
raise
111+
rewards = await self._run_agent(session.session_url, data)
112+
56113
session_id = session.session_id
57114
if isinstance(rewards, dict):
58115
for completion_id, reward in rewards.items():
@@ -65,6 +122,14 @@ async def arun_episode(self, engine: TRolloutEngine, data):
65122
# Pop a session id from the server queue and ignore it.
66123
_ = await self.proxy_server.fetch_next_session()
67124

68-
return await self.proxy_server.wait_for_session(
69-
session_id, discount=self.discount, style=self.export_style
125+
session_data = await self.proxy_server.wait_for_session(session_id)
126+
last_id = session_data.completions.last_interaction_id
127+
interactions = session_data.completions.export_interactions(
128+
reward_discount=self.discount, style=self.export_style
70129
)
130+
131+
# Record the last reward in wandb/tensorboard
132+
last_reward = interactions[last_id].reward
133+
stats_tracker.get(workflow_context.stat_scope()).scalar(reward=last_reward)
134+
135+
return interactions

areal/tests/experimental/openai/test_proxy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ async def test_session_lifecycle(proxy_server):
122122
assert resp.status == 200
123123

124124
# 3. after end session, can fetch results with `wait_for_session`
125-
interactions = await proxy_server.wait_for_session(session_id)
125+
session_data = await proxy_server.wait_for_session(session_id)
126+
interactions = session_data.completions.export_interactions
126127
assert len(interactions) >= 1
127128

128129

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,24 @@
1+
from math_verify import parse, verify
12
from openai import AsyncOpenAI
23
from openai.types.chat import ChatCompletion
34

45
from areal.api.workflow_api import AgentWorkflow
5-
from areal.core import workflow_context
6-
from areal.utils import stats_tracker
76

87

98
class GSM8kAgent(AgentWorkflow):
109
def __init__(self, **kwargs):
1110
self.kwargs = kwargs
1211

1312
async def run(self, base_url: str, data: dict):
14-
async with AsyncOpenAI(base_url=base_url) as client:
13+
# custom_timeout = httpx.Timeout(30.0, read=600.0)
14+
# async with AsyncOpenAI(base_url=base_url, max_retries=0,
15+
# timeout=custom_timeout) as client:
16+
async with AsyncOpenAI(max_retries=0) as client:
1517
comp: ChatCompletion = await client.chat.completions.create(
1618
messages=data["messages"], model="default", **self.kwargs
1719
)
1820

19-
# compute reward with areal's existing implementation
20-
# Use the following wrapper to suppress the annoying warning of math-verify
21-
from areal.api.reward_api import AsyncRewardWrapper
22-
from areal.reward.gsm8k import gsm8k_reward_fn
23-
24-
reward = await AsyncRewardWrapper(gsm8k_reward_fn)(
25-
None, comp.choices[0].message.content, None, None, answer=data["answer"]
26-
)
27-
stats_tracker.get(workflow_context.stat_scope()).scalar(reward=reward)
28-
return reward
21+
ans = parse(comp.choices[0].message.content)
22+
gold = parse(data["answer"])
23+
reward = verify(ans, gold)
24+
return float(reward)

0 commit comments

Comments
 (0)