diff --git a/areal/api/controller_api.py b/areal/api/controller_api.py index 1df593161..78edbd7a7 100644 --- a/areal/api/controller_api.py +++ b/areal/api/controller_api.py @@ -315,7 +315,7 @@ def set_version(self, version: int): """ raise NotImplementedError() - def get_version(self) -> int: + def get_version(self) -> List[int]: """Get the current weight version in the training engine. Returns @@ -359,7 +359,7 @@ def train_batch( input_: DistributedBatch, loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor], loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor], - ) -> Dict[str, float]: + ) -> List[Dict[str, float]]: """Update the model with a batch of data and a loss function. Note @@ -382,7 +382,7 @@ def train_batch( Returns ------- - Dict[str, float] + List[Dict[str, float]] Scalar statistics after training, e.g., the current learning rate, gradient norm, etc. """ @@ -394,7 +394,7 @@ def eval_batch( input_: DistributedBatch, loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor], loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor], - ) -> torch.Tensor | None: + ) -> List[torch.Tensor]: """Evaluate the model using the forward pass and loss function. Note @@ -458,7 +458,6 @@ def forward( """ raise NotImplementedError() - class RolloutController(abc.ABC): """A centralized controller that manages multiple distributed InferenceEngine workers for rollout generation. @@ -508,21 +507,6 @@ def destroy(self): """Destroy the engine and release GPU memory for the local inference engine.""" raise NotImplementedError() - async def agenerate(self, req: ModelRequest) -> ModelResponse: - """Asynchronously generate a response for the given request. - - Parameters - ---------- - req : ModelRequest - The model request containing input data and generation parameters - - Returns - ------- - ModelResponse - The generated response from the model - """ - raise NotImplementedError() - def update_weights(self, meta: WeightUpdateMeta) -> Future: """Update weights in the inference engine in a non-blocking manner. @@ -571,7 +555,7 @@ def get_version(self) -> int: def submit( self, - data: Dict[str, Any], + data: DistributedBatch, workflow: Optional["RolloutWorkflow"] = None, workflow_builder: Optional[Callable] = None, should_accept: Callable | None = None, @@ -623,7 +607,7 @@ def wait(self, count: int, timeout: float | None = None) -> DistributedBatch: def rollout_batch( self, - data: List[Dict[str, Any]], + data: DistributedBatch, workflow: Optional["RolloutWorkflow"] = None, workflow_builder: Optional[Callable] = None, should_accept: Callable | None = None, @@ -652,7 +636,7 @@ def rollout_batch( def prepare_batch( self, - dataloader: StatefulDataLoader, + dataloader: DistributedBatch, workflow: Optional["RolloutWorkflow"] = None, workflow_builder: Optional[Callable] = None, should_accept: Callable | None = None, @@ -688,31 +672,4 @@ def pause(self): def resume(self): """Resume request submission for async rollout.""" - raise NotImplementedError() - - def register_callback_to_all_worker( - self, method: str, callback: Callable, **kwargs - ): - """Register a callback function for the specified method across all workers. - - Partial rollout API. After successful registration, the controller will poll - and call the specified method in a background thread. When the return value - is obtained, it will be used as a parameter to call the `callback` function. - - Parameters - ---------- - method : str - The name of the method to register the callback for - callback : Callable - The callback function to be called with the method's return value - **kwargs - Additional keyword arguments for the callback registration - """ - raise NotImplementedError() - - def abort_all_requests(self) -> None: - """Abort all ongoing requests in the inference engine. - - Partial rollout API for canceling all queued and in-progress requests. - """ - raise NotImplementedError() + raise NotImplementedError() \ No newline at end of file diff --git a/areal/api/engine_api.py b/areal/api/engine_api.py index 021c0a2ea..e1ef480e5 100644 --- a/areal/api/engine_api.py +++ b/areal/api/engine_api.py @@ -25,6 +25,8 @@ class Scheduling: cpu: int gpu: int mem: int + port_count: int + cmd: str | None = None nodelist: str | None = None exclude: str | None = None partition: str | None = None @@ -138,7 +140,7 @@ def parallelism_group(self) -> dist.ProcessGroup: """ raise NotImplementedError() - def get_scheduling_config(self) -> Scheduling: + def get_scheduling_config(self) -> List[Scheduling]: """Get the scheduling configuration for the engine. This includes configuration such as container image, CPU/GPU/memory size. @@ -553,3 +555,15 @@ def pause(self): def resume(self): """Resume request submission for async rollout.""" raise NotImplementedError() + + def get_scheduling_config(self) -> List[Scheduling]: + """Get the scheduling configuration for the engine. + + This includes configuration such as container image, CPU/GPU/memory size. + + Returns + ------- + Scheduling + The scheduling configuration for the engine + """ + raise NotImplementedError() diff --git a/areal/api/scheduler_api.py b/areal/api/scheduler_api.py index f7e9fb941..e79c9b22c 100644 --- a/areal/api/scheduler_api.py +++ b/areal/api/scheduler_api.py @@ -1,47 +1,40 @@ import abc from dataclasses import dataclass, field -from typing import Dict, List +from typing import List, Literal + +from areal.api.engine_api import Scheduling @dataclass class Worker: id: str ip: str - ports: List[str] = field(default_factory=list) - - -@dataclass -class ContainerSpec: - cpu: int = 0 - gpu: int = 0 - mem: int = 0 - container_image: str = "" - cmd: str = "" - env_vars: Dict[str, str] = field(default_factory=dict) - port_count: int = 2 + serve_port: str + extra_ports: List[str] = field(default_factory=list) @dataclass class ScheduleStrategy: - type: str = "" - uid: str = "" + type: Literal["colocation", "separation", ""] = "" + target: str = "" @dataclass -class SchedulingConfig: +class Job: replicas: int = 0 - specs: List[ContainerSpec] = field(default_factory=list) + tasks: List[Scheduling] = field(default_factory=list) schedule_strategy: ScheduleStrategy | None = None role: str = "" class Scheduler(abc.ABC): - def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> str: + def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> None: """ - Start workers, return job id + Start workers """ + raise NotImplementedError() - def get_workers(self, worker_key, timeout=None) -> List[Worker]: + def get_workers(self, role: str, timeout=None) -> List[Worker]: """ Wait and return worker list, including scheduling results such as ip and engine ports (worker id, ip, ports) diff --git a/areal/api/workflow_api.py b/areal/api/workflow_api.py index 22d4facde..69a791353 100644 --- a/areal/api/workflow_api.py +++ b/areal/api/workflow_api.py @@ -330,14 +330,11 @@ async def _rollout_thread_async(self): try: while not self.exiting.is_set(): # Check capacity - capacity = self.get_capacity() + # capacity = self.get_capacity() + # self.logger.info(f"Current rollout capacity: {capacity}") # Create new rollout task self.lock.acquire() - while ( - capacity > 0 - and not self.paused.is_set() - and self.input_queue.qsize() > 0 - ): + while not self.paused.is_set() and self.input_queue.qsize() > 0: x = self.input_queue.get_nowait() x: _RolloutTaskInput self.logger.debug(f"Get data from puller: {x.data}") @@ -357,7 +354,7 @@ async def _rollout_thread_async(self): f"running: {self.rollout_stat.running}, " f"accepted: {self.rollout_stat.accepted}." ) - capacity -= 1 + # capacity -= 1 rid += 1 tasks = [x.task for x in rollout_tasks.values()] self.lock.release() @@ -524,7 +521,7 @@ def rollout_batch( def prepare_batch( self, - dataloader: StatefulDataLoader, + dataloader: StatefulDataLoader | List[Dict[str, Any]], workflow: "RolloutWorkflow" | None = None, workflow_builder: Callable | None = None, should_accept: Callable | None = None, @@ -533,28 +530,62 @@ def prepare_batch( See :meth:`~areal.api.engine_api.InferenceEngine.prepare_batch` for detailed documentation. """ - if not hasattr(self, "data_generator"): - self.data_generator = cycle_dataloader(dataloader) - assert dataloader.batch_size is not None - while True: - # Submit at least two batches to allow maximum overlap - if ( - self.get_capacity() + dataloader.batch_size > 0 - and self.input_queue.qsize() + dataloader.batch_size - < self.input_queue.maxsize - ): - data = next(self.data_generator) - for item in data: + if isinstance(dataloader, StatefulDataLoader): + # 处理StatefulDataLoader类型 - 保持原有逻辑不变 + if not hasattr(self, "data_generator"): + self.data_generator = cycle_dataloader(dataloader) + assert dataloader.batch_size is not None + batch_size = dataloader.batch_size + + while True: + # Submit at least two batches to allow maximum overlap + if ( + self.get_capacity() + batch_size > 0 + and self.input_queue.qsize() + batch_size + < self.input_queue.maxsize + ): + data = next(self.data_generator) + for item in data: + self.submit( + item, + workflow=workflow, + workflow_builder=workflow_builder, + should_accept=should_accept, + ) + try: + return self.wait(batch_size, timeout=1) + except TimeoutError: + pass + else: + self.data_list_index = 0 + + # 对于List类型,使用固定的batch_size=1 + batch_size = 1 + + while True: + # Submit at least two batches to allow maximum overlap + if ( + self.get_capacity() + batch_size > 0 + and self.input_queue.qsize() + batch_size + < self.input_queue.maxsize + ): + # 从List中获取数据,支持循环访问 + if self.data_list_index >= len(dataloader): + self.data_list_index = 0 # 循环访问 + + item = dataloader[self.data_list_index] + self.data_list_index += 1 + self.submit( item, workflow=workflow, workflow_builder=workflow_builder, should_accept=should_accept, ) - try: - return self.wait(dataloader.batch_size, timeout=1) - except TimeoutError: - pass + try: + return self.wait(batch_size, timeout=1) + except TimeoutError: + pass def pause(self): """Pause request submission for async rollout. diff --git a/areal/controller/rollout_controller.py b/areal/controller/rollout_controller.py new file mode 100644 index 000000000..daba0bab6 --- /dev/null +++ b/areal/controller/rollout_controller.py @@ -0,0 +1,133 @@ +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from typing import Any, Callable, Dict, List + +from tensordict import TensorDict, stack + +from areal.api.cli_args import InferenceEngineConfig +from areal.api.controller_api import RolloutController, DistributedBatch +from areal.api.engine_api import InferenceEngine +from areal.api.io_struct import AllocationMode, WeightUpdateMeta +from areal.api.workflow_api import RolloutWorkflow + +from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker +from areal.controller.utils import create_engine_with_retry, rpc_call +from areal.utils.data import concat_padded_tensors +from areal.utils import logging +from areal.utils.http import wait_future_ordered + +logger = logging.getLogger("DistributedRolloutController") + + +class DistributedRolloutController(RolloutController): + def __init__( + self, + inf_engine: InferenceEngine, + config: InferenceEngineConfig, + scheduler: Scheduler, + ): + super().__init__(inf_engine, config, scheduler) + self.role: str = "rollout" + self.alloc_mode: AllocationMode + self.enable_colocate_mode: bool + self.dp_world_size: int + self.dp_head_workers: List[Worker] + + def initialize( + self, + alloc_mode_str: str, + target: str, + ): + self.alloc_mode = AllocationMode.from_str(alloc_mode_str) + self.dp_world_size = self.alloc_mode.gen.world_size // self.alloc_mode.gen.dp_size + + job = Job( + replicas=self.alloc_mode.gen.world_size, + tasks=self.inf_engine.get_scheduling_config(), + schedule_strategy=ScheduleStrategy(type="colocation", target=target) if target else None, + role=self.role, + ) + logger.info(f"Start to create job: {job}") + self.scheduler.create_workers(job) + + workers = self.scheduler.get_workers(self.role, timeout=1800) + self.dp_head_workers = [worker for idx, worker in enumerate(workers) if idx % self.dp_world_size == 0] + assert len(self.dp_head_workers) == self.alloc_mode.gen.dp_size + + engine_addrs = [f"{w.ip}:{w.serve_port}" for w in self.dp_head_workers] + with ThreadPoolExecutor(max_workers=len(self.dp_head_workers)) as executor: + futures = [ + executor.submit( + partial( + create_engine_with_retry, + self.scheduler.create_engine, + worker.id, + self.inf_engine, + None, + engine_addrs, + self.dp_world_size, + ) + ) + for worker in self.dp_head_workers + ] + + wait_future_ordered(futures, exit_on_exception=True) + + def destroy(self): + self.scheduler.delete_workers() + + def __del__(self): + self.destroy() + + def update_weights(self, meta: WeightUpdateMeta) -> None: + """Update weights in the inference engine.""" + self.custom_function_call("update_weights", None, meta) + return None + + def prepare_batch(self, data: DistributedBatch, workflow: RolloutWorkflow) -> None: + """Asynchronously submit a request to the inference engine. Exits immediately.""" + batches = data.chunk(self.alloc_mode.gen.dp_size) + self.custom_function_call("prepare_batch", batches, workflow) + return None + + def rollout_batch( + self, + data: DistributedBatch, + workflow: RolloutWorkflow + ) -> DistributedBatch: + """Submit a batch of requests to the inference engine and wait for the results.""" + batches = data.chunk(self.alloc_mode.gen.dp_size) + results = self.custom_function_call("rollout_distributed_batch", batches, workflow) + assert len(results) > 0 + size = int(results[0]["input_ids"].shape[0]) + bs = size * len(results) + padded = concat_padded_tensors(results) + if isinstance(padded, dict): + padded = TensorDict(padded, batch_size=[bs]) + return DistributedBatch.concat(padded.to_dict()) + + def set_version(self, version: int) -> None: + self.custom_function_call("set_version", None, version) + return None + + def get_version(self) -> int: + results = self.custom_function_call("get_version", None) + return results[0] + + def pause(self): + self.custom_function_call("pause", None) + + def resume(self): + self.custom_function_call("resume", None) + + def submit(self, data: DistributedBatch): + batches = data.chunk(self.alloc_mode.gen.dp_size) + self.custom_function_call("submit", batches) + + def wait(self, counts: List[int], timeout: float | None = None)->DistributedBatch: + assert len(counts) == len(self.dp_head_workers) + results = self.custom_function_call("wait", counts, timeout) + return DistributedBatch.concat(results) + + def custom_function_call(self, method: str, batches, *args, **kwargs): + return rpc_call(self.scheduler, self.dp_head_workers, method, batches, args, kwargs) diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py new file mode 100644 index 000000000..99ed02dcb --- /dev/null +++ b/areal/controller/train_controller.py @@ -0,0 +1,193 @@ +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from typing import Any, Callable, Dict, List + +import torch + +from areal.api.alloc_mode import ParallelStrategy +from areal.api.cli_args import TrainEngineConfig +from areal.api.controller_api import DistributedBatch, TrainController +from areal.api.engine_api import TrainEngine +from areal.api.io_struct import ( + AllocationMode, + FinetuneSpec, + ParamSpec, + SaveLoadMeta, + WeightUpdateMeta, +) +from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker +from areal.controller.utils import create_engine_with_retry, rpc_call +from areal.utils import logging +from areal.utils.http import wait_future_ordered + +logger = logging.getLogger("DistributedTrainController") + + +class DistributedTrainController(TrainController): + def __init__( + self, train_engine: TrainEngine, config: TrainEngineConfig, scheduler: Scheduler + ): + super().__init__(train_engine, config, scheduler) + + self.role: str = "train" + self.group_size: int + self.alloc_mode: AllocationMode + self.workers: List[Worker] + self.engine_dp_ranks: List[int] + + def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): + assert self.workers is not None, "Workers are not created" + self.custom_function_call("create_process_group", parallel_strategy) + + def initialize( + self, + alloc_mode_str: str, + ft_spec: FinetuneSpec, + schedule_strategy: ScheduleStrategy, + group_size: int = 1, + ): + """Initialize environments for distributed training and load models.""" + self.alloc_mode = AllocationMode.from_str(alloc_mode_str) + self.ft_spec = ft_spec + self.group_size = group_size + + job = Job( + replicas=self.alloc_mode.train.world_size, + tasks=self.train_engine.get_scheduling_config(), + schedule_strategy=schedule_strategy, + role=self.role, + ) + logger.info(f"Start to create job: {job}") + self.scheduler.create_workers(job) + # after get workers, all rpc server is ready + self.workers = self.scheduler.get_workers(self.role, timeout=1800) + + logger.info(f"Start to create process group") + self.create_process_group(self.alloc_mode.train) + + logger.info(f"Start to initialize engine") + with ThreadPoolExecutor(max_workers=len(self.workers)) as executor: + futures = [ + executor.submit( + partial( + create_engine_with_retry, + self.scheduler.create_engine, + worker.id, + self.train_engine, + None, + self.ft_spec, + ) + ) + for worker in self.workers + ] + + wait_future_ordered(futures, exit_on_exception=True) + + logger.info(f"Start to get rank info from engine") + self.engine_dp_ranks = rpc_call( + self.scheduler, self.workers, "data_parallel_rank" + ) + logger.info(f"Initialize train engines succeeded!") + + def destroy(self): + self.scheduler.delete_workers() + + def train(self, mode: bool = True): + self.custom_function_call("train", mode) + + def upload_weights(self, meta: WeightUpdateMeta): + self.custom_function_call("upload_weights", meta) + + def get_param_specs( + self, weight_chunked_mem_mb: int = 1024 + ) -> List[List[ParamSpec]]: + ret: List[List[List[ParamSpec]]] = self.custom_function_call( + "get_param_specs", weight_chunked_mem_mb + ) + flattened = [inner for outer in ret for inner in outer] + return flattened + + def set_version(self, version: int): + return self.custom_function_call("set_version", version) + + def get_version(self) -> List[int]: + return self.custom_function_call("get_version") + + def save(self, meta: SaveLoadMeta): + self.custom_function_call("save", meta) + + def load(self, meta: SaveLoadMeta): + self.custom_function_call("load", meta) + + def step_lr_scheduler(self): + self.custom_function_call("step_lr_scheduler") + + def custom_function_call(self, method: str, *args, **kwargs): + return rpc_call(self.scheduler, self.workers, method, None, args, kwargs) + + def _align_batches_with_dp( + self, input_: DistributedBatch, rebalance=True + ) -> List[DistributedBatch]: + if rebalance: + inputs = input_.chunk_by_ffd(self.group_size, self.alloc_mode.train.dp_size) + else: + inputs = input_.chunk(self.alloc_mode.train.dp_size) + + batches = [] + for dp_rank in self.engine_dp_ranks: + batches.append(inputs[dp_rank]) + + return batches + + def train_batch( + self, + input_: DistributedBatch, + loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor], + loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor], + ) -> List[Dict[str, float]]: + + batches = self._align_batches_with_dp(input_, True) + train_stats = rpc_call( + self.scheduler, + self.workers, + "train_batch", + batches, + loss_fn, + loss_weight_fn, + ) + + return train_stats + + def eval_batch( + self, + input_: DistributedBatch, + loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor], + loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor], + ) -> List[torch.Tensor]: + + batches = self._align_batches_with_dp(input_, True) + eval_stats = rpc_call( + self.scheduler, self.workers, "eval_batch", batches, loss_fn, loss_weight_fn + ) + + return eval_stats + + def forward( + self, + input_: DistributedBatch, + output_seqlens: List[int] | None = None, + post_hook: Callable[[torch.Tensor, Dict[str, Any]], Any] | None = None, + aggregate_fn: Callable[[List[Any]], Any] = torch.cat, + ) -> List[Any]: + batches = self._align_batches_with_dp(input_, False) + forward_stats = rpc_call( + self.scheduler, + self.workers, + "forward", + batches, + output_seqlens, + post_hook, + aggregate_fn, + ) + + return forward_stats diff --git a/areal/controller/utils.py b/areal/controller/utils.py new file mode 100644 index 000000000..63b427f56 --- /dev/null +++ b/areal/controller/utils.py @@ -0,0 +1,95 @@ +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Any, List, Optional + +from requests.exceptions import ConnectionError + +from areal.api.scheduler_api import Scheduler, Worker +from areal.utils import logging +from areal.utils.http import wait_future_ordered + +logger = logging.getLogger("ControllerUtil") + + +def create_engine_with_retry( + create_engine_func, max_retries=60, retry_delay=10, *args, **kwargs +): + """ + Attempts to create an engine with retry logic. + :param create_engine_func: Callable function for creating the engine. + :param max_retries: Maximum number of retries before giving up. + :param retry_delay: Seconds to wait between retries. + :param args: Positional arguments to pass to create_engine_func. + :param kwargs: Keyword arguments to pass to create_engine_func. + :return: Engine instance created by create_engine_func. + :raises RuntimeError: If maximum retries are reached and connection still fails. + """ + logger.info( + f"Create engine with retry: {max_retries}, {retry_delay}, {args}, {kwargs}" + ) + retries = 0 + while retries < max_retries: + try: + return create_engine_func(*args, **kwargs) + except ConnectionError as e: + logger.info( + f"Worker is not ready, exception: {e}, retrying in {retry_delay} seconds..." + ) + time.sleep(retry_delay) + retries += 1 + except Exception as e: + logger.error(f"Connection failed: {e}. unknown exception") + raise e + + raise RuntimeError("Failed to connect to remote service after maximum retries.") + + +def rpc_call( + scheduler: Scheduler, + workers: List[Worker], + method: str, + batches: Optional[List[Any]] = None, + *args, + **kwargs, +) -> List[Any]: + """ + Utility method: Perform concurrent RPC calls to multiple workers. + :param scheduler: Scheduler object with a call_engine(worker_id, method, *args, **kwargs) method. + :param workers: List of worker instances. Each worker must have an 'id' attribute. + :param method: Name of the method to invoke on each worker. + :param batches: Optional list of batches, each batch is passed to the corresponding worker. + If provided, its length must match the number of workers. + :param args: Positional arguments to pass to call_engine. + :param kwargs: Keyword arguments to pass to call_engine. + :return: List of results returned in the order of workers. + :raises ValueError: If the batches parameter is provided but its length does not match the number of workers. + :raises RuntimeError: If any exception occurs during RPC execution. + """ + + if batches is not None and len(batches) != len(workers): + raise ValueError( + f"Batches length ({len(batches)}) must match workers count ({len(workers)})" + ) + logger.info(f"Start to rpc call, method: {method}") + + with ThreadPoolExecutor(max_workers=len(workers)) as executor: + futures = [] + for i, worker in enumerate(workers): + # 构建调用参数 + if batches is not None: + # 当有batch参数时:将batch作为第一位置参数 + worker_args = (batches[i],) + args + future = executor.submit( + scheduler.call_engine, worker.id, method, *worker_args, **kwargs + ) + else: + future = executor.submit( + scheduler.call_engine, worker.id, method, *args, **kwargs + ) + futures.append(future) + try: + results = wait_future_ordered(futures, exit_on_exception=True) + except Exception as e: + raise RuntimeError(f"{method} failed, error: {e}") + + return results diff --git a/areal/engine/base_hf_engine.py b/areal/engine/base_hf_engine.py index f6095a18a..0f25bce47 100644 --- a/areal/engine/base_hf_engine.py +++ b/areal/engine/base_hf_engine.py @@ -73,7 +73,7 @@ def __init__(self, config: TrainEngineConfig): ) self.is_vision_model = is_valid_vision_model(self.model_config.model_type) - self.world_size = int(os.environ["WORLD_SIZE"]) + self.world_size: int def set_version(self, version: int): self._version = version diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index e1a1a6b7f..45eed3159 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -121,9 +121,17 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None self.dp_head = int(self.world_mesh["sp_tp"].mesh[0].item()) self.dp_rank = dist.get_rank(self.dp_group) + self.world_size = int(os.environ["WORLD_SIZE"]) + self.logger.info(f"Data parallel head {self.dp_head} and rank {self.dp_rank}") - def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None): + def initialize( + self, + addr: str | None, + ft_spec: FinetuneSpec | None, + parallel_strategy: ParallelStrategy | None = None, + ): + self.create_process_group(parallel_strategy) # Initialize distributed enviroments and load model. assert addr is None, "FSDPEngine does not support remote initialization." assert ft_spec is not None, "FSDPEngine requires FinetuneSpec to initialize." diff --git a/areal/engine/ppo/actor.py b/areal/engine/ppo/actor.py index 8d6329732..d2122a515 100644 --- a/areal/engine/ppo/actor.py +++ b/areal/engine/ppo/actor.py @@ -67,7 +67,7 @@ def calc_logprobs(logits, input_data): aggregate_fn=lambda xs: torch.cat(xs, dim=-1), ) - def compute_advantages(self, data: Dict[str, Any]) -> None: + def compute_advantages(self, data: Dict[str, Any]) -> Dict[str, Any]: bs = data["input_ids"].shape[0] max_seqlen = data["input_ids"].shape[1] batch_indices = torch.arange( @@ -162,6 +162,8 @@ def compute_advantages(self, data: Dict[str, Any]) -> None: # because we have rolled old_logp by -1 data["logprobs"] = old_logp + return data + def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]: if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: @@ -286,8 +288,8 @@ def compute_logp(self, *args, **kwargs) -> torch.Tensor | None: return self.actor.compute_logp(*args, **kwargs) @torch.no_grad() - def compute_advantages(self, *args, **kwargs) -> None: - self.actor.compute_advantages(*args, **kwargs) + def compute_advantages(self, *args, **kwargs): + return self.actor.compute_advantages(*args, **kwargs) def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]: return self.actor.ppo_update(*args, **kwargs) diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 9104e78e4..8ac075295 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -7,7 +7,7 @@ from concurrent.futures import Future, ProcessPoolExecutor from datetime import datetime from threading import Lock -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import aiohttp import requests @@ -16,7 +16,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from areal.api.cli_args import InferenceEngineConfig -from areal.api.engine_api import InferenceEngine +from areal.api.engine_api import InferenceEngine, Scheduling from areal.api.io_struct import ( ModelRequest, ModelResponse, @@ -45,11 +45,8 @@ def __init__(self, config: InferenceEngineConfig): self.distributed_weight_update_initialized = False self._version = 0 - self.lock = Lock() - self.workflow_executor = WorkflowExecutor( - config=config, - inference_engine=self, - ) + self.lock: Lock + self.workflow_executor: WorkflowExecutor def _wait_for_server(self, address): base_url = f"http://{address}" @@ -74,6 +71,11 @@ def initialize( addr: str | List[str] | None = None, train_data_parallel_size: int | None = None, ): + self.lock = Lock() + self.workflow_executor = WorkflowExecutor( + config=self.config, + inference_engine=self, + ) if engine_id is None: if dist.is_initialized(): engine_id = str(dist.get_rank()) @@ -368,17 +370,28 @@ def callback(fut): def submit( self, - data: Dict[str, Any], + data: Union[Dict[str, Any], List[Dict[str, Any]]], workflow: Optional[RolloutWorkflow] = None, workflow_builder: Optional[Callable] = None, should_accept: Callable | None = None, ) -> None: - return self.workflow_executor.submit( - data, - workflow=workflow, - workflow_builder=workflow_builder, - should_accept=should_accept, - ) + if isinstance(data, Dict): + return self.workflow_executor.submit( + data, + workflow=workflow, + workflow_builder=workflow_builder, + should_accept=should_accept, + ) + else: + for item in data: + self.workflow_executor.submit( + item, + workflow=workflow, + workflow_builder=workflow_builder, + should_accept=should_accept, + ) + return None + def wait(self, count: int, timeout: float | None = None) -> Dict[str, Any]: return self.workflow_executor.wait(count, timeout=timeout) @@ -399,7 +412,7 @@ def rollout_batch( def prepare_batch( self, - dataloader: StatefulDataLoader, + dataloader: Union[StatefulDataLoader, List[Dict[str, Any]]], workflow: Optional[RolloutWorkflow] = None, workflow_builder: Optional[Callable] = None, should_accept: Callable | None = None, @@ -419,6 +432,10 @@ def resume(self): """Resume request submission for async rollout.""" return self.workflow_executor.resume() + def get_scheduling_config(self) -> List[Scheduling]: + # 部署 launcher/sglang_server.py, local_scheduler 注入一个ENGINE_PORTS的端口环境变量,里面有两个端口 + raise NotImplementedError() + def update_weights_from_disk( experiment_name, diff --git a/areal/launcher/sglang_server.py b/areal/launcher/sglang_server.py index f88d0549e..6550a5c85 100644 --- a/areal/launcher/sglang_server.py +++ b/areal/launcher/sglang_server.py @@ -171,7 +171,16 @@ def run(self): server_local_idx * ports_per_server + 10000, (server_local_idx + 1) * ports_per_server + 10000, ) - server_port, dist_init_port = find_free_ports(2, port_range) + engine_ports = os.getenv("ENGINE_PORTS", "") + server_port, dist_init_port = 0, 0 + if engine_ports != "": + ports = engine_ports.split(",") + if len(ports) == 2: + server_port = int(ports[0]) + dist_init_port = int(ports[1]) + print(f"SGLang server get ports from env, engine_ports: {engine_ports}") + else: + server_port, dist_init_port = find_free_ports(2, port_range) if cross_nodes: n_nodes = n_nodes_per_server diff --git a/areal/reward/gsm8k_reward.py b/areal/reward/gsm8k_reward.py new file mode 100644 index 000000000..5a32cecd0 --- /dev/null +++ b/areal/reward/gsm8k_reward.py @@ -0,0 +1,5 @@ +from areal.reward.math_parser import process_results + + +def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs): + return int(process_results(completions, answer)[0]) diff --git a/areal/scheduler/__init__.py b/areal/scheduler/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/areal/scheduler/local.py b/areal/scheduler/local.py new file mode 100644 index 000000000..473089b8e --- /dev/null +++ b/areal/scheduler/local.py @@ -0,0 +1,408 @@ +import getpass +import os +import re +import signal as signal_module +import subprocess +import time +import uuid +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import psutil + +from areal.api.alloc_mode import AllocationMode, AllocationType +from areal.api.cli_args import ( + ClusterSpecConfig, + LauncherConfig, + RecoverConfig, + SGLangConfig, + to_structured_cfg, +) +from areal.api.scheduler_api import Scheduler, Worker +from areal.platforms import current_platform +from areal.scheduler.rpc.rpc_client import RPCClient +from areal.scheduler.rpc.rpc_server import build_rpc_server_start_command +from areal.utils import logging, name_resolve, names +from areal.utils.launcher import JobException, JobInfo, JobState, get_env_vars +from areal.utils.network import find_free_ports, gethostip +from areal.utils.recover import check_if_recover + +logger = logging.getLogger("LocalScheduler") +JOB_STATE_TO_PROCESS_STATUS = { + JobState.NOT_FOUND: [], + JobState.PENDING: [psutil.STATUS_PARKED], + JobState.RUNNING: [ + psutil.STATUS_RUNNING, + psutil.STATUS_SLEEPING, + psutil.STATUS_DISK_SLEEP, + psutil.STATUS_TRACING_STOP, + psutil.STATUS_WAKING, + psutil.STATUS_WAITING, + psutil.STATUS_LOCKED, + psutil.STATUS_IDLE, + ], + JobState.COMPLETED: [ + psutil.STATUS_DEAD, + psutil.STATUS_STOPPED, + psutil.STATUS_ZOMBIE, + ], + JobState.FAILED: [], + JobState.CANCELLED: [], +} +RECOVER_TIME_INTERVAL = 10 # seconds + +PROCESS_STATUS_TO_JOB_STATE = {} +for job_state, process_statuses in JOB_STATE_TO_PROCESS_STATUS.items(): + for process_status in process_statuses: + PROCESS_STATUS_TO_JOB_STATE[process_status] = job_state + + +def terminate_process_and_children(pid: int, signal: Optional[Union[str, int]] = None): + if signal is None: + signal = signal_module.SIGKILL + if isinstance(signal, str): + signal = getattr(signal_module, signal) + try: + parent = psutil.Process(pid) + children = parent.children(recursive=True) + for child in children: + terminate_process_and_children(child.pid) + parent.send_signal(signal) + except psutil.NoSuchProcess: + pass + + +class LocalLauncher: + def __init__(self, experiment_name: str, trial_name: str, fileroot: str): + self.experiment_name = experiment_name + self.trial_name = trial_name + self.fileroot = fileroot + + self._jobs: Dict[str, subprocess.Popen] = {} + self._job_counter: Dict[str, int] = defaultdict(int) + self._job_states = {} + + self._gpu_counter = 0 + self._gpu_devices: List[str] = os.environ.get( + current_platform.device_control_env_var, + ",".join(map(str, range(current_platform.device_count()))), + ).split(",") + if len(self._gpu_devices) < 1: + raise RuntimeError( + f"Local mode can only run when there is at least one GPU. " + f"{current_platform.device_control_env_var} is currently" + f" set to: `{os.environ.get(current_platform.device_control_env_var, '')}`." + ) + + @property + def run_name(self): + return f"{self.experiment_name}_{self.trial_name}" + + def log_path_of(self, job_name: str) -> str: + log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}" + os.makedirs(log_path, exist_ok=True) + return os.path.join(log_path, f"{job_name}.log") + + def __del__(self): + self.wait() + + def submit_array( + self, + job_name: str, + cmd: str | List[str], + count: int = 1, + gpu: int = 0, + env_vars: Optional[Dict] = None, + ): + if env_vars is None: + env_vars = {} + if not isinstance(cmd, list): + cmd = [cmd] * count + offset = self._job_counter[job_name] + for i in range(count): + if gpu > 0: + # Allocate GPUs in a round-robin manner + visible_devices = [] + for _ in range(gpu): + available_device_id = self._gpu_counter % len(self._gpu_devices) + self._gpu_counter += 1 + visible_devices.append(available_device_id) + env_vars[current_platform.device_control_env_var] = ",".join( + str(self._gpu_devices[j]) for j in visible_devices + ) + c = ( + " ".join(str(k) + "=" + str(v) for k, v in env_vars.items()) + + " stdbuf -oL " + + cmd[i] + ) + c = f"{c} 2>&1 | tee -a {self.log_path_of(job_name)}" + logger.info("Starting local process with command: %s", c) + process = subprocess.Popen(c, shell=isinstance(c, str)) + self._jobs[f"{job_name}/{offset + i}"] = process + self._job_counter[job_name] += 1 + + def submit( + self, + job_name: str, + cmd: str | List[str], + gpu: int = 0, + env_vars: Optional[Dict] = None, + ): + self.submit_array(job_name=job_name, cmd=cmd, gpu=gpu, env_vars=env_vars) + + def stop(self, job_name, signal=None): + assert any(k.startswith(job_name) for k in self._jobs) + keys = [k for k, p in self._jobs.items() if k.startswith(job_name)] + procs = [p for k, p in self._jobs.items() if k.startswith(job_name)] + logger.info( + f"Stopping local process with signal {signal if signal else 'SIGKILL'}, " + f"pid: {[p.pid for p in procs]}" + ) + for p in procs: + terminate_process_and_children(p.pid, signal=signal) + for p in procs: + p.wait() + for k, p in zip(keys, procs): + self._jobs.pop(k) + del p + + def stop_all(self, signal=None): + # signal argument is ignored in local stop_all + for name in self._job_counter: + self.stop(name, signal=signal) + + def find(self, job_name): + if job_name in self._jobs: + return JobInfo(name=job_name, state=JobState.RUNNING, host="localhost") + else: + return JobInfo(name=job_name, state=JobState.NOT_FOUND) + + def find_all(self, job_name_regex=".*"): + rs = [] + for name in self._jobs: + if re.fullmatch(job_name_regex, name): + rs.append(self.find(name)) + return rs + + def wait( + self, + timeout=None, + check_status: Tuple[JobState, ...] = ( + JobState.CANCELLED, + JobState.FAILED, + JobState.NOT_FOUND, + ), + remove_status: Tuple[JobState, ...] = (JobState.COMPLETED,), + update=False, + ): + deadline = None if timeout is None else time.time() + timeout + logger.info( + "Waiting for %d local running processes, pids: %s", + len(self._jobs), + " ".join(str(job.pid) for job in self._jobs.values()), + ) + left = set(self._jobs.keys()) + num_jobs_left = len(left) + + while len(left) > 0: + to_remove = [] + if len(left) < num_jobs_left: + num_jobs_left = len(left) + logger.info(f"Waiting for {num_jobs_left} jobs.") + if deadline is not None and time.time() > deadline: + raise TimeoutError( + f"Timeout waiting for {self.run_name}: {', '.join(sorted(left))}" + ) + # update job states + for job_name in list(left): + job = self._jobs[job_name] + pid = job.pid + try: + process = psutil.Process(pid) + self._job_states[job_name] = PROCESS_STATUS_TO_JOB_STATE.get( + process.status(), JobState.NOT_FOUND + ) + except psutil.NoSuchProcess: + self._job_states[job_name] = JobState.NOT_FOUND + + for job_name in list(left): + state = self._job_states[job_name] + if state in check_status: + raise JobException( + run_name=self.run_name, + worker_type=job_name.split("/")[0], + host="local", + reason=state, + ) + if state in remove_status: + logger.info(f"Job {job_name} is {state}.(Removed)") + left.remove(job_name) + to_remove.append(job_name) + + if update: + for k in to_remove: + self._jobs.pop(k) + worker_type = k.split("/")[0] + assert worker_type in self._job_counter + self._job_counter[worker_type] -= 1 + if self._job_counter[worker_type] <= 0: + self._job_counter.pop(worker_type) + + time.sleep(2) + + +class LocalScheduler(Scheduler): + def __init__(self, config): + self.procs = [] # Store subprocess objects + self.engine_workers: Dict[str, List[str]] = defaultdict( + list + ) # role -> [worker_id] + self.rpc_client = RPCClient() + self.launcher = LocalLauncher( + config.experiment_name, config.trial_name, config.cluster.fileroot + ) + + def create_workers(self, worker_role, config, *args, **kwargs) -> None: + config.launcher = to_structured_cfg(config.launcher, LauncherConfig) + config.recover = to_structured_cfg(config.recover, RecoverConfig) + config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig) + is_recover_run = check_if_recover(config.recover, run_id=0) + + name_resolve.reconfigure(config.cluster.name_resolve) + name_resolve.clear_subtree( + names.trial_root( + experiment_name=config.experiment_name, trial_name=config.trial_name + ) + ) + alloc_mode = AllocationMode.from_str(config.allocation_mode) + logger.info( + f"experiment_name={config.experiment_name}, " + f"trial_name={config.trial_name}, fileroot={config.cluster.fileroot}, " + f"is_recover_run={is_recover_run}" + ) + + server_cmd = [] + server_addrs = [] + if worker_role == "rollout": + if alloc_mode.gen_backend == "sglang": + # launch sglang servers + base_seed = config.sglang.random_seed + config.sglang = to_structured_cfg(config.sglang, SGLangConfig) + # each sglang need 2 ports + ports = find_free_ports( + alloc_mode.gen.dp_size * 2, port_range=(10000, 50000) + ) + host_ip = gethostip() + host = "localhost" if not config.sglang.enable_metrics else host_ip + for i in range(alloc_mode.gen.dp_size): + config.sglang.random_seed = base_seed + i + cmd = SGLangConfig.build_cmd( + config.sglang, + host=host, + tp_size=alloc_mode.gen.tp_size, + base_gpu_id=0, + port=ports[i * 2], + dist_init_addr=f"localhost:{ports[i*2+1]}", + ) + server_cmd.append(cmd) + server_addrs.append(f"{host}:{ports[i * 2]}") + + # Launch inference servers. + self.launcher.submit_array( + job_name="llm_server", + cmd=server_cmd, + count=alloc_mode.gen.dp_size, + gpu=alloc_mode.gen.pp_size * alloc_mode.gen.tp_size, + env_vars=get_env_vars( + config.cluster.cluster_name, + config.launcher.inference_server_env_vars, + ), + ) + logger.info( + f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}" + ) + + # create rpc server workers + worker_ports = find_free_ports( + alloc_mode.gen.world_size, port_range=(10000, 50000) + ) # each sglang need 2 ports + for i in range(alloc_mode.gen.world_size): + cmd = build_rpc_server_start_command(worker_ports[i]) + + self.launcher.submit( + job_name="rollout_worker", + cmd=cmd, + gpu=0, + env_vars=dict( + **get_env_vars( + config.cluster.cluster_name, + # config.launcher.worker_env_vars, + ), + AREAL_LLM_SERVER_ADDRS=server_addrs[ + i % alloc_mode.gen.dp_size + ], + AREAL_RECOVER_RUN=str(int(is_recover_run)), + ), + ) + + logger.info( + f"RPC server for rollout worker launched at port: {worker_ports[i]}" + ) + + worker_id = f"rollout_{i}_{uuid.uuid4().hex[:8]}" + self.rpc_client.register(worker_id, "localhost", worker_ports[i]) + self.engine_workers.setdefault(worker_role, []).append(worker_id) + + else: + raise NotImplementedError(f"Unsupported allocation mode: {alloc_mode}") + elif worker_role == "actor": + if alloc_mode.type_ == AllocationType.DECOUPLED_EVAL: + gpu = 0 + nprocs = 1 + else: + gpu = nprocs = alloc_mode.train.world_size + + worker_ports = find_free_ports(alloc_mode.gen.world_size, (10000, 50000)) + + self.launcher.submit( + job_name="trainer", + cmd=f"torchrun --nnodes 1 --nproc-per-node {nprocs} " + f"--master-addr localhost --master-port {find_free_ports(1, (10000, 50000))[0]} " + f"-m areal.scheduler.rpc.rpc_server --rpc_ports {','.join(map(str, worker_ports))}", + gpu=gpu, + env_vars=dict( + **get_env_vars( + config.cluster.cluster_name, + config.launcher.trainer_env_vars, + ), + # AREAL_LLM_SERVER_ADDRS=",".join(server_addrs), # not need? + AREAL_RECOVER_RUN=str(int(is_recover_run)), + ), + ) + + for i in range(alloc_mode.gen.world_size): + worker_id = f"actor_{i}_{uuid.uuid4().hex[:8]}" + self.rpc_client.register(worker_id, "localhost", worker_ports[i]) + self.engine_workers.setdefault(worker_role, []).append(worker_id) + else: + raise ValueError(f"Unknown worker role: {worker_role}") + + def get_workers(self, worker_role, timeout: float = 60.0) -> List[Worker]: + workers = [] + for worker_id in self.engine_workers.get(worker_role, []): + ip, port = self.rpc_client.get_info(worker_id) + worker = Worker(id=worker_id, ip=ip, ports=[str(port)]) + workers.append(worker) + return workers + + def delete_workers(self): + raise NotImplementedError("LocalScheduler does not support delete_workers") + + # Other methods remain the same + def create_engine(self, worker_id, engine_obj, *args, **kwargs): + # launch engine rpc server on the worker + self.rpc_client.create_engine(worker_id, engine_obj, *args, **kwargs) + + def call_engine(self, worker_id, method, *args, **kwargs): + ret = self.rpc_client.call_engine(worker_id, method, 3, *args, **kwargs) + return ret diff --git a/areal/scheduler/rpc/rpc_client.py b/areal/scheduler/rpc/rpc_client.py index 28f4b8082..b25c6b120 100644 --- a/areal/scheduler/rpc/rpc_client.py +++ b/areal/scheduler/rpc/rpc_client.py @@ -6,7 +6,6 @@ import cloudpickle import requests -from areal.api.cli_args import InferenceEngineConfig, TrainEngineConfig from areal.api.engine_api import InferenceEngine, TrainEngine from areal.utils import logging from areal.utils.http import response_ok, response_retryable @@ -22,16 +21,20 @@ def register(self, worker_id: str, ip: str, port: int) -> None: self._addrs[worker_id] = (ip, port) logger.info(f"Registered worker {worker_id} at {ip}:{port}") + def get_info(self, worker_id: str) -> tuple[str, int]: + return self._addrs[worker_id] + def create_engine( self, worker_id: str, engine_obj: Union[InferenceEngine, TrainEngine], - init_config: Union[InferenceEngineConfig, TrainEngineConfig], + *args, + **kwargs, ) -> None: ip, port = self._addrs[worker_id] url = f"http://{ip}:{port}/create_engine" logger.info(f"send create_engine to {worker_id} ({ip}:{port})") - payload = (engine_obj, init_config) + payload = (engine_obj, args, kwargs) serialized_data = cloudpickle.dumps(payload) serialized_obj = gzip.compress(serialized_data) resp = requests.post(url, data=serialized_obj) @@ -48,7 +51,7 @@ def create_engine( ) def call_engine( - self, worker_id: str, method: str, max_retries: int = 3, *args, **kwargs + self, worker_id: str, method: str, max_retries: int, *args, **kwargs ) -> Any: """ call the rpc server with method name and args, retry on failure diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index b2bc3d612..4ac8ef0f7 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -3,20 +3,46 @@ import os import traceback from http import HTTPStatus -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from typing import AnyStr +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any, AnyStr, Dict, List import cloudpickle +import torch from tensordict import TensorDict from areal.api.controller_api import DistributedBatch +from areal.api.engine_api import InferenceEngine from areal.controller.batch import DistributedBatchMemory from areal.utils import logging logger = logging.getLogger("RPCServer") -def process_input_to_distributed_batch(*args, **kwargs): +def tensor_container_to_safe( + d: Dict[str, Any] | torch.Tensor | List[torch.Tensor], *args, **kwargs +): + """Apply `t.to(*args, **kwargs)` to all tensors in the dictionary. + Support nested dictionaries. + """ + new_dict = {} + if torch.is_tensor(d): + return d.to(*args, **kwargs) + elif isinstance(d, list): + return [tensor_container_to_safe(v, *args, **kwargs) for v in d] + elif isinstance(d, dict): + for key, value in d.items(): + if isinstance(value, dict) or isinstance(value, list): + new_dict[key] = tensor_container_to_safe(value, *args, **kwargs) + elif torch.is_tensor(value): + new_dict[key] = value.to(*args, **kwargs) + else: + new_dict[key] = value + return new_dict + else: + return d + + +def process_input_to_distributed_batch(to_device, *args, **kwargs): for i in range(len(args)): if isinstance(args[i], DistributedBatch): args = list(args) @@ -27,10 +53,14 @@ def process_input_to_distributed_batch(*args, **kwargs): if isinstance(kwargs[k], DistributedBatch): kwargs[k] = kwargs[k].get_data() + args = tuple(tensor_container_to_safe(list(args), to_device)) + kwargs = tensor_container_to_safe(kwargs, to_device) + return args, kwargs def process_output_to_distributed_batch(result): + result = tensor_container_to_safe(result, "cpu") if isinstance(result, dict): return DistributedBatchMemory.from_dict(result) elif isinstance(result, TensorDict): @@ -76,9 +106,9 @@ def do_POST(self): try: if self.path == "/create_engine": decompressed_data = gzip.decompress(data) - engine_obj, init_args = cloudpickle.loads(decompressed_data) + engine_obj, args, kwargs = cloudpickle.loads(decompressed_data) EngineRPCServer.engine = engine_obj - result = EngineRPCServer.engine.initialize(init_args) + result = EngineRPCServer.engine.initialize(*args, **kwargs) logger.info(f"Engine created and initialized on RPC server: {result}") self.send_response(HTTPStatus.OK) self.end_headers() @@ -93,8 +123,14 @@ def do_POST(self): action, args, kwargs = cloudpickle.loads(data) method = getattr(EngineRPCServer.engine, action) # NOTE: DO NOT print args here, args may be a very huge tensor - logger.info(f"RPC server calling engine method: {action}") - args, kwargs = process_input_to_distributed_batch(*args, **kwargs) + if isinstance(EngineRPCServer.engine, InferenceEngine): + device = "cpu" + else: # actor + device = EngineRPCServer.engine.device + + args, kwargs = process_input_to_distributed_batch( + device, *args, **kwargs + ) result = method(*args, **kwargs) result = process_output_to_distributed_batch(result) self.send_response(HTTPStatus.OK) @@ -113,36 +149,36 @@ def do_POST(self): def start_rpc_server(port): - server = ThreadingHTTPServer(("0.0.0.0", port), EngineRPCServer) + # NOTE: We must use HTTPServer rather than ThreadingHTTPServer here, since the rank and device info + # of pytorch is thread level, if use ThreadingHTTPServer, the device set by create_engine thread + # will not be seen by call_engine thread. + # server = ThreadingHTTPServer(("0.0.0.0", port), EngineRPCServer) + server = HTTPServer(("0.0.0.0", port), EngineRPCServer) server.serve_forever() -def get_serve_port(args): - port = args.port - port_str = os.environ.get("PORT_LIST", "").strip() - - # Check if PORT_LIST is set - if port_str: - # Split by comma and strip whitespace - ports = [p.strip() for p in port_str.split(",")] - # Use the first valid port from the list - if ports and ports[0]: - try: - return int(ports[0]) - except ValueError: - logger.warning( - f"Invalid port '{ports[0]}' in PORT_LIST. Falling back to --port argument." - ) - return port +def get_server_ports(ports_str: str) -> int: + ports = [p.strip() for p in ports_str.split(",")] + word_size = int(os.environ.get("WORLD_SIZE", "1")) + rank = int(os.environ.get("RANK", "0")) + if len(ports) < word_size: + raise ValueError( + f"Not enough ports for the world size {word_size}, got {ports_str}" + ) + return int(ports[rank]) + + +def build_rpc_server_start_command(port): + return f"python3 -m areal.scheduler.rpc.rpc_server --rpc_ports {port}" if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--port", type=int, required=False) + parser.add_argument("--rpc_ports", type=str, required=True) args, unknown = parser.parse_known_args() - port = get_serve_port(args) + port = get_server_ports(args.rpc_ports) logger.info(f"About to start RPC server on {port}") diff --git a/areal/scheduler/test_local.py b/areal/scheduler/test_local.py new file mode 100644 index 000000000..564038bc6 --- /dev/null +++ b/areal/scheduler/test_local.py @@ -0,0 +1,226 @@ +import os +import sys +import time +from concurrent.futures import ThreadPoolExecutor + +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import ( + GRPOConfig, + parse_cli_args, + to_structured_cfg, +) +from areal.api.io_struct import FinetuneSpec +from areal.dataset import get_custom_dataset +from areal.engine.ppo.actor import FSDPPPOActor +from areal.engine.sglang_remote import RemoteSGLangEngine +from areal.scheduler.local import LocalScheduler +from areal.utils import name_resolve +from areal.utils.data import ( + cycle_dataloader, +) +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.stats_logger import StatsLogger +from areal.workflow.rlvr import RLVRWorkflow + +# init_config = {} + +create_workers_config, _ = parse_cli_args(sys.argv[1:]) + +from omegaconf import OmegaConf + +# config, _ = load_expr_config(sys.argv[1:]) +config = to_structured_cfg(create_workers_config, config_cls=GRPOConfig) +config = OmegaConf.to_object(config) +name_resolve.reconfigure(config.cluster.name_resolve) +config: GRPOConfig +# seeding.set_random_seed(config.seed, key=f"trainer{rank}") +allocation_mode = AllocationMode.from_str(config.allocation_mode) +parallel_strategy = allocation_mode.train + + +shcheduler = LocalScheduler(create_workers_config) +shcheduler.create_workers("rollout", create_workers_config) +shcheduler.create_workers("actor", create_workers_config) + +rollout_workers = shcheduler.get_workers("rollout", timeout=300) +actor_workers = shcheduler.get_workers("actor", timeout=300) + +print("[wht debug] rollout workers:", rollout_workers) +print("[wht debug] actor workers:", actor_workers) + +time.sleep(20) + + +rollout = RemoteSGLangEngine(config.rollout) +with ThreadPoolExecutor(max_workers=len(rollout_workers)) as executor: + + def create_engine_and_init(worker_id): + print(f"[wht debug] start create rollout engine and init {worker_id}") + shcheduler.create_engine( + worker_id, rollout, train_data_parallel_size=parallel_strategy.dp_size + ) + print(f"[wht debug] end create rollout engine and init {worker_id}") + + futures = [] + for i in range(len(rollout_workers)): + futures.append(executor.submit(create_engine_and_init, rollout_workers[i].id)) + + for future in futures: + future.result() + +ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=1024, # dummy value + train_batch_size=config.train_dataset.batch_size, +) + +actor = FSDPPPOActor(config=config.actor) +with ThreadPoolExecutor(max_workers=len(actor_workers)) as executor: + + def create_engine_and_init(worker_id): + print(f"[wht debug] start create actor engine and init {worker_id}") + shcheduler.create_engine( + worker_id, actor, None, ft_spec, parallel_strategy=parallel_strategy + ) + print(f"[wht debug] end create actor engine and init {worker_id}") + + futures = [] + for i in range(len(actor_workers)): + futures.append(executor.submit(create_engine_and_init, actor_workers[i].id)) + + for future in futures: + future.result() + +print("[wht debug] all engines created and initialized.") + + +tokenizer = load_hf_tokenizer(config.tokenizer_path) +train_dataset = get_custom_dataset( + path=config.train_dataset.path, + rank=0, + world_size=1, + split="train", + max_length=config.train_dataset.max_length, + type=config.train_dataset.type, + tokenizer=tokenizer, +) +train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=config.train_dataset.batch_size, + shuffle=config.train_dataset.shuffle, + num_workers=config.train_dataset.num_workers, + collate_fn=lambda x: x, + drop_last=config.train_dataset.drop_last, +) +data_generator = cycle_dataloader(train_dataloader) +data = next(data_generator) + +print(f"[wht debug] get data batch: {data[0]}") + +from areal.reward.gsm8k_reward import gsm8k_reward_fn + +workflow = RLVRWorkflow( + reward_fn=gsm8k_reward_fn, + gconfig=config.gconfig, + tokenizer=tokenizer, + enable_thinking=False, + dump_dir=os.path.join(StatsLogger.get_log_path(config.stats_logger), "generated"), +) + +batch = None +with ThreadPoolExecutor(max_workers=len(rollout_workers)) as executor: + + def call_rollout(worker_id, data): + try: + batch = shcheduler.call_engine( + worker_id, + "rollout_batch", + data, + workflow=workflow, + should_accept=lambda sample: True, + ) + print(f"[wht debug] rollout {worker_id} done, got batch: {batch}") + return batch + except Exception as e: + print(f"[wht debug] rollout {worker_id} failed, error: {e}") + raise e + + futures = [] + for i in range(len(rollout_workers)): + futures.append(executor.submit(call_rollout, rollout_workers[i].id, data)) + for future in futures: + r = future.result() + print(f"[wht debug] rollout result: {r}") + batch = r + +print("[wht debug] all rollout done.") + +assert not config.actor.use_decoupled_loss and not config.actor.recompute_logprob + +with ThreadPoolExecutor(max_workers=len(actor_workers)) as executor: + + def call_compute_advantages(worker_id, data): + try: + batch = shcheduler.call_engine(worker_id, "compute_advantages", data) + print( + f"[wht debug] compute_advantages {worker_id} done, got batch: {batch}" + ) + return batch + except Exception as e: + print(f"[wht debug] compute_advantages {worker_id} failed, error: {e}") + raise e + + futures = [] + for i in range(len(actor_workers)): + futures.append( + executor.submit(call_compute_advantages, actor_workers[i].id, batch) + ) + for future in futures: + r = future.result() + print(f"[wht debug] compute_advantages result: {r}") + batch = r + +print("[wht debug] all compute_advantages done.") + +with ThreadPoolExecutor(max_workers=len(actor_workers)) as executor: + + def call_ppo_update(worker_id, data): + try: + batch = shcheduler.call_engine(worker_id, "ppo_update", data) + print(f"[wht debug] ppo_update {worker_id} done, got batch: {batch}") + return batch + except Exception as e: + print(f"[wht debug] ppo_update {worker_id} failed, error: {e}") + raise e + + futures = [] + for i in range(len(actor_workers)): + futures.append(executor.submit(call_ppo_update, actor_workers[i].id, batch)) + + for future in futures: + r = future.result() + print(f"[wht debug] ppo_update result: {r}") + +print("[wht debug] all ppo_update done.") + +with ThreadPoolExecutor(max_workers=len(actor_workers)) as executor: + + def call_step_lr_scheduler(worker_id): + try: + res = shcheduler.call_engine(worker_id, "step_lr_scheduler") + print(f"[wht debug] step_lr_scheduler {worker_id} done, got res: {res}") + return res + except Exception as e: + print(f"[wht debug] step_lr_scheduler {worker_id} failed, error: {e}") + raise e + + futures = [] + for i in range(len(actor_workers)): + futures.append(executor.submit(call_step_lr_scheduler, actor_workers[i].id)) + for future in futures: + r = future.result() + print(f"[wht debug] step_lr_scheduler result: {r}") + +print("[wht debug] all step_lr_scheduler done.") diff --git a/areal/tests/test_rpc.py b/areal/tests/test_rpc.py index 2f5ab493a..58590a407 100644 --- a/areal/tests/test_rpc.py +++ b/areal/tests/test_rpc.py @@ -16,7 +16,7 @@ from areal.scheduler.rpc.rpc_client import RPCClient from areal.scheduler.rpc.rpc_server import ( EngineRPCServer, - get_serve_port, + get_server_ports, process_input_to_distributed_batch, process_output_to_distributed_batch, start_rpc_server, @@ -175,61 +175,53 @@ def test_process_output_to_distributed_batch_other_types(): def test_get_serve_port_from_args(): """Test getting port from command line arguments""" mock_args = Mock() - mock_args.port = 8080 + mock_args.rpc_port = "8080" with patch.dict("os.environ", {}, clear=True): - port = get_serve_port(mock_args) + port = get_server_ports(mock_args.rpc_port) assert port == 8080 -def test_get_serve_port_from_env_single_port(): +def test_get_server_ports_default_from_multi_ports(): """Test getting single port from PORT_LIST environment variable""" mock_args = Mock() - mock_args.port = 8080 + mock_args.rpc_port = "8080,8081,8082,8083" - with patch.dict("os.environ", {"PORT_LIST": "9000"}): - port = get_serve_port(mock_args) - assert port == 9000 - - -def test_get_serve_port_from_env_multiple_ports(): - """Test getting first port from multiple ports in PORT_LIST environment variable""" - mock_args = Mock() - mock_args.port = 8080 - - with patch.dict("os.environ", {"PORT_LIST": "9000, 9001, 9002"}): - port = get_serve_port(mock_args) - assert port == 9000 + with patch.dict("os.environ", {}, clear=True): + port = get_server_ports(mock_args.rpc_port) + assert port == 8080 -def test_get_serve_port_invalid_env_port(): - """Test fallback when PORT_LIST contains invalid ports""" +def test_get_serve_port_from_multi_ports(): + """Test getting single port from PORT_LIST environment variable""" mock_args = Mock() - mock_args.port = 8080 + mock_args.rpc_port = "8080,8081,8082,8083" - with patch.dict("os.environ", {"PORT_LIST": "invalid_port, 9001"}): - port = get_serve_port(mock_args) + with patch.dict("os.environ", {"WORLD_SIZE": "4", "RANK": "0"}): + port = get_server_ports(mock_args.rpc_port) assert port == 8080 + with patch.dict("os.environ", {"WORLD_SIZE": "4", "RANK": "1"}): + port = get_server_ports(mock_args.rpc_port) + assert port == 8081 -def test_get_serve_port_empty_env(): - """Test fallback when PORT_LIST is empty""" - mock_args = Mock() - mock_args.port = 8080 + with patch.dict("os.environ", {"WORLD_SIZE": "4", "RANK": "2"}): + port = get_server_ports(mock_args.rpc_port) + assert port == 8082 - with patch.dict("os.environ", {"PORT_LIST": ""}): - port = get_serve_port(mock_args) - assert port == 8080 + with patch.dict("os.environ", {"WORLD_SIZE": "4", "RANK": "3"}): + port = get_server_ports(mock_args.rpc_port) + assert port == 8083 -def test_get_serve_port_whitespace_env(): - """Test fallback when PORT_LIST contains only whitespace""" +def test_get_serve_port_not_enough_ports(): + """Test error when not enough ports for WORLD_SIZE""" mock_args = Mock() - mock_args.port = 8080 + mock_args.rpc_port = "8080,8081" - with patch.dict("os.environ", {"PORT_LIST": " "}): - port = get_serve_port(mock_args) - assert port == 8080 + with patch.dict("os.environ", {"WORLD_SIZE": "4", "RANK": "0"}): + with pytest.raises(ValueError, match="Not enough ports for the world size"): + get_server_ports(mock_args.rpc_port) # RPC client and server integration tests diff --git a/areal/utils/http.py b/areal/utils/http.py index 5dc88f1df..140e7e474 100644 --- a/areal/utils/http.py +++ b/areal/utils/http.py @@ -1,6 +1,10 @@ import asyncio +import os +import signal +import traceback +from concurrent.futures import Future, as_completed from http import HTTPStatus -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import aiohttp @@ -96,3 +100,31 @@ def response_ok(http_code: int) -> bool: def response_retryable(http_code: int) -> bool: return http_code == HTTPStatus.REQUEST_TIMEOUT + + +def wait_future_ordered( + futures: List[Future], exit_on_exception: bool = False +) -> List[Any]: + """ + Waits for a list of futures to complete and returns the results in the order the futures were submitted. + :param futures: List of Future objects to wait for. + :param exit_on_exception: If True, terminate the process upon an exception in any future. + If False, raise the exception. + :return: List of results in the same order as the input futures. + :raises Exception: If exit_on_exception is False and any future raises an exception. + """ + results = [None] * len(futures) + future_index_map = {future: i for i, future in enumerate(futures)} + for future in as_completed(futures): + index = future_index_map[future] + try: + results[index] = future.result() + except Exception as e: + logger.warning(f"Exception caught when waiting for future: {e}") + logger.warning(traceback.format_exc()) + if exit_on_exception: + logger.info("Exiting due to exception in future.") + os.kill(os.getpid(), signal.SIGTERM) + else: + raise e + return results diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index 34a6994ae..de3b26ff8 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -52,13 +52,15 @@ def __init__( self.enable_thinking = enable_thinking self.dump_dir = dump_dir self.rollout_stat_scope = rollout_stat_scope - self.async_reward_fn = AsyncRewardWrapper(reward_fn) + self.async_reward_fn = None self.get_input_ids_fn = get_input_ids_fn self.data_extract_prompt_fn = data_extract_prompt_fn if self.dump_dir is not None and not os.path.exists(self.dump_dir): os.makedirs(self.dump_dir, exist_ok=True) async def arun_episode(self, engine: InferenceEngine, data): + if self.async_reward_fn is None: + self.async_reward_fn = AsyncRewardWrapper(self.reward_fn) input_ids = self.get_input_ids_fn( self.data_extract_prompt_fn(data), self.tokenizer, self.enable_thinking ) diff --git a/examples/math/gsm8k_grpo_single_controller.yaml b/examples/math/gsm8k_grpo_single_controller.yaml new file mode 100644 index 000000000..6cf4387c0 --- /dev/null +++ b/examples/math/gsm8k_grpo_single_controller.yaml @@ -0,0 +1,153 @@ +experiment_name: gsm8k-grpo +trial_name: trial0 + +seed: 1 +total_train_epochs: 10 +tokenizer_path: ${actor.path} +async_training: true + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +allocation_mode: sglang.d4p1t1+d4p1t1 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-1.5B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + backend: fsdp + group_size: ${gconfig.n_samples} + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: false + use_decoupled_loss: false + behav_imp_weight_cap: 5.0 + dynamic_sampling: false + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + backend: fsdp + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +# datasets +train_dataset: + batch_size: 8 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +launcher: + inference_server_cpus_per_gpu: 4 + inference_server_mem_per_gpu: 32768 + trainer_cpus_per_gpu: 4 + trainer_mem_per_gpu: 32768