66from concurrent .futures import ProcessPoolExecutor , ThreadPoolExecutor
77from dataclasses import dataclass
88from queue import Empty , Full , Queue
9- from typing import TYPE_CHECKING , Any , Optional
9+ from typing import Any , Optional
1010
1111import aiohttp
1212import requests
2222 RolloutStat ,
2323 WeightUpdateMeta ,
2424)
25+ from areal .api .workflow_api import RolloutWorkflow
2526from areal .extension .asystem .api .cli_args import RemoteHybridInferenceConfig
2627from areal .extension .asystem .util import wait_future_ordered
2728from areal .utils import logging , seeding
2829from areal .utils .data import concat_padded_tensors , cycle_dataloader
2930from areal .utils .errors import EngineError , FrameworkError
3031from areal .utils .http import arequest_with_retry , get_default_connector
3132
32- if TYPE_CHECKING :
33- from areal .api .workflow_api import RolloutWorkflow
3433logger = logging .getLogger (__name__ )
3534
3635ROLLOUT_POLL_WAIT_TIME = 0.05
@@ -236,9 +235,7 @@ async def _rollout_thread_async(self):
236235 # logger.info(f"Get data from puller: {data}")
237236 task = asyncio .create_task (
238237 (
239- workflow .arun_episodes (self , data )
240- if isinstance (data , list )
241- else workflow .arun_episode (self , data )
238+ workflow .arun_episode (self , data )
242239 ),
243240 name = str (rid ),
244241 )
@@ -345,6 +342,7 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse:
345342 start_time = time .perf_counter ()
346343 accumulated_output_tokens = []
347344 accumulated_output_logprobs = []
345+ accumulated_versions = []
348346
349347 # Deal with rollout interruption
350348 stop_reason = ""
@@ -385,6 +383,9 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse:
385383 # Update accumulated outputs
386384 accumulated_output_tokens .extend (output_tokens )
387385 accumulated_output_logprobs .extend (output_logprobs )
386+ accumulated_versions .extend (
387+ [self .get_version ()] * len (output_logprobs )
388+ )
388389
389390 # Check if generation is complete
390391 finish_reason = meta_info ["finish_reason" ]
@@ -399,7 +400,7 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse:
399400 input_tokens = req .input_ids ,
400401 output_tokens = accumulated_output_tokens ,
401402 output_logprobs = accumulated_output_logprobs ,
402- output_version = self . get_version () ,
403+ output_versions = accumulated_versions ,
403404 stop_reason = stop_reason ,
404405 latency = latency ,
405406 ttft = latency , # Simplified for non-streaming
@@ -532,14 +533,13 @@ def update_weights_from_disk(self, addr, path: str):
532533
533534 def submit (
534535 self ,
535- data : list [dict [str , Any ]] | dict [str , Any ],
536- workflow : "RolloutWorkflow" ,
536+ data : dict [str , Any ],
537+ workflow : RolloutWorkflow | None = None ,
538+ workflow_builder : Callable | None = None ,
539+ should_accept : Callable | None = None ,
537540 ) -> None :
538541 try :
539- if not isinstance (data , list ):
540- data = [data ]
541- for d in data :
542- self .input_queue .put_nowait ((d , workflow ))
542+ self .input_queue .put_nowait ((data , workflow ))
543543 except Full :
544544 raise FrameworkError (
545545 "FrameworkError" ,
@@ -548,7 +548,7 @@ def submit(
548548 )
549549
550550 def submit_batch (
551- self , data : list [dict [str , Any ]], workflow : " RolloutWorkflow"
551+ self , data : list [dict [str , Any ]], workflow : RolloutWorkflow
552552 ) -> None :
553553 try :
554554 self .input_queue .put_nowait (data , workflow )
@@ -701,3 +701,11 @@ def notify_event(self, event: str, global_step: int) -> None:
701701 except Exception as e :
702702 raise EngineError ("InferenceEngineError" , "NotifyEventError" , e )
703703 return None
704+
705+ def wait_quiet (
706+ self , count : int , timeout : float | None = None , max_retries : int = 1 ,
707+ ) -> dict [str , Any ] | None :
708+ try :
709+ return self .wait (count , timeout = timeout )
710+ except TimeoutError :
711+ return "NO_RESULT"
0 commit comments