22import json
33import threading
44import time
5+ from collections .abc import Callable
56from concurrent .futures import ProcessPoolExecutor , ThreadPoolExecutor
67from dataclasses import dataclass
78from queue import Empty , Full , Queue
8- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union
9+ from typing import TYPE_CHECKING , Any , Optional
910
1011import aiohttp
1112import requests
1415from tensordict import TensorDict
1516from torchdata .stateful_dataloader import StatefulDataLoader
1617
17- from areal .api .cli_args import RemoteHybridInferenceConfig
1818from areal .api .engine_api import InferenceEngine
1919from areal .api .io_struct import (
2020 ModelRequest ,
2121 ModelResponse ,
2222 RolloutStat ,
2323 WeightUpdateMeta ,
2424)
25-
26- from areal .utils .data import concat_padded_tensors , cycle_dataloader
25+ from areal .extension .asystem .api .cli_args import RemoteHybridInferenceConfig
2726from areal .extension .asystem .util import wait_future_ordered
27+ from areal .utils import logging , seeding
28+ from areal .utils .data import concat_padded_tensors , cycle_dataloader
2829from areal .utils .errors import EngineError , FrameworkError
2930from areal .utils .http import arequest_with_retry , get_default_connector
30- from realhf .base import logging , seeding
3131
3232if TYPE_CHECKING :
3333 from areal .api .workflow_api import RolloutWorkflow
@@ -156,7 +156,7 @@ def initialize(self, initialize_cfg: RemoteHypidInferenceInitConfig):
156156 self .input_queue = Queue (maxsize = self .qsize )
157157 self .output_queue = Queue (maxsize = self .qsize )
158158
159- self .rollout_tasks : Dict [str , asyncio .Task ] = {}
159+ self .rollout_tasks : dict [str , asyncio .Task ] = {}
160160 self .executor = ProcessPoolExecutor (max_workers = 1 )
161161 self .rollout_thread = threading .Thread (target = self ._rollout_thread )
162162 self .rollout_thread .start ()
@@ -463,7 +463,7 @@ def update_single_server(addr):
463463 )
464464 elif meta .type == "nccl" or meta .type == "astate" :
465465 load_timestamp = time .time_ns ()
466- logger .info (f "Begin update weights." )
466+ logger .info ("Begin update weights." )
467467
468468 def update_single_server (addr ):
469469 try :
@@ -532,7 +532,7 @@ def update_weights_from_disk(self, addr, path: str):
532532
533533 def submit (
534534 self ,
535- data : Union [ List [ Dict [ str , Any ]], Dict [str , Any ] ],
535+ data : list [ dict [ str , Any ]] | dict [str , Any ],
536536 workflow : "RolloutWorkflow" ,
537537 ) -> None :
538538 try :
@@ -548,7 +548,7 @@ def submit(
548548 )
549549
550550 def submit_batch (
551- self , data : List [ Dict [str , Any ]], workflow : "RolloutWorkflow"
551+ self , data : list [ dict [str , Any ]], workflow : "RolloutWorkflow"
552552 ) -> None :
553553 try :
554554 self .input_queue .put_nowait (data , workflow )
@@ -561,11 +561,11 @@ def submit_batch(
561561
562562 def rollout_batch (
563563 self ,
564- data : List [ Dict [str , Any ]],
564+ data : list [ dict [str , Any ]],
565565 workflow : Optional ["RolloutWorkflow" ] = None ,
566- workflow_builder : Optional [ Callable ] = None ,
566+ workflow_builder : Callable | None = None ,
567567 should_accept : Callable | None = None ,
568- ) -> Dict [str , Any ]:
568+ ) -> dict [str , Any ]:
569569 try :
570570 self .input_queue .put_nowait (data , workflow )
571571 except Full :
@@ -645,8 +645,7 @@ def wait(
645645 raise FrameworkError (
646646 "FrameworkError" ,
647647 "InferenceWorkError" ,
648- f"Timed out waiting for { count } rollouts, "
649- f"only received { accepted } ." ,
648+ f"Timed out waiting for { count } rollouts, only received { accepted } ." ,
650649 )
651650 with self .lock :
652651 results , self .result_cache = (
@@ -658,8 +657,8 @@ def wait(
658657 return padded
659658
660659 def rollout ( # only dp head accept this request
661- self , data : List [ Dict [str , Any ]], workflow : "RolloutWorkflow" , * args , ** kwargs
662- ) -> Dict [str , Any ]:
660+ self , data : list [ dict [str , Any ]], workflow : "RolloutWorkflow" , * args , ** kwargs
661+ ) -> dict [str , Any ]:
663662 """Submit a batch of requests to the inference engine and wait for the results."""
664663 if self .config .batch_requests is True :
665664 self .submit_batch (data , workflow )
0 commit comments