11from __future__ import annotations
22
3+ import asyncio
4+ import atexit
5+ import os
6+ import threading
37import traceback
8+ from concurrent .futures import ProcessPoolExecutor
49from typing import TYPE_CHECKING
510
611from areal .api .workflow_api import AgentWorkflow , RolloutWorkflow
712from areal .core import workflow_context
8- from areal .utils import logging
13+ from areal .utils import logging , stats_tracker
14+ from areal .utils .perf_tracer import session_context , trace_session
915
1016from .client_session import OpenAIProxyClientSession
1117from .server import OpenAIProxyServer
1622logger = logging .getLogger ("OpenAIProxyWorkflow" )
1723
1824
25+ # Lazy-initialized thread pool for async HTTP requests
26+ _executor : ProcessPoolExecutor | None = None
27+ _executor_lock = threading .Lock ()
28+
29+
30+ def _get_executor () -> ProcessPoolExecutor :
31+ """Get or create the shared process pool executor."""
32+ global _executor
33+ if _executor is None :
34+ with _executor_lock :
35+ if _executor is None :
36+ _executor = ProcessPoolExecutor (max_workers = 4 )
37+ # Register cleanup on process exit
38+ atexit .register (_shutdown_executor )
39+ return _executor
40+
41+
42+ def _shutdown_executor () -> None :
43+ """Shutdown the shared thread pool executor if it exists.
44+
45+ Called via atexit at process exit, when no other threads should be
46+ accessing the executor.
47+ """
48+ global _executor
49+ if _executor is not None :
50+ _executor .shutdown (wait = False )
51+ _executor = None
52+
53+
54+ def _wrap_run (agent , data , extra_envs ):
55+ for key , value in extra_envs .items ():
56+ os .environ [key ] = value
57+
58+ try :
59+ return asyncio .run (agent .run (None , data ))
60+ except Exception :
61+ logger .error (f"Agent task failed: { traceback .format_exc ()} " )
62+ raise
63+
64+
1965class OpenAIProxyWorkflow (RolloutWorkflow ):
2066 def __init__ (
2167 self ,
@@ -32,6 +78,20 @@ def __init__(
3278 self .discount = discount
3379 self .export_style = export_style
3480
81+ @trace_session ("run_agent" )
82+ async def _run_agent (self , base_url : str , data : dict ):
83+ extra_envs = {
84+ "OPENAI_BASE_URL" : base_url ,
85+ }
86+ executor = _get_executor ()
87+ fut = executor .submit (_wrap_run , self .agent , data , extra_envs )
88+ try :
89+ return await asyncio .wrap_future (fut )
90+ except Exception :
91+ logger .error (f"Agent task failed: { traceback .format_exc ()} " )
92+ raise
93+
94+ @session_context ()
3595 async def arun_episode (self , engine : TRolloutEngine , data ):
3696 # Ensure that we own the same engine instance
3797 task_id = workflow_context .get ().task_id
@@ -48,11 +108,8 @@ async def arun_episode(self, engine: TRolloutEngine, data):
48108 async with OpenAIProxyClientSession (
49109 base_url = self .proxy_server .public_addr , task_id = str (task_id )
50110 ) as session :
51- try :
52- rewards = await self .agent .run (session .session_url , data )
53- except Exception :
54- logger .error (f"Agent task failed: { traceback .format_exc ()} " )
55- raise
111+ rewards = await self ._run_agent (session .session_url , data )
112+
56113 session_id = session .session_id
57114 if isinstance (rewards , dict ):
58115 for completion_id , reward in rewards .items ():
@@ -65,6 +122,14 @@ async def arun_episode(self, engine: TRolloutEngine, data):
65122 # Pop a session id from the server queue and ignore it.
66123 _ = await self .proxy_server .fetch_next_session ()
67124
68- return await self .proxy_server .wait_for_session (
69- session_id , discount = self .discount , style = self .export_style
125+ session_data = await self .proxy_server .wait_for_session (session_id )
126+ last_id = session_data .completions .last_interaction_id
127+ interactions = session_data .completions .export_interactions (
128+ reward_discount = self .discount , style = self .export_style
70129 )
130+
131+ # Record the last reward in wandb/tensorboard
132+ last_reward = interactions [last_id ].reward
133+ stats_tracker .get (workflow_context .stat_scope ()).scalar (reward = last_reward )
134+
135+ return interactions
0 commit comments