Skip to content

Commit 87b1dba

Browse files
楚财daihao
authored andcommitted
PullRequest: 931 add train worker
Merge branch chucai.dzq/train-worker-adapt-open-source of [email protected]:inclusionAI/AReaL.git into asystem/gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/931 Reviewed-by: 峯回 <[email protected]> * add worker
1 parent a42edb2 commit 87b1dba

File tree

2 files changed

+1002
-16
lines changed

2 files changed

+1002
-16
lines changed

areal/extension/asystem/remote_hybrid_inference_worker.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import json
33
import threading
44
import time
5+
from collections.abc import Callable
56
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
67
from dataclasses import dataclass
78
from queue import Empty, Full, Queue
8-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
9+
from typing import TYPE_CHECKING, Any, Optional
910

1011
import aiohttp
1112
import requests
@@ -14,20 +15,19 @@
1415
from tensordict import TensorDict
1516
from torchdata.stateful_dataloader import StatefulDataLoader
1617

17-
from areal.api.cli_args import RemoteHybridInferenceConfig
1818
from areal.api.engine_api import InferenceEngine
1919
from areal.api.io_struct import (
2020
ModelRequest,
2121
ModelResponse,
2222
RolloutStat,
2323
WeightUpdateMeta,
2424
)
25-
26-
from areal.utils.data import concat_padded_tensors, cycle_dataloader
25+
from areal.extension.asystem.api.cli_args import RemoteHybridInferenceConfig
2726
from areal.extension.asystem.util import wait_future_ordered
27+
from areal.utils import logging, seeding
28+
from areal.utils.data import concat_padded_tensors, cycle_dataloader
2829
from areal.utils.errors import EngineError, FrameworkError
2930
from areal.utils.http import arequest_with_retry, get_default_connector
30-
from realhf.base import logging, seeding
3131

3232
if TYPE_CHECKING:
3333
from areal.api.workflow_api import RolloutWorkflow
@@ -156,7 +156,7 @@ def initialize(self, initialize_cfg: RemoteHypidInferenceInitConfig):
156156
self.input_queue = Queue(maxsize=self.qsize)
157157
self.output_queue = Queue(maxsize=self.qsize)
158158

159-
self.rollout_tasks: Dict[str, asyncio.Task] = {}
159+
self.rollout_tasks: dict[str, asyncio.Task] = {}
160160
self.executor = ProcessPoolExecutor(max_workers=1)
161161
self.rollout_thread = threading.Thread(target=self._rollout_thread)
162162
self.rollout_thread.start()
@@ -463,7 +463,7 @@ def update_single_server(addr):
463463
)
464464
elif meta.type == "nccl" or meta.type == "astate":
465465
load_timestamp = time.time_ns()
466-
logger.info(f"Begin update weights.")
466+
logger.info("Begin update weights.")
467467

468468
def update_single_server(addr):
469469
try:
@@ -532,7 +532,7 @@ def update_weights_from_disk(self, addr, path: str):
532532

533533
def submit(
534534
self,
535-
data: Union[List[Dict[str, Any]], Dict[str, Any]],
535+
data: list[dict[str, Any]] | dict[str, Any],
536536
workflow: "RolloutWorkflow",
537537
) -> None:
538538
try:
@@ -548,7 +548,7 @@ def submit(
548548
)
549549

550550
def submit_batch(
551-
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
551+
self, data: list[dict[str, Any]], workflow: "RolloutWorkflow"
552552
) -> None:
553553
try:
554554
self.input_queue.put_nowait(data, workflow)
@@ -561,11 +561,11 @@ def submit_batch(
561561

562562
def rollout_batch(
563563
self,
564-
data: List[Dict[str, Any]],
564+
data: list[dict[str, Any]],
565565
workflow: Optional["RolloutWorkflow"] = None,
566-
workflow_builder: Optional[Callable] = None,
566+
workflow_builder: Callable | None = None,
567567
should_accept: Callable | None = None,
568-
) -> Dict[str, Any]:
568+
) -> dict[str, Any]:
569569
try:
570570
self.input_queue.put_nowait(data, workflow)
571571
except Full:
@@ -645,8 +645,7 @@ def wait(
645645
raise FrameworkError(
646646
"FrameworkError",
647647
"InferenceWorkError",
648-
f"Timed out waiting for {count} rollouts, "
649-
f"only received {accepted}.",
648+
f"Timed out waiting for {count} rollouts, only received {accepted}.",
650649
)
651650
with self.lock:
652651
results, self.result_cache = (
@@ -658,8 +657,8 @@ def wait(
658657
return padded
659658

660659
def rollout( # only dp head accept this request
661-
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow", *args, **kwargs
662-
) -> Dict[str, Any]:
660+
self, data: list[dict[str, Any]], workflow: "RolloutWorkflow", *args, **kwargs
661+
) -> dict[str, Any]:
663662
"""Submit a batch of requests to the inference engine and wait for the results."""
664663
if self.config.batch_requests is True:
665664
self.submit_batch(data, workflow)

0 commit comments

Comments
 (0)