From edf7b5db2cc0a0daa39e6509bbcea8272f3b737f Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 22 Oct 2025 21:47:26 +0800 Subject: [PATCH 01/52] . --- .pre-commit-config.yaml | 6 +- areal/core/__init__.py | 9 + areal/core/local_inf_engine.py | 598 ++++++++++++++++++++++++ areal/engine/sglang_local.py | 321 +++++++++++++ areal/engine/vllm_local.py | 342 ++++++++++++++ areal/tests/test_local_sglang_engine.py | 174 +++++++ areal/tests/test_local_vllm_engine.py | 88 ++++ pyproject.toml | 6 + 8 files changed, 1541 insertions(+), 3 deletions(-) create mode 100644 areal/core/local_inf_engine.py create mode 100644 areal/engine/sglang_local.py create mode 100644 areal/engine/vllm_local.py create mode 100644 areal/tests/test_local_sglang_engine.py create mode 100644 areal/tests/test_local_vllm_engine.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 56ab76d9f..9c0712655 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,13 +39,13 @@ repos: # Ruff version. rev: v0.14.1 hooks: + - id: ruff-format # Run the formatter. + name: Run Formatter (Ruff) + types_or: [ python, pyi, jupyter ] - id: ruff # Run the linter. name: Run Linter Check (Ruff) types_or: [ python, pyi, jupyter ] args: [ --fix ] - - id: ruff-format # Run the formatter. - name: Run Formatter (Ruff) - types_or: [ python, pyi, jupyter ] # Clean notebook outputs and metadata - repo: https://github.com/kynan/nbstripout diff --git a/areal/core/__init__.py b/areal/core/__init__.py index 28fa01204..741b4d687 100644 --- a/areal/core/__init__.py +++ b/areal/core/__init__.py @@ -1,5 +1,11 @@ """Core components for AREAL.""" +from __future__ import annotations + +from .local_inf_engine import ( + LocalInfBackendProtocol, + LocalInfEngine, +) from .remote_inf_engine import ( RemoteInfBackendProtocol, RemoteInfEngine, @@ -10,7 +16,10 @@ check_trajectory_format, ) + __all__ = [ + "LocalInfBackendProtocol", + "LocalInfEngine", "RemoteInfBackendProtocol", "RemoteInfEngine", "StalenessManager", diff --git a/areal/core/local_inf_engine.py b/areal/core/local_inf_engine.py new file mode 100644 index 000000000..519d74cba --- /dev/null +++ b/areal/core/local_inf_engine.py @@ -0,0 +1,598 @@ +from __future__ import annotations + +import asyncio +import time +import uuid +from collections.abc import Callable +from concurrent.futures import Future, ThreadPoolExecutor +from threading import Lock +from typing import Any, Protocol + +import torch.distributed as dist +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.cli_args import InferenceEngineConfig +from areal.api.io_struct import ( + ModelRequest, + ModelResponse, + ParamSpec, + WeightUpdateMeta, +) +from areal.api.workflow_api import RolloutWorkflow +from areal.platforms import current_platform +from areal.utils import logging, name_resolve, names + +from .workflow_executor import WorkflowExecutor + + +class LocalInfBackendProtocol(Protocol): + """Protocol defining backend-specific operations for local inference engines. + + This protocol abstracts the differences between various local inference engines + (SGLang, vLLM, etc.) by defining a common interface for: + - Creating and managing local engine instances + - Performing async generation + - Handling weight updates (both disk and distributed) + - Managing engine lifecycle + + Implementations can raise NotImplementedError for unsupported features. + """ + + def create_engine(self, engine_args: dict[str, Any]) -> Any: + """Create a local inference engine instance. + + Parameters + ---------- + engine_args : Dict[str, Any] + Arguments to pass to the engine constructor + + Returns + ------- + Any + The created engine instance + """ + ... + + async def async_generation(self, engine: Any, req: ModelRequest) -> ModelResponse: + """Perform async generation using the local engine. + + Parameters + ---------- + engine : Any + The engine instance + req : ModelRequest + The generation request containing input and parameters + + Returns + ------- + ModelResponse + The generated response with tokens, logprobs, and metadata + """ + ... + + def update_weight_disk(self, engine: Any, model_path: str) -> None: + """Update weights from disk synchronously. + + Parameters + ---------- + engine : Any + The engine instance + model_path : str + Path to the model weights on disk + """ + ... + + def update_weight_xccl( + self, + engine: Any, + meta: WeightUpdateMeta, + param_specs: list[ParamSpec], + ) -> None: + """Update weights from distributed memory via NCCL/XCCL synchronously. + + Parameters + ---------- + engine : Any + The engine instance + meta : WeightUpdateMeta + Metadata containing communication group info + param_specs : List[ParamSpec] + Specifications for parameters to be updated + """ + ... + + def init_update_weight_group( + self, engine: Any, meta: WeightUpdateMeta, rank_offset: int + ) -> None: + """Initialize weight update communication group synchronously. + + Parameters + ---------- + engine : Any + The engine instance + meta : WeightUpdateMeta + Metadata containing communication backend configuration + rank_offset : int + Rank offset for this engine in the communication group + """ + ... + + def destroy(self, engine: Any) -> None: + """Destroy the engine and release resources. + + Parameters + ---------- + engine : Any + The engine instance to destroy + """ + ... + + +class LocalInfEngine: + """ + Base implementation for local in-process inference engines. + + This class provides common functionality for running inference engines + within the same process. Backend-specific behaviors are delegated to + an injected LocalInfBackendProtocol implementation. + + Uses composition pattern - instantiate directly with a backend rather + than inheriting from this class. + + Parameters + ---------- + config : InferenceEngineConfig + Configuration for the inference engine + backend : LocalInfBackendProtocol + Backend implementation providing engine-specific behavior + """ + + def __init__(self, config: InferenceEngineConfig, backend: LocalInfBackendProtocol): + self.config = config + self.backend = backend + + self.engine = None + self.distributed_weight_update_initialized = False + self._version = 0 + + self.lock = Lock() + + self.workflow_executor: WorkflowExecutor + + def initialize( + self, + engine_id: str | None = None, + engine_args: dict[str, Any] | None = None, + train_data_parallel_size: int | None = None, + ): + """Initialize the engine by creating the local inference engine. + + Parameters + ---------- + engine_id : Optional[str] + Unique identifier for this engine instance + engine_args : Optional[Dict[str, Any]] + Arguments to pass to the backend engine constructor + train_data_parallel_size : int | None + Data parallel size of the training engine + """ + if engine_id is None: + if dist.is_initialized(): + engine_id = str(dist.get_rank()) + else: + engine_id = uuid.uuid4().hex + self.engine_id = engine_id + self.logger = logging.getLogger(f"[Local Inference Engine Rank {engine_id}]") + + # Create the local engine via backend + engine_args = engine_args or {} + self.logger.info(f"Creating local inference engine with args: {engine_args}") + self.engine = self.backend.create_engine(engine_args) + self.logger.info("Local inference engine created successfully!") + + # Initialize thread pool for non-blocking weight updates + self.executor = ThreadPoolExecutor(max_workers=1) + + # Initialize workflow executor + self.workflow_executor = WorkflowExecutor( + config=self.config, + inference_engine=self, + ) + self.workflow_executor.initialize( + logger=self.logger, train_data_parallel_size=train_data_parallel_size + ) + + def destroy(self): + """Destroy the engine and clean up resources.""" + self.workflow_executor.destroy() + if self.engine is not None: + self.backend.destroy(self.engine) + self.engine = None + self.executor.shutdown() + + def set_version(self, version: int): + """Set the current weight version.""" + with self.lock: + self._version = version + + def get_version(self) -> int: + """Get the current weight version.""" + with self.lock: + return self._version + + async def agenerate(self, req: ModelRequest) -> ModelResponse: + """Asynchronously generate a response for the given request. + + Parameters + ---------- + req : ModelRequest + The model request containing input data and generation parameters + + Returns + ------- + ModelResponse + The generated response from the model + """ + if self.engine is None: + raise RuntimeError( + "Local inference engine is not initialized, cannot generate." + ) + + # Create a shallow copy of the input request + # we are going to modify it in-place + req = req.copy() + + # Validate n_samples + gconfig = req.gconfig + if gconfig.n_samples != 1: + raise ValueError( + "Local inference engines do not support n_samples > 1. " + "Please call generate multiple times with n_samples = 1." + ) + + # Validate max_new_tokens + max_new_tokens = min( + gconfig.max_tokens - len(req.input_ids), gconfig.max_new_tokens + ) + if max_new_tokens <= 0: + raise RuntimeError( + f"max_new_tokens ({max_new_tokens}) is non-positive! " + f"max_tokens={gconfig.max_tokens}, prompt_len={len(req.input_ids)}, " + f"max_new_tokens={gconfig.max_new_tokens}." + ) + + # Update max_new_tokens in request + req.gconfig.max_new_tokens = max_new_tokens + + # Make request + start_time = time.perf_counter() + accumulated_output_tokens = [] + accumulated_output_logprobs = [] + accumulated_versions = [] + + # Loop until generation is complete + stop_reason = None + while ( + stop_reason not in ["stop", "tool_calls", "length"] + and len(accumulated_output_tokens) < gconfig.max_new_tokens + ): + # Handle rollout interruption + while self.workflow_executor.paused.is_set(): + await asyncio.sleep(0.5) + + # Call backend async_generation + response = await self.backend.async_generation(self.engine, req) + + # Extract result + output_tokens = response.output_tokens + output_logprobs = response.output_logprobs + stop_reason = response.stop_reason + + # Update accumulated outputs + accumulated_output_tokens.extend(output_tokens) + accumulated_output_logprobs.extend(output_logprobs) + accumulated_versions.extend([self.get_version()] * len(output_tokens)) + + # Update request for next iteration + req.input_ids += output_tokens + req.gconfig.max_new_tokens -= len(output_tokens) + assert req.gconfig.max_new_tokens >= 0, ( + req.gconfig.max_new_tokens, + len(output_tokens), + len(req.input_ids), + ) + + # Final abort handling + if stop_reason == "abort": + # If stop_reason is "abort", the only reason we exit the loop is + # len(accumulated_output_tokens) >= gconfig.max_new_tokens + # so the actual reason is length + stop_reason = "length" + + latency = time.perf_counter() - start_time + + response = ModelResponse( + input_tokens=req.input_ids[ + : len(req.input_ids) - len(accumulated_output_tokens) + ], + input_images=req.image_data, + output_tokens=accumulated_output_tokens, + output_logprobs=accumulated_output_logprobs, + output_versions=accumulated_versions, + stop_reason=stop_reason, + latency=latency, + ttft=latency, # Simplified for non-streaming + tokenizer=req.tokenizer, + processor=req.processor, + ) + return response + + def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: + """Initialize the weight update process group for distributed weight updates. + + Parameters + ---------- + meta : WeightUpdateMeta + Metadata containing information about the weight update + + Returns + ------- + Future[None] + A future object representing the asynchronous initialization operation + """ + assert meta.type == current_platform.communication_backend + assert not self.distributed_weight_update_initialized, ( + "Weight update group already initialized." + ) + + if self.engine is None: + raise RuntimeError( + "Local inference engine is not initialized, " + "cannot init weight update group." + ) + + # Compute rank offset for this engine + # For local engines, we assume single instance per process + rank_offset = 1 # Offset by 1 to leave rank 0 for the training engine + + fut = self.executor.submit( + self._init_weights_update_group_sync, meta, rank_offset + ) + + def callback(fut): + self.logger.info( + f"Initialized {current_platform.communication_backend.upper()} group " + f"for distributed weight update for {meta.nccl_group_name}." + ) + self.distributed_weight_update_initialized = True + + fut.add_done_callback(callback) + + return fut + + def _init_weights_update_group_sync(self, meta: WeightUpdateMeta, rank_offset: int): + """Synchronously initialize weight update group in thread pool.""" + self.backend.init_update_weight_group(self.engine, meta, rank_offset) + + def update_weights_from_distributed( + self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] + ) -> Future[None]: + """Update weights in the inference engine from distributed memory. + + Parameters + ---------- + meta : WeightUpdateMeta + Metadata containing information about the weight update + param_specs : List[ParamSpec] + A list of parameter specifications for the weights to be updated + + Returns + ------- + Future[None] + A future object representing the asynchronous weight update operation + """ + assert meta.type == current_platform.communication_backend + + if self.engine is None: + raise RuntimeError( + "Local inference engine is not initialized, cannot update weights." + ) + + fut = self.executor.submit( + self._update_weights_from_distributed_sync, meta, param_specs + ) + + return fut + + def _update_weights_from_distributed_sync( + self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] + ): + """Synchronously update weights from distributed memory in thread pool.""" + self.backend.update_weight_xccl(self.engine, meta, param_specs) + + def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: + """Update weights in the inference engine from disk. + + Parameters + ---------- + meta : WeightUpdateMeta + Metadata containing information about the weight update + + Returns + ------- + Future[None] + A future object representing the asynchronous weight update operation + """ + assert meta.type == "disk" + + if self.engine is None: + raise RuntimeError( + "Local inference engine is not initialized, cannot update weights." + ) + + tik = time.perf_counter() + + # Validate experiment and trial names + if self.config.experiment_name is None or self.config.trial_name is None: + raise RuntimeError( + "Experiment and trial names must be set for disk-based weight updates." + ) + + fut = self.executor.submit(self._update_weights_from_disk_sync, meta) + + def callback(fut): + respond_time = fut.result() + self.logger.info( + f"Loading weights from disk done in " + f"{(time.perf_counter() - tik):.2f}s. " + f"Respond time: {respond_time:.2f}s." + ) + + fut.add_done_callback(callback) + + return fut + + def _update_weights_from_disk_sync(self, meta: WeightUpdateMeta) -> float: + """Synchronously update weights from disk in thread pool.""" + # Wait for training engine to signal that weights are ready + update_name = names.update_weights_from_disk( + self.config.experiment_name, + self.config.trial_name, + meta.model_version, + ) + save_timestamp = float(name_resolve.wait(update_name, timeout=120)) + load_timestamp = time.time() + + self.logger.info( + f"Begin update weights from {meta.path}, " + f"responded in {(load_timestamp - save_timestamp) * 1000:.2f} ms" + ) + + # Update weights from disk via backend + self.backend.update_weight_disk(self.engine, str(meta.path)) + + self.logger.info( + f"Loading weights done in {(time.time() - load_timestamp) * 1000:.2f} ms" + ) + self.set_version(meta.model_version) + + return load_timestamp - save_timestamp + + def submit( + self, + data: dict[str, Any], + workflow: RolloutWorkflow | None = None, + workflow_builder: Callable | None = None, + should_accept: Callable | None = None, + ) -> None: + """Submit a request to the inference engine and return immediately. + + Parameters + ---------- + data : Dict[str, Any] + The input data for rollout + workflow : RolloutWorkflow, optional + The workflow instance to run + workflow_builder : Callable, optional + A builder to create a workflow instance + should_accept : Callable, optional + A function to decide whether to accept a trajectory + """ + return self.workflow_executor.submit( + data, + workflow=workflow, + workflow_builder=workflow_builder, + should_accept=should_accept, + ) + + def wait(self, count: int, timeout: float | None = None) -> dict[str, Any]: + """Wait for a specified number of requests to complete. + + Parameters + ---------- + count : int + The number of accepted trajectories to wait for + timeout : float, optional + Timeout in seconds + + Returns + ------- + Dict[str, Any] + A concatenated batch of trajectories + """ + return self.workflow_executor.wait(count, timeout=timeout) + + def rollout_batch( + self, + data: list[dict[str, Any]], + workflow: RolloutWorkflow | None = None, + workflow_builder: Callable | None = None, + should_accept: Callable | None = None, + ) -> dict[str, Any]: + """Submit a batch of requests and wait for results. + + Parameters + ---------- + data : List[Dict[str, Any]] + A list of input data dictionaries for rollout + workflow : RolloutWorkflow, optional + The workflow instance to run + workflow_builder : Callable, optional + A builder to create a workflow instance + should_accept : Callable, optional + A function to decide whether to accept a trajectory + + Returns + ------- + Dict[str, Any] + A concatenated batch of trajectory results + """ + return self.workflow_executor.rollout_batch( + data=data, + workflow=workflow, + workflow_builder=workflow_builder, + should_accept=should_accept, + ) + + def prepare_batch( + self, + dataloader: StatefulDataLoader, + workflow: RolloutWorkflow | None = None, + workflow_builder: Callable | None = None, + should_accept: Callable | None = None, + ): + """Asynchronously submit and wait until a full batch is ready. + + Parameters + ---------- + dataloader : StatefulDataLoader + The data loader to pull data from + workflow : RolloutWorkflow, optional + The workflow instance to run + workflow_builder : Callable, optional + A builder to create a workflow instance + should_accept : Callable, optional + A function to decide whether to accept a trajectory + + Returns + ------- + Dict[str, Any] + A full batch of trajectory results + """ + return self.workflow_executor.prepare_batch( + dataloader=dataloader, + workflow=workflow, + workflow_builder=workflow_builder, + should_accept=should_accept, + ) + + def pause(self): + """Pause request submission for async rollout. + + Used during evaluation to prevent data over generation. + """ + return self.workflow_executor.pause() + + def resume(self): + """Resume request submission for async rollout.""" + return self.workflow_executor.resume() diff --git a/areal/engine/sglang_local.py b/areal/engine/sglang_local.py new file mode 100644 index 000000000..fe3b81b0a --- /dev/null +++ b/areal/engine/sglang_local.py @@ -0,0 +1,321 @@ +from __future__ import annotations + +import time +from collections.abc import Callable +from concurrent.futures import Future +from typing import Any + +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.cli_args import InferenceEngineConfig +from areal.api.engine_api import InferenceEngine +from areal.api.io_struct import ( + ModelRequest, + ModelResponse, + ParamSpec, + WeightUpdateMeta, +) +from areal.api.workflow_api import RolloutWorkflow +from areal.core.local_inf_engine import LocalInfEngine +from areal.platforms import current_platform + + +class SGLangLocalBackend: + """SGLang-specific backend implementation for local inference. + + This backend wraps SGLang's native Engine API for in-process inference. + """ + + def create_engine(self, engine_args: dict[str, Any]) -> Any: + """Create a local SGLang engine instance. + + Parameters + ---------- + engine_args : Dict[str, Any] + Arguments to pass to sglang.Engine constructor + + Returns + ------- + Any + The created SGLang Engine instance + """ + import sglang as sgl + + engine = sgl.Engine(**engine_args) + return engine + + async def async_generation(self, engine: Any, req: ModelRequest) -> ModelResponse: + """Perform async generation using the local SGLang engine. + + Parameters + ---------- + engine : Any + The SGLang Engine instance + req : ModelRequest + The generation request containing input and parameters + + Returns + ------- + ModelResponse + The generated response with tokens, logprobs, and metadata + """ + # Prepare request payload + gconfig = req.gconfig + stop_token_ids = gconfig.stop_token_ids + + sampling_params = { + "top_p": gconfig.top_p, + "top_k": gconfig.top_k, + "max_new_tokens": gconfig.max_new_tokens, + "temperature": 0.0 if gconfig.greedy else gconfig.temperature, + "stop_token_ids": stop_token_ids, + "frequency_penalty": gconfig.frequency_penalty, + } + + if gconfig.stop: + sampling_params["stop"] = gconfig.stop + + # Make request + start_time = time.perf_counter() + + # Call SGLang's async_generate method + outputs = await engine.async_generate( + input_ids=req.input_ids, + sampling_params=sampling_params, + return_logprob=True, + ) + + # Parse response + meta_info = outputs["meta_info"] + finish_reason = meta_info["finish_reason"] + stop_reason = finish_reason["type"] + stop_message = finish_reason.get("message", "") + + # Handle early abort + if stop_reason == "abort" and stop_message.startswith("Abort before prefill"): + latency = time.perf_counter() - start_time + return ModelResponse( + input_tokens=req.input_ids, + input_images=req.image_data, + output_tokens=[], + output_logprobs=[], + output_versions=[], + stop_reason=stop_reason, + latency=latency, + ttft=latency, + tokenizer=req.tokenizer, + processor=req.processor, + ) + + # Extract output tokens and logprobs + output_tokens = [x[1] for x in meta_info["output_token_logprobs"]] + output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]] + + latency = time.perf_counter() - start_time + + return ModelResponse( + input_tokens=req.input_ids, + input_images=req.image_data, + output_tokens=output_tokens, + output_logprobs=output_logprobs, + output_versions=[], # Will be filled by LocalInfEngine + stop_reason=stop_reason, + latency=latency, + ttft=latency, + tokenizer=req.tokenizer, + processor=req.processor, + ) + + def update_weight_disk(self, engine: Any, model_path: str) -> None: + """Update weights from disk synchronously. + + Parameters + ---------- + engine : Any + The SGLang Engine instance + model_path : str + Path to the model weights on disk + """ + # Call SGLang's update_weights_from_disk method + engine.update_weights_from_disk(model_path=model_path) + + def update_weight_xccl( + self, + engine: Any, + meta: WeightUpdateMeta, + param_specs: list[ParamSpec], + ) -> None: + """Update weights from distributed memory via NCCL/XCCL synchronously. + + Parameters + ---------- + engine : Any + The SGLang Engine instance + meta : WeightUpdateMeta + Metadata containing communication group info + param_specs : List[ParamSpec] + Specifications for parameters to be updated + """ + # Call SGLang's update_weights_from_distributed method + engine.update_weights_from_distributed( + names=[pspec.name for pspec in param_specs], + dtypes=[pspec.dtype for pspec in param_specs], + shapes=[pspec.shape for pspec in param_specs], + group_name=meta.nccl_group_name, + ) + + def init_update_weight_group( + self, engine: Any, meta: WeightUpdateMeta, rank_offset: int + ) -> None: + """Initialize weight update communication group synchronously. + + Parameters + ---------- + engine : Any + The SGLang Engine instance + meta : WeightUpdateMeta + Metadata containing communication backend configuration + rank_offset : int + Rank offset for this engine in the communication group + """ + assert meta.alloc_mode is not None + if meta.alloc_mode.gen.pp_size != 1: + raise NotImplementedError( + "NCCL weight update with PP size > 1 is not implemented yet." + ) + + # Call SGLang's init_weights_update_group method + engine.init_weights_update_group( + master_address=meta.nccl_master_address, + master_port=str(meta.nccl_master_port), + rank_offset=rank_offset, + world_size=meta.alloc_mode.gen.world_size + 1, + backend=current_platform.communication_backend, + group_name=meta.nccl_group_name, + ) + + def destroy(self, engine: Any) -> None: + """Destroy the engine and release resources. + + Parameters + ---------- + engine : Any + The SGLang Engine instance to destroy + """ + # SGLang engines typically don't need explicit cleanup + # but we include this for consistency with the protocol + if hasattr(engine, "shutdown"): + engine.shutdown() + + +class LocalSGLangEngine(InferenceEngine): + """SGLang local inference engine. + + This class delegates all functionality to LocalInfEngine with + an SGLangLocalBackend implementation. It maintains the same public API. + + Parameters + ---------- + config : InferenceEngineConfig + Configuration for the inference engine + """ + + def __init__(self, config: InferenceEngineConfig): + self.config = config + # Pure composition - create internal engine with SGLang backend + self._engine = LocalInfEngine(config, SGLangLocalBackend()) + + def initialize( + self, + engine_id: str | None = None, + engine_args: dict[str, Any] | None = None, + train_data_parallel_size: int | None = None, + ): + """Initialize the engine by creating the local SGLang engine. + + Parameters + ---------- + engine_id : Optional[str] + Unique identifier for this engine instance + engine_args : Optional[Dict[str, Any]] + Arguments to pass to sglang.Engine constructor + train_data_parallel_size : int | None + Data parallel size of the training engine + """ + return self._engine.initialize(engine_id, engine_args, train_data_parallel_size) + + def destroy(self): + """Destroy the engine and clean up resources.""" + return self._engine.destroy() + + def set_version(self, version: int): + """Set the current weight version.""" + return self._engine.set_version(version) + + def get_version(self) -> int: + """Get the current weight version.""" + return self._engine.get_version() + + async def agenerate(self, req: ModelRequest) -> ModelResponse: + """Asynchronously generate a response for the given request.""" + return await self._engine.agenerate(req) + + def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: + """Initialize the weight update process group.""" + return self._engine.init_weights_update_group(meta) + + def update_weights_from_distributed( + self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] + ) -> Future[None]: + """Update weights from distributed memory.""" + return self._engine.update_weights_from_distributed(meta, param_specs) + + def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: + """Update weights from disk.""" + return self._engine.update_weights_from_disk(meta) + + def submit( + self, + data: dict[str, Any], + workflow: RolloutWorkflow | None = None, + workflow_builder: Callable | None = None, + should_accept: Callable | None = None, + ) -> None: + """Submit a request to the inference engine.""" + return self._engine.submit(data, workflow, workflow_builder, should_accept) + + def wait(self, count: int, timeout: float | None = None) -> dict[str, Any]: + """Wait for a specified number of requests to complete.""" + return self._engine.wait(count, timeout) + + def rollout_batch( + self, + data: list[dict[str, Any]], + workflow: RolloutWorkflow | None = None, + workflow_builder: Callable | None = None, + should_accept: Callable | None = None, + ) -> dict[str, Any]: + """Submit a batch of requests and wait for results.""" + return self._engine.rollout_batch( + data, workflow, workflow_builder, should_accept + ) + + def prepare_batch( + self, + dataloader: StatefulDataLoader, + workflow: RolloutWorkflow | None = None, + workflow_builder: Callable | None = None, + should_accept: Callable | None = None, + ): + """Asynchronously submit and wait until a full batch is ready.""" + return self._engine.prepare_batch( + dataloader, workflow, workflow_builder, should_accept + ) + + def pause(self): + """Pause request submission for async rollout.""" + return self._engine.pause() + + def resume(self): + """Resume request submission for async rollout.""" + return self._engine.resume() diff --git a/areal/engine/vllm_local.py b/areal/engine/vllm_local.py new file mode 100644 index 000000000..4254fc17a --- /dev/null +++ b/areal/engine/vllm_local.py @@ -0,0 +1,342 @@ +from __future__ import annotations + +import time +import uuid +from collections.abc import Callable +from concurrent.futures import Future +from typing import Any + +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.cli_args import InferenceEngineConfig +from areal.api.engine_api import InferenceEngine +from areal.api.io_struct import ( + ModelRequest, + ModelResponse, + ParamSpec, + WeightUpdateMeta, +) +from areal.api.workflow_api import RolloutWorkflow +from areal.core.local_inf_engine import LocalInfEngine + + +class VLLMLocalBackend: + """vLLM-specific backend implementation for local inference. + + This backend wraps vLLM's native AsyncLLMEngine API for in-process inference. + """ + + def create_engine(self, engine_args: dict[str, Any]) -> Any: + """Create a local vLLM engine instance. + + Parameters + ---------- + engine_args : Dict[str, Any] + Arguments to pass to vLLM AsyncLLMEngine constructor + + Returns + ------- + Any + The created vLLM AsyncLLMEngine instance + """ + from vllm.engine.async_llm_engine import AsyncLLMEngine + + engine = AsyncLLMEngine.from_engine_args(**engine_args) + return engine + + async def async_generation(self, engine: Any, req: ModelRequest) -> ModelResponse: + """Perform async generation using the local vLLM engine. + + Parameters + ---------- + engine : Any + The vLLM AsyncLLMEngine instance + req : ModelRequest + The generation request containing input and parameters + + Returns + ------- + ModelResponse + The generated response with tokens, logprobs, and metadata + """ + from vllm import SamplingParams + + # Prepare request payload + gconfig = req.gconfig + stop_token_ids = gconfig.stop_token_ids + + sampling_params = SamplingParams( + top_p=gconfig.top_p, + top_k=gconfig.top_k, + max_tokens=gconfig.max_new_tokens, + temperature=0.0 if gconfig.greedy else gconfig.temperature, + stop_token_ids=stop_token_ids, + logprobs=0, # Request logprobs + ) + + # Make request + start_time = time.perf_counter() + + # Generate unique request ID + request_id = uuid.uuid4().hex + + # Call vLLM's generate method which returns an async generator + results_generator = engine.generate( + prompt=None, + sampling_params=sampling_params, + request_id=request_id, + prompt_token_ids=req.input_ids, + ) + + # Iterate through the generator to get the final result + final_output = None + async for request_output in results_generator: + final_output = request_output + + # Parse response + if final_output is None or len(final_output.outputs) == 0: + latency = time.perf_counter() - start_time + return ModelResponse( + input_tokens=req.input_ids, + input_images=req.image_data, + output_tokens=[], + output_logprobs=[], + output_versions=[], + stop_reason="abort", + latency=latency, + ttft=latency, + tokenizer=req.tokenizer, + processor=req.processor, + ) + + # Extract first completion output + completion_output = final_output.outputs[0] + stop_reason = completion_output.finish_reason + + # Extract output tokens from token_ids + output_tokens = completion_output.token_ids + + # Extract logprobs - vLLM returns logprobs as a list of dicts + output_logprobs = [] + if completion_output.logprobs: + for token_logprobs in completion_output.logprobs: + if token_logprobs: + # Get logprob for the actual selected token + # token_logprobs is a dict mapping token_id to Logprob object + # We need to find the logprob for the token that was selected + max_logprob = max(token_logprobs.values(), key=lambda x: x.logprob) + output_logprobs.append(max_logprob.logprob) + else: + output_logprobs.append(0.0) + + latency = time.perf_counter() - start_time + + return ModelResponse( + input_tokens=req.input_ids, + input_images=req.image_data, + output_tokens=output_tokens, + output_logprobs=output_logprobs, + output_versions=[], # Will be filled by LocalInfEngine + stop_reason=stop_reason, + latency=latency, + ttft=latency, + tokenizer=req.tokenizer, + processor=req.processor, + ) + + def update_weight_disk(self, engine: Any, model_path: str) -> None: + """Update weights from disk synchronously. + + Parameters + ---------- + engine : Any + The vLLM AsyncLLMEngine instance + model_path : str + Path to the model weights on disk + """ + # vLLM doesn't support updating weights from disk + # Typically requires creating a new engine + raise NotImplementedError( + "vLLM does not support updating weights from disk. " + "Please create a new engine instance with the new weights." + ) + + def update_weight_xccl( + self, + engine: Any, + meta: WeightUpdateMeta, + param_specs: list[ParamSpec], + ) -> None: + """Update weights from distributed memory via NCCL/XCCL synchronously. + + Parameters + ---------- + engine : Any + The vLLM AsyncLLMEngine instance + meta : WeightUpdateMeta + Metadata containing communication group info + param_specs : List[ParamSpec] + Specifications for parameters to be updated + """ + # vLLM doesn't support distributed weight updates in the same way + raise NotImplementedError( + "vLLM does not support distributed weight updates via NCCL/XCCL. " + "Please use disk-based updates or create a new engine instance." + ) + + def init_update_weight_group( + self, engine: Any, meta: WeightUpdateMeta, rank_offset: int + ) -> None: + """Initialize weight update communication group synchronously. + + Parameters + ---------- + engine : Any + The vLLM AsyncLLMEngine instance + meta : WeightUpdateMeta + Metadata containing communication backend configuration + rank_offset : int + Rank offset for this engine in the communication group + """ + # vLLM doesn't support initializing weight update groups + raise NotImplementedError( + "vLLM does not support weight update communication groups." + ) + + def destroy(self, engine: Any) -> None: + """Destroy the engine and release resources. + + Parameters + ---------- + engine : Any + The vLLM AsyncLLMEngine instance to destroy + """ + # vLLM engines typically don't need explicit cleanup + # but we include this for consistency with the protocol + if hasattr(engine, "shutdown"): + engine.shutdown() + + +class LocalvLLMEngine(InferenceEngine): + """vLLM local inference engine. + + This class delegates all functionality to LocalInfEngine with + a VLLMLocalBackend implementation. It maintains the same public API. + + Note: vLLM does not support weight updates, so update_weights_from_disk + and update_weights_from_distributed will raise NotImplementedError. + + Parameters + ---------- + config : InferenceEngineConfig + Configuration for the inference engine + """ + + def __init__(self, config: InferenceEngineConfig): + self.config = config + # Pure composition - create internal engine with vLLM backend + self._engine = LocalInfEngine(config, VLLMLocalBackend()) + + def initialize( + self, + engine_id: str | None = None, + engine_args: dict[str, Any] | None = None, + train_data_parallel_size: int | None = None, + ): + """Initialize the engine by creating the local vLLM engine. + + Parameters + ---------- + engine_id : Optional[str] + Unique identifier for this engine instance + engine_args : Optional[Dict[str, Any]] + Arguments to pass to vLLM AsyncLLMEngine constructor + train_data_parallel_size : int | None + Data parallel size of the training engine + """ + return self._engine.initialize(engine_id, engine_args, train_data_parallel_size) + + def destroy(self): + """Destroy the engine and clean up resources.""" + return self._engine.destroy() + + def set_version(self, version: int): + """Set the current weight version.""" + return self._engine.set_version(version) + + def get_version(self) -> int: + """Get the current weight version.""" + return self._engine.get_version() + + async def agenerate(self, req: ModelRequest) -> ModelResponse: + """Asynchronously generate a response for the given request.""" + return await self._engine.agenerate(req) + + def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: + """Initialize the weight update process group. + + Note: Not supported by vLLM. + """ + return self._engine.init_weights_update_group(meta) + + def update_weights_from_distributed( + self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] + ) -> Future[None]: + """Update weights from distributed memory. + + Note: Not supported by vLLM. + """ + return self._engine.update_weights_from_distributed(meta, param_specs) + + def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: + """Update weights from disk. + + Note: Not supported by vLLM. + """ + return self._engine.update_weights_from_disk(meta) + + def submit( + self, + data: dict[str, Any], + workflow: RolloutWorkflow | None = None, + workflow_builder: Callable | None = None, + should_accept: Callable | None = None, + ) -> None: + """Submit a request to the inference engine.""" + return self._engine.submit(data, workflow, workflow_builder, should_accept) + + def wait(self, count: int, timeout: float | None = None) -> dict[str, Any]: + """Wait for a specified number of requests to complete.""" + return self._engine.wait(count, timeout) + + def rollout_batch( + self, + data: list[dict[str, Any]], + workflow: RolloutWorkflow | None = None, + workflow_builder: Callable | None = None, + should_accept: Callable | None = None, + ) -> dict[str, Any]: + """Submit a batch of requests and wait for results.""" + return self._engine.rollout_batch( + data, workflow, workflow_builder, should_accept + ) + + def prepare_batch( + self, + dataloader: StatefulDataLoader, + workflow: RolloutWorkflow | None = None, + workflow_builder: Callable | None = None, + should_accept: Callable | None = None, + ): + """Asynchronously submit and wait until a full batch is ready.""" + return self._engine.prepare_batch( + dataloader, workflow, workflow_builder, should_accept + ) + + def pause(self): + """Pause request submission for async rollout.""" + return self._engine.pause() + + def resume(self): + """Resume request submission for async rollout.""" + return self._engine.resume() diff --git a/areal/tests/test_local_sglang_engine.py b/areal/tests/test_local_sglang_engine.py new file mode 100644 index 000000000..25eedd256 --- /dev/null +++ b/areal/tests/test_local_sglang_engine.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import os +import time + +import pytest + +from areal.api.cli_args import GenerationHyperparameters, InferenceEngineConfig +from areal.api.io_struct import WeightUpdateMeta +from areal.utils.data import get_batch_size +from areal.utils.hf_utils import load_hf_tokenizer + + +EXPR_NAME = "test_local_sglang_engine" +TRIAL_NAME = "trial_0" +MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" +if not os.path.exists(MODEL_PATH): + MODEL_PATH = "Qwen/Qwen3-0.6B" + + +def _dummy_reward_fn(*args, **kwargs): + return 1.0 + + +@pytest.fixture(scope="module") +def engine_args(): + """Provide SGLang engine args for local inference.""" + return { + "model_path": MODEL_PATH, + "tp_size": 1, + "mem_fraction_static": 0.3, + "skip_tokenizer_init": True, + } + + +@pytest.mark.parametrize("n_samples", [1, 2, 4]) +def test_local_sglang_rollout(engine_args, n_samples): + from areal.engine.sglang_local import LocalSGLangEngine + from areal.workflow.rlvr import RLVRWorkflow + + config = InferenceEngineConfig( + experiment_name=EXPR_NAME, + trial_name=TRIAL_NAME, + max_concurrent_rollouts=2, + consumer_batch_size=2, + ) + engine = LocalSGLangEngine(config) + engine.initialize(engine_args=engine_args) + + gconfig = GenerationHyperparameters( + max_new_tokens=16, greedy=False, n_samples=n_samples + ) + tokenizer = load_hf_tokenizer(MODEL_PATH) + + workflow = RLVRWorkflow( + reward_fn=_dummy_reward_fn, + gconfig=gconfig, + tokenizer=tokenizer, + enable_thinking=False, + ) + + data = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + } + result = engine.rollout_batch([data] * 2, workflow=workflow) + assert isinstance(result, dict) + bs = get_batch_size(result) + assert bs == 2 * n_samples + engine.destroy() + + +@pytest.mark.parametrize("ofp", [1, 4, 16]) +@pytest.mark.parametrize("bs", [2, 4]) +@pytest.mark.parametrize("n_samples", [2, 1]) +def test_local_sglang_staleness_control(engine_args, bs, ofp, n_samples): + from areal.engine.sglang_local import LocalSGLangEngine + from areal.workflow.rlvr import RLVRWorkflow + + config = InferenceEngineConfig( + experiment_name=EXPR_NAME, + trial_name=TRIAL_NAME, + consumer_batch_size=bs, + max_head_offpolicyness=ofp, + ) + engine = LocalSGLangEngine(config) + engine.initialize(engine_args=engine_args) + + gconfig = GenerationHyperparameters( + max_new_tokens=2, greedy=False, n_samples=n_samples + ) + tokenizer = load_hf_tokenizer(MODEL_PATH) + + workflow = RLVRWorkflow( + reward_fn=_dummy_reward_fn, + gconfig=gconfig, + tokenizer=tokenizer, + enable_thinking=False, + ) + data = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + } + for _ in range(bs * 2): + engine.submit(data, workflow=workflow) + + # wait for some time + time.sleep(10) + assert engine._engine.workflow_executor.output_queue.qsize() == min( + bs * 2, bs * (ofp + 1) + ) + + # Update model version + engine.set_version(1) + print("Updated model version", flush=True) + + # submit again + for _ in range(bs * 2): + engine.submit(data, workflow=workflow) + # wait for some time + time.sleep(5) + assert engine._engine.workflow_executor.output_queue.qsize() == min( + bs * 4, bs * (ofp + 2) + ) + + # exit + engine.destroy() + + +def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, engine_args): + # setup FSDP engine + from areal.api.cli_args import OptimizerConfig, TrainEngineConfig + from areal.api.io_struct import FinetuneSpec + from areal.engine.fsdp_engine import FSDPEngine + + os.environ["WORLD_SIZE"] = "1" + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "7777" + + engine_config = TrainEngineConfig( + experiment_name=EXPR_NAME, + trial_name=TRIAL_NAME, + path=MODEL_PATH, + optimizer=OptimizerConfig(), + ) + train_engine = FSDPEngine(engine_config) + train_engine.create_process_group() + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) + train_engine.initialize(None, ft_spec) + train_engine.model_version = 100 + + # setup name resolve + import areal.utils.name_resolve as name_resolve + from areal.api.cli_args import NameResolveConfig + + nfs_record_root = tmp_path_factory.mktemp("nfs_record_path") + name_resolve_config = NameResolveConfig(type="nfs", nfs_record_root=nfs_record_root) + name_resolve.reconfigure(name_resolve_config) + + # initialize SGLang local engine + from areal.engine.sglang_local import LocalSGLangEngine + + config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) + inf_engine = LocalSGLangEngine(config) + inf_engine.initialize(engine_args=engine_args) + inf_engine.set_version(100) + + # test update weights + path = tmp_path_factory.mktemp("upload_weights_from_disk") + update_weight_meta = WeightUpdateMeta(type="disk", path=str(path)) + train_engine.connect_engine(inf_engine, update_weight_meta) + train_engine.set_version(100) + train_engine.update_weights(update_weight_meta) + inf_engine.destroy() diff --git a/areal/tests/test_local_vllm_engine.py b/areal/tests/test_local_vllm_engine.py new file mode 100644 index 000000000..06ca2bf9b --- /dev/null +++ b/areal/tests/test_local_vllm_engine.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import os + +import pytest + +from areal.api.cli_args import GenerationHyperparameters, InferenceEngineConfig +from areal.utils.data import get_batch_size +from areal.utils.hf_utils import load_hf_tokenizer + + +EXPR_NAME = "test_local_vllm_engine" +TRIAL_NAME = "trial_0" +MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" +if not os.path.exists(MODEL_PATH): + MODEL_PATH = "Qwen/Qwen3-0.6B" + + +def _dummy_reward_fn(*args, **kwargs): + return 1.0 + + +@pytest.fixture(scope="module") +def engine_args(): + """Provide vLLM engine args for local inference.""" + from vllm import EngineArgs + + return EngineArgs( + model=MODEL_PATH, + tensor_parallel_size=1, + gpu_memory_utilization=0.3, + trust_remote_code=True, + ) + + +@pytest.mark.parametrize("n_samples", [1, 2, 4]) +def test_local_vllm_rollout(engine_args, n_samples): + from areal.engine.vllm_local import LocalvLLMEngine + from areal.workflow.rlvr import RLVRWorkflow + + config = InferenceEngineConfig( + experiment_name=EXPR_NAME, + trial_name=TRIAL_NAME, + max_concurrent_rollouts=2, + consumer_batch_size=2, + ) + engine = LocalvLLMEngine(config) + engine.initialize(engine_args=engine_args) + + gconfig = GenerationHyperparameters( + max_new_tokens=16, greedy=False, n_samples=n_samples + ) + tokenizer = load_hf_tokenizer(MODEL_PATH) + + workflow = RLVRWorkflow( + reward_fn=_dummy_reward_fn, + gconfig=gconfig, + tokenizer=tokenizer, + enable_thinking=False, + ) + + data = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + } + result = engine.rollout_batch([data] * 2, workflow=workflow) + assert isinstance(result, dict) + bs = get_batch_size(result) + assert bs == 2 * n_samples + engine.destroy() + + +def test_local_vllm_weight_update_not_supported(engine_args): + """Test that weight updates correctly raise NotImplementedError for vLLM.""" + from areal.api.io_struct import WeightUpdateMeta + from areal.engine.vllm_local import LocalvLLMEngine + + config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) + engine = LocalvLLMEngine(config) + engine.initialize(engine_args=engine_args) + + # Test that disk weight update is not supported + update_weight_meta = WeightUpdateMeta(type="disk", path="/tmp/fake_path") + + with pytest.raises(NotImplementedError, match="vLLM does not support"): + fut = engine.update_weights_from_disk(update_weight_meta) + fut.result() # Wait for the future to complete and raise the exception + + engine.destroy() diff --git a/pyproject.toml b/pyproject.toml index 443e6cb38..402b00a0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -199,7 +199,13 @@ select = [ ignore = [] [tool.ruff.lint.isort] +from-first = false known-first-party = ["areal"] +force-sort-within-sections = false +split-on-trailing-comma = true +combine-as-imports = false +force-wrap-aliases = false +lines-after-imports = 2 # The following tools are remained as legacy. # We will use `ruff` instead in future development From 337e71a57363c55e41ebd9187d3bb659fe3e98dd Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 22 Oct 2025 21:48:51 +0800 Subject: [PATCH 02/52] . --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 402b00a0d..5bf1d81f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -205,7 +205,6 @@ force-sort-within-sections = false split-on-trailing-comma = true combine-as-imports = false force-wrap-aliases = false -lines-after-imports = 2 # The following tools are remained as legacy. # We will use `ruff` instead in future development From a9dad5af73a3f8575d44e6154d8bdad218dd6c41 Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Fri, 24 Oct 2025 16:32:30 +0800 Subject: [PATCH 03/52] minor fix import --- .pre-commit-config.yaml | 3 --- areal/core/__init__.py | 3 --- areal/core/local_inf_engine.py | 2 -- areal/engine/sglang_local.py | 2 -- areal/engine/vllm_local.py | 2 -- areal/tests/test_local_sglang_engine.py | 3 --- areal/tests/test_local_vllm_engine.py | 3 --- 7 files changed, 18 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index be6dfe9e1..32574d02a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,9 +39,6 @@ repos: # Ruff version. rev: v0.14.1 hooks: - - id: ruff-format # Run the formatter. - name: Run Formatter (Ruff) - types_or: [ python, pyi, jupyter ] - id: ruff # Run the linter. name: Run Linter Check (Ruff) types_or: [ python, pyi, jupyter ] diff --git a/areal/core/__init__.py b/areal/core/__init__.py index 741b4d687..ddc08c8cf 100644 --- a/areal/core/__init__.py +++ b/areal/core/__init__.py @@ -1,7 +1,5 @@ """Core components for AREAL.""" -from __future__ import annotations - from .local_inf_engine import ( LocalInfBackendProtocol, LocalInfEngine, @@ -16,7 +14,6 @@ check_trajectory_format, ) - __all__ = [ "LocalInfBackendProtocol", "LocalInfEngine", diff --git a/areal/core/local_inf_engine.py b/areal/core/local_inf_engine.py index 519d74cba..5119c2778 100644 --- a/areal/core/local_inf_engine.py +++ b/areal/core/local_inf_engine.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import asyncio import time import uuid diff --git a/areal/engine/sglang_local.py b/areal/engine/sglang_local.py index fe3b81b0a..0b28608be 100644 --- a/areal/engine/sglang_local.py +++ b/areal/engine/sglang_local.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import time from collections.abc import Callable from concurrent.futures import Future diff --git a/areal/engine/vllm_local.py b/areal/engine/vllm_local.py index 4254fc17a..e8f7a3c8e 100644 --- a/areal/engine/vllm_local.py +++ b/areal/engine/vllm_local.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import time import uuid from collections.abc import Callable diff --git a/areal/tests/test_local_sglang_engine.py b/areal/tests/test_local_sglang_engine.py index 25eedd256..e71837fd7 100644 --- a/areal/tests/test_local_sglang_engine.py +++ b/areal/tests/test_local_sglang_engine.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import os import time @@ -10,7 +8,6 @@ from areal.utils.data import get_batch_size from areal.utils.hf_utils import load_hf_tokenizer - EXPR_NAME = "test_local_sglang_engine" TRIAL_NAME = "trial_0" MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" diff --git a/areal/tests/test_local_vllm_engine.py b/areal/tests/test_local_vllm_engine.py index 06ca2bf9b..c8a05a1fa 100644 --- a/areal/tests/test_local_vllm_engine.py +++ b/areal/tests/test_local_vllm_engine.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import os import pytest @@ -8,7 +6,6 @@ from areal.utils.data import get_batch_size from areal.utils.hf_utils import load_hf_tokenizer - EXPR_NAME = "test_local_vllm_engine" TRIAL_NAME = "trial_0" MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" From f660e5b48a60b52059e1de09f16ebbf808543d9c Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Mon, 27 Oct 2025 17:07:03 +0800 Subject: [PATCH 04/52] merge inferece engine tests --- areal/tests/test_inference_engines.py | 368 ++++++++++++++++++++++++ areal/tests/test_local_sglang_engine.py | 171 ----------- areal/tests/test_local_vllm_engine.py | 85 ------ areal/tests/test_sglang_engine.py | 226 --------------- areal/tests/test_vllm_engine.py | 236 --------------- 5 files changed, 368 insertions(+), 718 deletions(-) create mode 100644 areal/tests/test_inference_engines.py delete mode 100644 areal/tests/test_local_sglang_engine.py delete mode 100644 areal/tests/test_local_vllm_engine.py delete mode 100644 areal/tests/test_sglang_engine.py delete mode 100644 areal/tests/test_vllm_engine.py diff --git a/areal/tests/test_inference_engines.py b/areal/tests/test_inference_engines.py new file mode 100644 index 000000000..246acbfd2 --- /dev/null +++ b/areal/tests/test_inference_engines.py @@ -0,0 +1,368 @@ +"""Unified test suite for inference engines (vLLM and SGLang, both local and remote).""" + +import os +import subprocess +import sys +import time + +import pytest +import requests + +from areal.api.cli_args import ( + GenerationHyperparameters, + InferenceEngineConfig, + SGLangConfig, + vLLMConfig, +) +from areal.api.io_struct import WeightUpdateMeta +from areal.utils import network +from areal.utils.data import get_batch_size +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.pkg_version import is_available + +MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" +if not os.path.exists(MODEL_PATH): + MODEL_PATH = "Qwen/Qwen3-0.6B" + +# set a large timeout since we may need to download the model from hub +RUN_SERVER_TIMEOUT = 180 + +IS_VLLM_INSTALLED = is_available("vllm") + + +def check_server_health(base_url): + """Check if the server is healthy and ready to accept requests.""" + try: + response = requests.get(f"{base_url}/health", timeout=30) + return response.status_code == 200 + except requests.exceptions.RequestException: + return False + + +def _dummy_reward_fn(*args, **kwargs): + """Dummy reward function for testing.""" + return 1.0 + + +@pytest.fixture( + params=[ + ("vllm", "remote"), + ("vllm", "local"), + ("sglang", "remote"), + ("sglang", "local"), + ], + ids=["vllm-remote", "vllm-local", "sglang-remote", "sglang-local"], +) +def inference_engine(request): + """Unified fixture that provides any inference engine (vLLM/SGLang, local/remote). + + This fixture: + 1. Launches the appropriate server (for remote) or prepares engine args (for local) + 2. Yields engine metadata for test initialization + 3. Cleans up resources after all tests complete + """ + backend, mode = request.param + + # Skip if vLLM is not installed + if backend == "vllm" and not IS_VLLM_INSTALLED: + pytest.skip("vLLM is not installed") + + from areal.utils import seeding + + expr_name = f"test_{mode}_{backend}_engine" + trial_name = "trial_0" + + seeding.set_random_seed(1, expr_name) + + # Initialize engine based on backend and mode + if mode == "remote": + # Launch server + port, dist_port = network.find_free_ports(2) + host = network.gethostip() + + if backend == "vllm": + from areal.engine.vllm_remote import RemotevLLMEngine + + cmd = vLLMConfig.build_cmd( + vllm_config=vLLMConfig( + skip_tokenizer_init=False, + model=MODEL_PATH, + gpu_memory_utilization=0.1, + ), + host=host, + port=port, + tp_size=1, + dist_init_addr=f"{host}:{dist_port}", + ) + engine_class = RemotevLLMEngine + else: # sglang + from areal.engine.sglang_remote import RemoteSGLangEngine + + cmd = SGLangConfig.build_cmd( + sglang_config=SGLangConfig( + skip_tokenizer_init=True, + model_path=MODEL_PATH, + mem_fraction_static=0.3, + ), + host=host, + port=port, + tp_size=1, + base_gpu_id=0, + dist_init_addr=f"{host}:{dist_port}", + ) + engine_class = RemoteSGLangEngine + + # Launch process + cmd = cmd.replace("\\\n", " ").replace("\\", " ") + process = subprocess.Popen( + cmd.split(), + text=True, + stdout=sys.stdout, + stderr=sys.stdout, + ) + base_url = f"http://{host}:{port}" + tik = time.time() + while time.time() - tik < RUN_SERVER_TIMEOUT: + if check_server_health(base_url): + break + time.sleep(1) + if time.time() - tik > RUN_SERVER_TIMEOUT: + process.terminate() + raise RuntimeError(f"{backend.upper()} server launch failed") + + # Set environment for remote engine + os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{host}:{port}" + + yield { + "engine_class": engine_class, + "backend": backend, + "mode": mode, + "expr_name": expr_name, + "trial_name": trial_name, + "host": host, + "port": port, + } + + # Cleanup + process.terminate() + + else: # local + if backend == "vllm": + from vllm import EngineArgs + + from areal.engine.vllm_local import LocalvLLMEngine + + engine_args = EngineArgs( + model=MODEL_PATH, + tensor_parallel_size=1, + gpu_memory_utilization=0.3, + trust_remote_code=True, + ) + engine_class = LocalvLLMEngine + else: # sglang + from areal.engine.sglang_local import LocalSGLangEngine + + engine_args = { + "model_path": MODEL_PATH, + "tp_size": 1, + "mem_fraction_static": 0.3, + "skip_tokenizer_init": True, + } + engine_class = LocalSGLangEngine + + yield { + "engine_class": engine_class, + "backend": backend, + "mode": mode, + "expr_name": expr_name, + "trial_name": trial_name, + "engine_args": engine_args, + } + + +# ============================================================================ +# Unified Tests +# ============================================================================ + + +@pytest.mark.parametrize("n_samples", [1, 2, 4]) +def test_rollout(inference_engine, n_samples): + """Test engine rollout with different sample sizes.""" + from areal.workflow.rlvr import RLVRWorkflow + + config = InferenceEngineConfig( + experiment_name=inference_engine["expr_name"], + trial_name=inference_engine["trial_name"], + max_concurrent_rollouts=2, + consumer_batch_size=2, + ) + + engine = inference_engine["engine_class"](config) + + # Initialize based on mode + if inference_engine["mode"] == "remote": + engine.initialize() + else: # local + engine.initialize(engine_args=inference_engine["engine_args"]) + + gconfig = GenerationHyperparameters( + max_new_tokens=16, greedy=False, n_samples=n_samples + ) + tokenizer = load_hf_tokenizer(MODEL_PATH) + + workflow = RLVRWorkflow( + reward_fn=_dummy_reward_fn, + gconfig=gconfig, + tokenizer=tokenizer, + enable_thinking=False, + ) + + data = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + } + result = engine.rollout_batch([data] * 2, workflow=workflow) + assert isinstance(result, dict) + bs = get_batch_size(result) + assert bs == 2 * n_samples + engine.destroy() + + +@pytest.mark.parametrize("ofp", [0, 1, 4, 16]) +@pytest.mark.parametrize("bs", [2, 4]) +@pytest.mark.parametrize("n_samples", [2, 1]) +def test_staleness_control(inference_engine, bs, ofp, n_samples): + """Test engine staleness control mechanism.""" + from areal.workflow.rlvr import RLVRWorkflow + + # Skip certain parameter combinations based on backend + if inference_engine["backend"] == "sglang" and bs == 4: + pytest.skip("SGLang only tests with bs=2") + if inference_engine["backend"] == "sglang" and ofp == 16: + pytest.skip("SGLang doesn't test with ofp=16") + if inference_engine["backend"] == "vllm" and ofp == 0: + pytest.skip("vLLM doesn't test with ofp=0") + + config = InferenceEngineConfig( + experiment_name=inference_engine["expr_name"], + trial_name=inference_engine["trial_name"], + consumer_batch_size=bs, + max_head_offpolicyness=ofp, + enable_rollout_tracing=( + inference_engine["backend"] == "sglang" + and inference_engine["mode"] == "remote" + ), + ) + + engine = inference_engine["engine_class"](config) + + # Initialize based on mode + if inference_engine["mode"] == "remote": + engine.initialize() + else: # local + engine.initialize(engine_args=inference_engine["engine_args"]) + + gconfig = GenerationHyperparameters( + max_new_tokens=2, greedy=False, n_samples=n_samples + ) + tokenizer = load_hf_tokenizer(MODEL_PATH) + + workflow = RLVRWorkflow( + reward_fn=_dummy_reward_fn, + gconfig=gconfig, + tokenizer=tokenizer, + enable_thinking=False, + ) + data = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + } + for _ in range(bs * 2): + engine.submit(data, workflow=workflow) + + if ofp < 1: + # Due to controlled offpolicyness, not all requests are committed + with pytest.raises(TimeoutError): + engine.wait(count=bs * 2, timeout=10) + else: + result = engine.wait(count=bs * 2, timeout=10) + assert result["attention_mask"].shape[0] == bs * 2 * n_samples + + # Update model version + engine.set_version(1) + print("Updated model version", flush=True) + + # submit again + for _ in range(bs * 2): + engine.submit(data, workflow=workflow) + + if ofp < 2: + # Due to controlled offpolicyness, not all requests are committed + with pytest.raises(TimeoutError): + engine.wait(count=bs * 4, timeout=5) + else: + # 2 * bs samples haved been retrived above + results = engine.wait(count=bs * 2, timeout=5) + assert results["attention_mask"].shape[0] == bs * 2 * n_samples + + # exit + engine.destroy() + + +def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, inference_engine): + """Test disk-based weight updates from FSDP engine to inference engine.""" + # Skip weight update test for local vLLM (not supported) + if inference_engine["backend"] == "vllm" and inference_engine["mode"] == "local": + pytest.skip("Local vLLM doesn't support weight updates") + + # setup FSDP engine + from areal.api.cli_args import OptimizerConfig, TrainEngineConfig + from areal.api.io_struct import FinetuneSpec + from areal.engine.fsdp_engine import FSDPEngine + + os.environ["WORLD_SIZE"] = "1" + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "7777" + + engine_config = TrainEngineConfig( + experiment_name=inference_engine["expr_name"], + trial_name=inference_engine["trial_name"], + path=MODEL_PATH, + optimizer=OptimizerConfig(), + ) + train_engine = FSDPEngine(engine_config) + train_engine.create_process_group() + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) + train_engine.initialize(None, ft_spec) + train_engine.model_version = 100 + + # setup name resolve + import areal.utils.name_resolve as name_resolve + from areal.api.cli_args import NameResolveConfig + + nfs_record_root = tmp_path_factory.mktemp("nfs_record_path") + name_resolve_config = NameResolveConfig(type="nfs", nfs_record_root=nfs_record_root) + name_resolve.reconfigure(name_resolve_config) + + # initialize inference engine + config = InferenceEngineConfig( + experiment_name=inference_engine["expr_name"], + trial_name=inference_engine["trial_name"], + ) + inf_engine = inference_engine["engine_class"](config) + + # Initialize based on mode + if inference_engine["mode"] == "remote": + inf_engine.initialize() + else: # local + inf_engine.initialize(engine_args=inference_engine["engine_args"]) + + inf_engine.set_version(100) + + # test update weights + path = tmp_path_factory.mktemp("update_weights_from_disk") + update_weight_meta = WeightUpdateMeta(type="disk", path=str(path)) + train_engine.connect_engine(inf_engine, update_weight_meta) + train_engine.set_version(100) + train_engine.update_weights(update_weight_meta) + inf_engine.destroy() diff --git a/areal/tests/test_local_sglang_engine.py b/areal/tests/test_local_sglang_engine.py deleted file mode 100644 index e71837fd7..000000000 --- a/areal/tests/test_local_sglang_engine.py +++ /dev/null @@ -1,171 +0,0 @@ -import os -import time - -import pytest - -from areal.api.cli_args import GenerationHyperparameters, InferenceEngineConfig -from areal.api.io_struct import WeightUpdateMeta -from areal.utils.data import get_batch_size -from areal.utils.hf_utils import load_hf_tokenizer - -EXPR_NAME = "test_local_sglang_engine" -TRIAL_NAME = "trial_0" -MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" -if not os.path.exists(MODEL_PATH): - MODEL_PATH = "Qwen/Qwen3-0.6B" - - -def _dummy_reward_fn(*args, **kwargs): - return 1.0 - - -@pytest.fixture(scope="module") -def engine_args(): - """Provide SGLang engine args for local inference.""" - return { - "model_path": MODEL_PATH, - "tp_size": 1, - "mem_fraction_static": 0.3, - "skip_tokenizer_init": True, - } - - -@pytest.mark.parametrize("n_samples", [1, 2, 4]) -def test_local_sglang_rollout(engine_args, n_samples): - from areal.engine.sglang_local import LocalSGLangEngine - from areal.workflow.rlvr import RLVRWorkflow - - config = InferenceEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - max_concurrent_rollouts=2, - consumer_batch_size=2, - ) - engine = LocalSGLangEngine(config) - engine.initialize(engine_args=engine_args) - - gconfig = GenerationHyperparameters( - max_new_tokens=16, greedy=False, n_samples=n_samples - ) - tokenizer = load_hf_tokenizer(MODEL_PATH) - - workflow = RLVRWorkflow( - reward_fn=_dummy_reward_fn, - gconfig=gconfig, - tokenizer=tokenizer, - enable_thinking=False, - ) - - data = { - "messages": [{"role": "user", "content": "Hello, how are you?"}], - } - result = engine.rollout_batch([data] * 2, workflow=workflow) - assert isinstance(result, dict) - bs = get_batch_size(result) - assert bs == 2 * n_samples - engine.destroy() - - -@pytest.mark.parametrize("ofp", [1, 4, 16]) -@pytest.mark.parametrize("bs", [2, 4]) -@pytest.mark.parametrize("n_samples", [2, 1]) -def test_local_sglang_staleness_control(engine_args, bs, ofp, n_samples): - from areal.engine.sglang_local import LocalSGLangEngine - from areal.workflow.rlvr import RLVRWorkflow - - config = InferenceEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - consumer_batch_size=bs, - max_head_offpolicyness=ofp, - ) - engine = LocalSGLangEngine(config) - engine.initialize(engine_args=engine_args) - - gconfig = GenerationHyperparameters( - max_new_tokens=2, greedy=False, n_samples=n_samples - ) - tokenizer = load_hf_tokenizer(MODEL_PATH) - - workflow = RLVRWorkflow( - reward_fn=_dummy_reward_fn, - gconfig=gconfig, - tokenizer=tokenizer, - enable_thinking=False, - ) - data = { - "messages": [{"role": "user", "content": "Hello, how are you?"}], - } - for _ in range(bs * 2): - engine.submit(data, workflow=workflow) - - # wait for some time - time.sleep(10) - assert engine._engine.workflow_executor.output_queue.qsize() == min( - bs * 2, bs * (ofp + 1) - ) - - # Update model version - engine.set_version(1) - print("Updated model version", flush=True) - - # submit again - for _ in range(bs * 2): - engine.submit(data, workflow=workflow) - # wait for some time - time.sleep(5) - assert engine._engine.workflow_executor.output_queue.qsize() == min( - bs * 4, bs * (ofp + 2) - ) - - # exit - engine.destroy() - - -def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, engine_args): - # setup FSDP engine - from areal.api.cli_args import OptimizerConfig, TrainEngineConfig - from areal.api.io_struct import FinetuneSpec - from areal.engine.fsdp_engine import FSDPEngine - - os.environ["WORLD_SIZE"] = "1" - os.environ["RANK"] = "0" - os.environ["LOCAL_RANK"] = "0" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "7777" - - engine_config = TrainEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - path=MODEL_PATH, - optimizer=OptimizerConfig(), - ) - train_engine = FSDPEngine(engine_config) - train_engine.create_process_group() - ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) - train_engine.initialize(None, ft_spec) - train_engine.model_version = 100 - - # setup name resolve - import areal.utils.name_resolve as name_resolve - from areal.api.cli_args import NameResolveConfig - - nfs_record_root = tmp_path_factory.mktemp("nfs_record_path") - name_resolve_config = NameResolveConfig(type="nfs", nfs_record_root=nfs_record_root) - name_resolve.reconfigure(name_resolve_config) - - # initialize SGLang local engine - from areal.engine.sglang_local import LocalSGLangEngine - - config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) - inf_engine = LocalSGLangEngine(config) - inf_engine.initialize(engine_args=engine_args) - inf_engine.set_version(100) - - # test update weights - path = tmp_path_factory.mktemp("upload_weights_from_disk") - update_weight_meta = WeightUpdateMeta(type="disk", path=str(path)) - train_engine.connect_engine(inf_engine, update_weight_meta) - train_engine.set_version(100) - train_engine.update_weights(update_weight_meta) - inf_engine.destroy() diff --git a/areal/tests/test_local_vllm_engine.py b/areal/tests/test_local_vllm_engine.py deleted file mode 100644 index c8a05a1fa..000000000 --- a/areal/tests/test_local_vllm_engine.py +++ /dev/null @@ -1,85 +0,0 @@ -import os - -import pytest - -from areal.api.cli_args import GenerationHyperparameters, InferenceEngineConfig -from areal.utils.data import get_batch_size -from areal.utils.hf_utils import load_hf_tokenizer - -EXPR_NAME = "test_local_vllm_engine" -TRIAL_NAME = "trial_0" -MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" -if not os.path.exists(MODEL_PATH): - MODEL_PATH = "Qwen/Qwen3-0.6B" - - -def _dummy_reward_fn(*args, **kwargs): - return 1.0 - - -@pytest.fixture(scope="module") -def engine_args(): - """Provide vLLM engine args for local inference.""" - from vllm import EngineArgs - - return EngineArgs( - model=MODEL_PATH, - tensor_parallel_size=1, - gpu_memory_utilization=0.3, - trust_remote_code=True, - ) - - -@pytest.mark.parametrize("n_samples", [1, 2, 4]) -def test_local_vllm_rollout(engine_args, n_samples): - from areal.engine.vllm_local import LocalvLLMEngine - from areal.workflow.rlvr import RLVRWorkflow - - config = InferenceEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - max_concurrent_rollouts=2, - consumer_batch_size=2, - ) - engine = LocalvLLMEngine(config) - engine.initialize(engine_args=engine_args) - - gconfig = GenerationHyperparameters( - max_new_tokens=16, greedy=False, n_samples=n_samples - ) - tokenizer = load_hf_tokenizer(MODEL_PATH) - - workflow = RLVRWorkflow( - reward_fn=_dummy_reward_fn, - gconfig=gconfig, - tokenizer=tokenizer, - enable_thinking=False, - ) - - data = { - "messages": [{"role": "user", "content": "Hello, how are you?"}], - } - result = engine.rollout_batch([data] * 2, workflow=workflow) - assert isinstance(result, dict) - bs = get_batch_size(result) - assert bs == 2 * n_samples - engine.destroy() - - -def test_local_vllm_weight_update_not_supported(engine_args): - """Test that weight updates correctly raise NotImplementedError for vLLM.""" - from areal.api.io_struct import WeightUpdateMeta - from areal.engine.vllm_local import LocalvLLMEngine - - config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) - engine = LocalvLLMEngine(config) - engine.initialize(engine_args=engine_args) - - # Test that disk weight update is not supported - update_weight_meta = WeightUpdateMeta(type="disk", path="/tmp/fake_path") - - with pytest.raises(NotImplementedError, match="vLLM does not support"): - fut = engine.update_weights_from_disk(update_weight_meta) - fut.result() # Wait for the future to complete and raise the exception - - engine.destroy() diff --git a/areal/tests/test_sglang_engine.py b/areal/tests/test_sglang_engine.py deleted file mode 100644 index a3d1b5b02..000000000 --- a/areal/tests/test_sglang_engine.py +++ /dev/null @@ -1,226 +0,0 @@ -import os -import subprocess -import sys -import time - -import pytest -import requests - -from areal.api.cli_args import ( - GenerationHyperparameters, - InferenceEngineConfig, - SGLangConfig, -) -from areal.api.io_struct import WeightUpdateMeta -from areal.utils import network -from areal.utils.data import get_batch_size -from areal.utils.hf_utils import load_hf_tokenizer - -EXPR_NAME = "test_sglang_engine" -TRIAL_NAME = "trial_0" -MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" -if not os.path.exists(MODEL_PATH): - MODEL_PATH = "Qwen/Qwen3-0.6B" -PORT, DIST_PORT = network.find_free_ports(2) -HOST = network.gethostip() -# set a large timeout since we may need to download the model from hub -RUN_SERVER_TIMEOUT = 180 - - -def check_server_health(base_url): - try: - response = requests.get(f"{base_url}/health", timeout=30) - return response.status_code == 200 - except requests.exceptions.RequestException: - return False - - -@pytest.fixture(scope="module") -def sglang_server(): - from areal.utils import seeding - - seeding.set_random_seed(1, EXPR_NAME) - cmd = SGLangConfig.build_cmd( - sglang_config=SGLangConfig( - skip_tokenizer_init=True, - model_path=MODEL_PATH, - mem_fraction_static=0.3, - ), - host=HOST, - port=PORT, - tp_size=1, - base_gpu_id=0, - dist_init_addr=f"{HOST}:{DIST_PORT}", - ) - # Launch process - cmd = cmd.replace("\\\n", " ").replace("\\", " ") - process = subprocess.Popen( - cmd.split(), - text=True, - stdout=sys.stdout, - stderr=sys.stdout, - ) - base_url = f"http://{HOST}:{PORT}" - tik = time.time() - while time.time() - tik < RUN_SERVER_TIMEOUT: - if check_server_health(base_url): - break - time.sleep(1) - if time.time() - tik > RUN_SERVER_TIMEOUT: - raise RuntimeError("server launch failed") - yield - process.terminate() - - -def _dummy_reward_fn(*args, **kwargs): - return 1.0 - - -@pytest.mark.parametrize("n_samples", [1, 2, 4]) -def test_remote_sglang_rollout(sglang_server, n_samples): - from areal.engine.sglang_remote import RemoteSGLangEngine - from areal.workflow.rlvr import RLVRWorkflow - - config = InferenceEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - max_concurrent_rollouts=2, - consumer_batch_size=2, - ) - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - engine = RemoteSGLangEngine(config) - engine.initialize() - - gconfig = GenerationHyperparameters( - max_new_tokens=16, greedy=False, n_samples=n_samples - ) - tokenizer = load_hf_tokenizer(MODEL_PATH) - - workflow = RLVRWorkflow( - reward_fn=_dummy_reward_fn, - gconfig=gconfig, - tokenizer=tokenizer, - enable_thinking=False, - ) - - data = { - "messages": [{"role": "user", "content": "Hello, how are you?"}], - } - result = engine.rollout_batch([data] * 2, workflow=workflow) - assert isinstance(result, dict) - bs = get_batch_size(result) - assert bs == 2 * n_samples - engine.destroy() - - -@pytest.mark.parametrize("ofp", [0, 1, 4]) -@pytest.mark.parametrize("bs", [2]) -@pytest.mark.parametrize("n_samples", [2, 1]) -def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples): - from areal.engine.sglang_remote import RemoteSGLangEngine - from areal.workflow.rlvr import RLVRWorkflow - - config = InferenceEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - consumer_batch_size=bs, - max_head_offpolicyness=ofp, - enable_rollout_tracing=True, - ) - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - engine = RemoteSGLangEngine(config) - engine.initialize() - - gconfig = GenerationHyperparameters( - max_new_tokens=2, greedy=False, n_samples=n_samples - ) - tokenizer = load_hf_tokenizer(MODEL_PATH) - - workflow = RLVRWorkflow( - reward_fn=_dummy_reward_fn, - gconfig=gconfig, - tokenizer=tokenizer, - enable_thinking=False, - ) - data = { - "messages": [{"role": "user", "content": "Hello, how are you?"}], - } - for _ in range(bs * 2): - engine.submit(data, workflow=workflow) - - if ofp < 1: - # Due to controlled offpolicyness, not all requests are committed - with pytest.raises(TimeoutError): - engine.wait(count=bs * 2, timeout=10) - else: - result = engine.wait(count=bs * 2, timeout=10) - assert result["attention_mask"].shape[0] == bs * 2 * n_samples - - # Update model version - engine.set_version(1) - print("Updated model version", flush=True) - - # submit again - for _ in range(bs * 2): - engine.submit(data, workflow=workflow) - - if ofp < 2: - # Due to controlled offpolicyness, not all requests are committed - with pytest.raises(TimeoutError): - engine.wait(count=bs * 4, timeout=5) - else: - # 2 * bs samples haved been retrived above - results = engine.wait(count=bs * 2, timeout=5) - assert results["attention_mask"].shape[0] == bs * 2 * n_samples - - # exit - engine.destroy() - - -def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server): - # setup FSDP engine - from areal.api.cli_args import OptimizerConfig, TrainEngineConfig - from areal.api.io_struct import FinetuneSpec - from areal.engine.fsdp_engine import FSDPEngine - - os.environ["WORLD_SIZE"] = "1" - os.environ["RANK"] = "0" - os.environ["LOCAL_RANK"] = "0" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "7777" - - engine_config = TrainEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - path=MODEL_PATH, - optimizer=OptimizerConfig(), - ) - engine = FSDPEngine(engine_config) - engine.create_process_group() - ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) - engine.initialize(None, ft_spec) - engine.model_version = 100 - - # setup name resolve - import areal.utils.name_resolve as name_resolve - from areal.api.cli_args import NameResolveConfig - - nfs_record_root = tmp_path_factory.mktemp("nfs_record_path") - name_resolve_config = NameResolveConfig(type="nfs", nfs_record_root=nfs_record_root) - name_resolve.reconfigure(name_resolve_config) - # initialize SGLang remote engine - from areal.api.cli_args import InferenceEngineConfig - from areal.engine.sglang_remote import RemoteSGLangEngine - - config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - inf_engine = RemoteSGLangEngine(config) - inf_engine.initialize() - inf_engine.set_version(100) - # test update weights - path = tmp_path_factory.mktemp("upload_weights_from_disk") - update_weight_meta = WeightUpdateMeta(type="disk", path=str(path)) - engine.connect_engine(inf_engine, update_weight_meta) - engine.set_version(100) - engine.update_weights(update_weight_meta) - inf_engine.destroy() diff --git a/areal/tests/test_vllm_engine.py b/areal/tests/test_vllm_engine.py deleted file mode 100644 index 2f2ab425c..000000000 --- a/areal/tests/test_vllm_engine.py +++ /dev/null @@ -1,236 +0,0 @@ -import os -import subprocess -import sys -import time - -import pytest -import requests - -from areal.api.cli_args import ( - GenerationHyperparameters, - InferenceEngineConfig, - vLLMConfig, -) -from areal.api.io_struct import WeightUpdateMeta -from areal.utils import network -from areal.utils.data import get_batch_size -from areal.utils.hf_utils import load_hf_tokenizer -from areal.utils.pkg_version import is_available - -EXPR_NAME = "test_vllm_engine" -TRIAL_NAME = "trial_0" -MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" -if not os.path.exists(MODEL_PATH): - MODEL_PATH = "Qwen/Qwen3-0.6B" -PORT, DIST_PORT = network.find_free_ports(2) -HOST = network.gethostip() -# set a large timeout since we may need to download the model from hub -RUN_SERVER_TIMEOUT = 180 - -IS_VLLM_INSTALLED = is_available("vllm") - - -def check_server_health(base_url): - try: - response = requests.get(f"{base_url}/health", timeout=30) - return response.status_code == 200 - except requests.exceptions.RequestException: - return False - - -@pytest.fixture(scope="module") -def vllm_server(): - from areal.utils import seeding - - seeding.set_random_seed(1, EXPR_NAME) - cmd = vLLMConfig.build_cmd( - vllm_config=vLLMConfig( - skip_tokenizer_init=False, - model=MODEL_PATH, - gpu_memory_utilization=0.1, - ), - host=HOST, - port=PORT, - tp_size=1, - dist_init_addr=f"{HOST}:{DIST_PORT}", - ) - # Launch process - cmd = cmd.replace("\\\n", " ").replace("\\", " ") - process = subprocess.Popen( - cmd.split(), - text=True, - stdout=sys.stdout, - stderr=sys.stdout, - ) - base_url = f"http://{HOST}:{PORT}" - tik = time.time() - while time.time() - tik < RUN_SERVER_TIMEOUT: - if check_server_health(base_url): - break - time.sleep(1) - if time.time() - tik > RUN_SERVER_TIMEOUT: - raise RuntimeError("server launch failed") - yield - process.terminate() - - -def _dummy_reward_fn(*args, **kwargs): - return 1.0 - - -@pytest.mark.skipif( - not IS_VLLM_INSTALLED, reason="Skip the test because vllm is not installed." -) -@pytest.mark.parametrize("n_samples", [1, 2, 4]) -def test_remote_vllm_rollout(vllm_server, n_samples): - from areal.engine.vllm_remote import RemotevLLMEngine - from areal.workflow.rlvr import RLVRWorkflow - - config = InferenceEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - max_concurrent_rollouts=2, - consumer_batch_size=2, - ) - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - engine = RemotevLLMEngine(config) - engine.initialize() - - gconfig = GenerationHyperparameters( - max_new_tokens=16, greedy=False, n_samples=n_samples - ) - tokenizer = load_hf_tokenizer(MODEL_PATH) - - workflow = RLVRWorkflow( - reward_fn=_dummy_reward_fn, - gconfig=gconfig, - tokenizer=tokenizer, - enable_thinking=False, - ) - - data = { - "messages": [{"role": "user", "content": "Hello, how are you?"}], - } - result = engine.rollout_batch([data] * 2, workflow=workflow) - assert isinstance(result, dict) - bs = get_batch_size(result) - assert bs == 2 * n_samples - engine.destroy() - - -@pytest.mark.skipif( - not IS_VLLM_INSTALLED, reason="Skip the test because vllm is not installed." -) -@pytest.mark.parametrize("ofp", [1, 4, 16]) -@pytest.mark.parametrize("bs", [2, 4]) -@pytest.mark.parametrize("n_samples", [2, 1]) -def test_remote_vllm_staleness_control(vllm_server, bs, ofp, n_samples): - from areal.engine.vllm_remote import RemotevLLMEngine - from areal.workflow.rlvr import RLVRWorkflow - - config = InferenceEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - consumer_batch_size=bs, - max_head_offpolicyness=ofp, - ) - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - engine = RemotevLLMEngine(config) - engine.initialize() - - gconfig = GenerationHyperparameters( - max_new_tokens=2, greedy=False, n_samples=n_samples - ) - tokenizer = load_hf_tokenizer(MODEL_PATH) - - workflow = RLVRWorkflow( - reward_fn=_dummy_reward_fn, - gconfig=gconfig, - tokenizer=tokenizer, - enable_thinking=False, - ) - data = { - "messages": [{"role": "user", "content": "Hello, how are you?"}], - } - for _ in range(bs * 2): - engine.submit(data, workflow=workflow) - - if ofp < 1: - # Due to controlled offpolicyness, not all requests are committed - with pytest.raises(TimeoutError): - engine.wait(count=bs * 2, timeout=10) - else: - result = engine.wait(count=bs * 2, timeout=10) - assert result["attention_mask"].shape[0] == bs * 2 * n_samples - - # Update model version - engine.set_version(1) - print("Updated model version", flush=True) - - # submit again - for _ in range(bs * 2): - engine.submit(data, workflow=workflow) - - if ofp < 2: - # Due to controlled offpolicyness, not all requests are committed - with pytest.raises(TimeoutError): - engine.wait(count=bs * 4, timeout=5) - else: - # 2 * bs samples haved been retrived above - results = engine.wait(count=bs * 2, timeout=5) - assert results["attention_mask"].shape[0] == bs * 2 * n_samples - - # exit - engine.destroy() - - -@pytest.mark.skipif( - not IS_VLLM_INSTALLED, reason="Skip the test because vllm is not installed." -) -def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, vllm_server): - # setup FSDP engine - from areal.api.cli_args import OptimizerConfig, TrainEngineConfig - from areal.api.io_struct import FinetuneSpec - from areal.engine.fsdp_engine import FSDPEngine - - os.environ["WORLD_SIZE"] = "1" - os.environ["RANK"] = "0" - os.environ["LOCAL_RANK"] = "0" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "7777" - - engine_config = TrainEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - path=MODEL_PATH, - optimizer=OptimizerConfig(), - ) - engine = FSDPEngine(engine_config) - engine.create_process_group() - ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) - engine.initialize(None, ft_spec) - engine.model_version = 100 - - # setup name resolve - import areal.utils.name_resolve as name_resolve - from areal.api.cli_args import NameResolveConfig - - nfs_record_root = tmp_path_factory.mktemp("nfs_record_path") - name_resolve_config = NameResolveConfig(type="nfs", nfs_record_root=nfs_record_root) - name_resolve.reconfigure(name_resolve_config) - # initialize vLLM remote engine - from areal.api.cli_args import InferenceEngineConfig - from areal.engine.vllm_remote import RemotevLLMEngine - - config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - inf_engine = RemotevLLMEngine(config) - inf_engine.initialize() - inf_engine.set_version(100) - # test update weights - path = tmp_path_factory.mktemp("areal_update_weights") - update_weight_meta = WeightUpdateMeta(type="disk", path=str(path)) - engine.connect_engine(inf_engine, update_weight_meta) - engine.set_version(100) - engine.update_weights(update_weight_meta) - inf_engine.destroy() From 78b489ddf67ac22ab1ddc52bab82bdcaeb6644dc Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Mon, 27 Oct 2025 17:08:23 +0800 Subject: [PATCH 05/52] update --- areal/tests/test_inference_engines.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/areal/tests/test_inference_engines.py b/areal/tests/test_inference_engines.py index 246acbfd2..d16031084 100644 --- a/areal/tests/test_inference_engines.py +++ b/areal/tests/test_inference_engines.py @@ -234,14 +234,6 @@ def test_staleness_control(inference_engine, bs, ofp, n_samples): """Test engine staleness control mechanism.""" from areal.workflow.rlvr import RLVRWorkflow - # Skip certain parameter combinations based on backend - if inference_engine["backend"] == "sglang" and bs == 4: - pytest.skip("SGLang only tests with bs=2") - if inference_engine["backend"] == "sglang" and ofp == 16: - pytest.skip("SGLang doesn't test with ofp=16") - if inference_engine["backend"] == "vllm" and ofp == 0: - pytest.skip("vLLM doesn't test with ofp=0") - config = InferenceEngineConfig( experiment_name=inference_engine["expr_name"], trial_name=inference_engine["trial_name"], @@ -309,9 +301,6 @@ def test_staleness_control(inference_engine, bs, ofp, n_samples): def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, inference_engine): """Test disk-based weight updates from FSDP engine to inference engine.""" - # Skip weight update test for local vLLM (not supported) - if inference_engine["backend"] == "vllm" and inference_engine["mode"] == "local": - pytest.skip("Local vLLM doesn't support weight updates") # setup FSDP engine from areal.api.cli_args import OptimizerConfig, TrainEngineConfig From 722afadd53c448d810820acba9c414ad4640e118 Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Tue, 28 Oct 2025 19:27:23 +0800 Subject: [PATCH 06/52] fix --- areal/api/cli_args.py | 96 +++++++++--------- areal/api/engine_api.py | 64 +++++++----- areal/core/local_inf_engine.py | 63 +++++++++--- areal/core/remote_inf_engine.py | 18 +++- areal/engine/sglang_local.py | 13 ++- areal/engine/sglang_remote.py | 41 +++++--- areal/engine/vllm_local.py | 93 ++++++++++++------ areal/engine/vllm_remote.py | 41 +++++--- areal/tests/test_inference_engines.py | 135 ++++++++++++-------------- docs/cli_reference.md | 6 +- 10 files changed, 344 insertions(+), 226 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index aecccaafa..457ba5655 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -3,14 +3,10 @@ import os from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Dict, List +from typing import Any import uvloop import yaml - -from areal.utils.pkg_version import is_version_less - -uvloop.install() from hydra import compose as hydra_compose from hydra import initialize as hydra_init from hydra.core.global_hydra import GlobalHydra @@ -18,6 +14,9 @@ from areal.platforms import current_platform from areal.utils import name_resolve, pkg_version +from areal.utils.pkg_version import is_version_less + +uvloop.install() @dataclass @@ -129,11 +128,11 @@ class GenerationHyperparameters: default=1.0, metadata={"help": "Sampling temperature. Higher values increase diversity."}, ) - stop_token_ids: List[int] = field( + stop_token_ids: list[int] = field( default_factory=list, metadata={"help": "Stop generation when encountering these token IDs."}, ) - stop: List[str] | None = field( + stop: list[str] | None = field( default=None, metadata={ "help": "One or multiple stop words. Generation will stop if one of these words is sampled." @@ -232,7 +231,7 @@ class OptimizerConfig: class FSDPWrapPolicy: """Policy configuration for FSDP model layer wrapping. None defaults to wrapping transformer decoder layers defined by transformers.""" - transformer_layer_cls_to_wrap: List[str] | None = field( + transformer_layer_cls_to_wrap: list[str] | None = field( default=None, metadata={"help": "A list of transformer layer names for FSDP to wrap."}, ) @@ -310,7 +309,7 @@ class MegatronEngineConfig: recompute_method: str | None = "uniform" recompute_num_layers: int | None = 1 distribute_saved_activations: bool | None = None - recompute_modules: List[str] | None = None + recompute_modules: list[str] | None = None @dataclass @@ -378,7 +377,7 @@ class TrainEngineConfig: ) lora_rank: int = field(default=32, metadata={"help": "lora rank"}) lora_alpha: int = field(default=16, metadata={"help": "lora alpha"}) - target_modules: List[str] = field( + target_modules: list[str] = field( default_factory=list, metadata={"help": "lora target_modules."}, ) @@ -500,7 +499,7 @@ class PPOActorConfig(TrainEngineConfig): default=False, metadata={"help": "Log statistics for agent trajectories"}, ) - log_agent_stats_keys: List[str] = field( + log_agent_stats_keys: list[str] = field( default_factory=lambda: [], metadata={"help": "Keys for logging agent trajectory statistics"}, ) @@ -574,7 +573,7 @@ def build_args( port, dist_init_addr: str | None = None, ): - args: Dict = conf_as_dict(vllm_config) + args: dict = conf_as_dict(vllm_config) args = dict( host=host, port=port, @@ -587,6 +586,21 @@ def build_args( ) return args + @staticmethod + def build_cmd_from_args(args: dict[str, Any]): + # convert to flags + flags = [] + for k, v in args.items(): + if v is None or v is False or v == "": + continue + if v is True: + flags.append(f"--{k.replace('_', '-')}") + elif isinstance(v, list): + flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}") + else: + flags.append(f"--{k.replace('_', '-')} {v}") + return f"python3 -m areal.thirdparty.vllm.areal_vllm_server {' '.join(flags)}" + @staticmethod def build_cmd( vllm_config: "vLLMConfig", @@ -602,18 +616,7 @@ def build_cmd( port=port, dist_init_addr=dist_init_addr, ) - # convert to flags - flags = [] - for k, v in args.items(): - if v is None or v is False or v == "": - continue - if v is True: - flags.append(f"--{k.replace('_','-')}") - elif isinstance(v, list): - flags.append(f"--{k.replace('_','-')} {' '.join(map(str, v))}") - else: - flags.append(f"--{k.replace('_','-')} {v}") - return f"python3 -m areal.thirdparty.vllm.areal_vllm_server {' '.join(flags)}" + return vLLMConfig.build_cmd_from_args(args) @dataclass @@ -638,7 +641,7 @@ class SGLangConfig: enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: int | None = None - cuda_graph_bs: List[int] | None = None + cuda_graph_bs: list[int] | None = None torchao_config: str = "" enable_nan_detection: bool = False enable_p2p_check: bool = False @@ -667,8 +670,8 @@ class SGLangConfig: # lora enable_lora: bool | None = None max_lora_rank: int | None = None - lora_target_modules: List[str] | None = None - lora_paths: List[str] | None = None + lora_target_modules: list[str] | None = None + lora_paths: list[str] | None = None max_loaded_loras: int = 1 max_loras_per_batch: int = 1 lora_backend: str = "triton" @@ -711,6 +714,10 @@ def build_cmd( node_rank=node_rank, ) + return SGLangConfig.build_cmd_from_args(args) + + @staticmethod + def build_cmd_from_args(args: dict[str, Any]): # convert to flags flags = [] for k, v in args.items(): @@ -719,11 +726,11 @@ def build_cmd( if v is None or v is False or v == "": continue if v is True: - flags.append(f"--{k.replace('_','-')}") + flags.append(f"--{k.replace('_', '-')}") elif isinstance(v, list): - flags.append(f"--{k.replace('_','-')} {' '.join(map(str, v))}") + flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}") else: - flags.append(f"--{k.replace('_','-')} {v}") + flags.append(f"--{k.replace('_', '-')} {v}") return f"python3 -m sglang.launch_server {' '.join(flags)}" @staticmethod @@ -737,21 +744,20 @@ def build_args( n_nodes: int = 1, node_rank: int = 0, ): - # Map "all-linear" to "all" - args: Dict = conf_as_dict(sglang_config) + args: dict = conf_as_dict(sglang_config) if sglang_config.enable_multithread_load or sglang_config.enable_fast_load: - assert pkg_version.is_version_equal( - "sglang", "0.5.2" - ), f"Customized model loading requires exact SGLang version 0.5.2" + assert pkg_version.is_version_equal("sglang", "0.5.2"), ( + "Customized model loading requires exact SGLang version 0.5.2" + ) model_loader_extra_config = dict( enable_multithread_load=sglang_config.enable_multithread_load, enable_fast_load=sglang_config.enable_fast_load, ) - args.pop("enable_multithread_load", None) - args.pop("enable_fast_load", None) args["model_loader_extra_config"] = json.dumps( model_loader_extra_config, separators=(",", ":") ) + args.pop("enable_multithread_load", None) + args.pop("enable_fast_load", None) # Map "all-linear" to "all" if "lora_target_modules" in args and args["lora_target_modules"]: args["lora_target_modules"] = [ @@ -915,8 +921,8 @@ class WandBConfig: job_type: str | None = None group: str | None = None notes: str | None = None - tags: List[str] | None = None - config: Dict | None = None + tags: list[str] | None = None + config: dict | None = None id_suffix: str | None = "train" @@ -926,7 +932,7 @@ class SwanlabConfig: project: str | None = None name: str | None = None - config: Dict | None = None + config: dict | None = None logdir: str | None = None mode: str | None = "disabled" api_key: str | None = os.getenv("SWANLAB_API_KEY", None) @@ -1023,7 +1029,7 @@ class SchedulerConfig: endpoint: str = field(default="http://localhost:8081") deploy_mode: str = field(default="separation") functioncall_service_domain: str = field(default="http://localhost:8080") - reward_functioncall_config: Dict = field(default_factory=dict) + reward_functioncall_config: dict = field(default_factory=dict) reward_model_path: str = field(default="") reward_model_service_url: str = field(default="http://localhost:30000/classify") @@ -1076,7 +1082,7 @@ class SlurmLauncherConfig: default="--mpi=pmi2 -K --chdir $PWD", metadata={"help": "Additional arguments to pass to the srun command."}, ) - additional_bash_cmds: List[str] | None = field( + additional_bash_cmds: list[str] | None = field( default=None, metadata={ "help": "Additional bash commands to setup the container before running " @@ -1244,7 +1250,7 @@ class PPOConfig(GRPOConfig): critic: PPOCriticConfig = field(default_factory=PPOCriticConfig) -def parse_cli_args(argv: List[str]): +def parse_cli_args(argv: list[str]): parser = argparse.ArgumentParser() parser.add_argument( "--config", help="Path to the main configuration file", required=True @@ -1277,7 +1283,7 @@ def to_structured_cfg(cfg, config_cls): return cfg -def load_expr_config(argv: List[str], config_cls): +def load_expr_config(argv: list[str], config_cls): cfg, config_file = parse_cli_args(argv) cfg = to_structured_cfg(cfg, config_cls=config_cls) cfg = OmegaConf.to_object(cfg) @@ -1305,7 +1311,7 @@ def save_config(cfg, log_dir): os.makedirs(log_dir, exist_ok=True) config_save_path = os.path.join(log_dir, "config.yaml") with open(config_save_path, "w") as f: - config_dict: Dict = asdict(cfg) + config_dict: dict = asdict(cfg) yaml.dump( config_dict, f, diff --git a/areal/api/engine_api.py b/areal/api/engine_api.py index 5761b2565..bada2f299 100644 --- a/areal/api/engine_api.py +++ b/areal/api/engine_api.py @@ -1,7 +1,8 @@ import abc +from collections.abc import Callable from concurrent.futures import Future from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Optional import torch import torch.distributed as dist @@ -30,14 +31,16 @@ class Scheduling: partition: str | None = None container_image: str | None = None type: str | None = None - env_vars: Dict[str, str] = field(default_factory=dict) + env_vars: dict[str, str] = field(default_factory=dict) # time utils from "https://slurm.schedmd.com/sbatch.html" - time_limit: Optional[str] = None # see "--time" option for format - begin: Optional[str] = None # see "--begin" option for format - deadline: Optional[str] = None # see "--deadline" option for format + time_limit: str | None = None # see "--time" option for format + begin: str | None = None # see "--begin" option for format + deadline: str | None = None # see "--deadline" option for format class TrainEngine(abc.ABC): + def configure(self, config): + raise NotImplementedError() def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): """Initialize PyTorch distributed communication groups. @@ -49,6 +52,9 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None """ raise NotImplementedError() + def destroy_process_group(self): + raise NotImplementedError() + def initialize(self, *args, **kwargs): """Initialize environments for distributed training and load models. @@ -241,10 +247,10 @@ def step_lr_scheduler(self): def train_batch( self, - input_: Dict[str, Any], - loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor], - loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor], - ) -> Dict[str, float]: + input_: dict[str, Any], + loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], + loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], + ) -> dict[str, float]: """Update the model with a batch of data and a loss function. Note @@ -276,9 +282,9 @@ def train_batch( @torch.no_grad() def eval_batch( self, - input_: Dict[str, Any], - loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor], - loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor], + input_: dict[str, Any], + loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], + loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], ) -> torch.Tensor | None: """Evaluate the model using the forward pass and loss function. @@ -311,10 +317,10 @@ def eval_batch( @torch.no_grad() def forward( self, - input_: Dict[str, Any], - output_seqlens: List[int] | None = None, - post_hook: Callable[[torch.Tensor, Dict[str, Any]], Any] | None = None, - aggregate_fn: Callable[[List[Any]], Any] = torch.cat, + input_: dict[str, Any], + output_seqlens: list[int] | None = None, + post_hook: Callable[[torch.Tensor, dict[str, Any]], Any] | None = None, + aggregate_fn: Callable[[list[Any]], Any] = torch.cat, ) -> Any | None: """Run the forward pass or inference on the model. @@ -345,6 +351,14 @@ def forward( class InferenceEngine(abc.ABC): + def configure(self, config): + raise NotImplementedError() + + def create_engine(self, engine_args: dict[str, Any]): + raise NotImplementedError() + + def destroy_engine(self): + raise NotImplementedError() def initialize(self, *args, **kwargs): """Initialize environments and launch the background thread for asynchronous distributed inference. @@ -405,7 +419,7 @@ def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: raise NotImplementedError() def update_weights_from_distributed( - self, meta: WeightUpdateMeta, param_specs: List[ParamSpec] + self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] ) -> Future[None]: """Update weights in the inference engine in a non-blocking manner. @@ -460,9 +474,9 @@ def get_version(self) -> int: def submit( self, - data: Dict[str, Any], + data: dict[str, Any], workflow: Optional["RolloutWorkflow"] = None, - workflow_builder: Optional[Callable] = None, + workflow_builder: Callable | None = None, should_accept: Callable | None = None, ) -> None: """Submit a request to the inference engine and return immediately. @@ -486,7 +500,7 @@ def submit( """ raise NotImplementedError() - def wait(self, count: int, timeout: float | None = None) -> Dict[str, Any]: + def wait(self, count: int, timeout: float | None = None) -> dict[str, Any]: """Wait for a specified number of requests to complete, with a timeout. Should be used together with preceding `submit`. @@ -512,11 +526,11 @@ def wait(self, count: int, timeout: float | None = None) -> Dict[str, Any]: def rollout_batch( self, - data: List[Dict[str, Any]], + data: list[dict[str, Any]], workflow: Optional["RolloutWorkflow"] = None, - workflow_builder: Optional[Callable] = None, + workflow_builder: Callable | None = None, should_accept: Callable | None = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Submit a batch of requests to the inference engine and wait for the results. See `workflow_api.py` for concrete implementation. @@ -543,9 +557,9 @@ def prepare_batch( self, dataloader: StatefulDataLoader, workflow: Optional["RolloutWorkflow"] = None, - workflow_builder: Optional[Callable] = None, + workflow_builder: Callable | None = None, should_accept: Callable | None = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Asynchronously submit and wait until a full batch is ready with controlled staleness. See `workflow_api.py` for concrete implementation. diff --git a/areal/core/local_inf_engine.py b/areal/core/local_inf_engine.py index 5119c2778..cb58334a4 100644 --- a/areal/core/local_inf_engine.py +++ b/areal/core/local_inf_engine.py @@ -125,6 +125,14 @@ def destroy(self, engine: Any) -> None: """ ... + def pause_generation(self) -> None: + """Pause generation.""" + ... + + def continue_generation(self) -> None: + """Continue generation.""" + ... + class LocalInfEngine: """ @@ -157,10 +165,22 @@ def __init__(self, config: InferenceEngineConfig, backend: LocalInfBackendProtoc self.workflow_executor: WorkflowExecutor + def configure(self, config): + self.config = config + + def create_engine(self, engine_args: dict[str, Any] | None = None): + # Create the local engine via backend + engine_args = engine_args or {} + self.engine = self.backend.create_engine(engine_args) + + def destroy_engine(self): + if self.engine is not None: + self.backend.destroy(self.engine) + self.engine = None + def initialize( self, engine_id: str | None = None, - engine_args: dict[str, Any] | None = None, train_data_parallel_size: int | None = None, ): """Initialize the engine by creating the local inference engine. @@ -182,12 +202,6 @@ def initialize( self.engine_id = engine_id self.logger = logging.getLogger(f"[Local Inference Engine Rank {engine_id}]") - # Create the local engine via backend - engine_args = engine_args or {} - self.logger.info(f"Creating local inference engine with args: {engine_args}") - self.engine = self.backend.create_engine(engine_args) - self.logger.info("Local inference engine created successfully!") - # Initialize thread pool for non-blocking weight updates self.executor = ThreadPoolExecutor(max_workers=1) @@ -202,11 +216,12 @@ def initialize( def destroy(self): """Destroy the engine and clean up resources.""" - self.workflow_executor.destroy() - if self.engine is not None: - self.backend.destroy(self.engine) - self.engine = None - self.executor.shutdown() + if getattr(self, "workflow_executor"): + self.workflow_executor.destroy() + self.workflow_executor = None + if getattr(self, "executor"): + self.executor.shutdown() + self.executor = None def set_version(self, version: int): """Set the current weight version.""" @@ -275,12 +290,11 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse: and len(accumulated_output_tokens) < gconfig.max_new_tokens ): # Handle rollout interruption - while self.workflow_executor.paused.is_set(): + while self.workflow_executor.is_paused(): await asyncio.sleep(0.5) # Call backend async_generation response = await self.backend.async_generation(self.engine, req) - # Extract result output_tokens = response.output_tokens output_logprobs = response.output_logprobs @@ -456,7 +470,7 @@ def _update_weights_from_disk_sync(self, meta: WeightUpdateMeta) -> float: update_name = names.update_weights_from_disk( self.config.experiment_name, self.config.trial_name, - meta.model_version, + str(self.get_version()), ) save_timestamp = float(name_resolve.wait(update_name, timeout=120)) load_timestamp = time.time() @@ -472,7 +486,6 @@ def _update_weights_from_disk_sync(self, meta: WeightUpdateMeta) -> float: self.logger.info( f"Loading weights done in {(time.time() - load_timestamp) * 1000:.2f} ms" ) - self.set_version(meta.model_version) return load_timestamp - save_timestamp @@ -594,3 +607,21 @@ def pause(self): def resume(self): """Resume request submission for async rollout.""" return self.workflow_executor.resume() + + def pause_generation(self): + """Pause request submission for async rollout.""" + try: + self.backend.pause_generation() + except NotImplementedError: + self.logger.warning("Backend does not support pause operation") + + # The above http request may require some time to be scheduled and executed. + # The following line waits until all requests are indeed dropped. + time.sleep(self.config.pause_grace_period) + + def continue_generation(self): + """Resume request submission for async rollout.""" + try: + self.backend.continue_generation() + except NotImplementedError: + self.logger.warning("Backend does not support resume operation") diff --git a/areal/core/remote_inf_engine.py b/areal/core/remote_inf_engine.py index 7eae12c59..44f82f52b 100644 --- a/areal/core/remote_inf_engine.py +++ b/areal/core/remote_inf_engine.py @@ -229,6 +229,16 @@ def __init__( self.workflow_executor: WorkflowExecutor + def configure(self, config): + self.config = config + + def create_engine(self, engine_args): + # remote inference engine does not need to create an engine + return + + def destroy_engine(self): + return + def _wait_for_server(self, address): """Wait for a server to become healthy.""" base_url = f"http://{address}" @@ -323,8 +333,12 @@ def initialize( def destroy(self): """Destroy the engine and clean up resources.""" - self.workflow_executor.destroy() - self.executor.shutdown() + if getattr(self, "workflow_executor"): + self.workflow_executor.destroy() + self.workflow_executor = None + if getattr(self, "executor"): + self.executor.shutdown() + self.executor = None def set_version(self, version): """Set the current weight version.""" diff --git a/areal/engine/sglang_local.py b/areal/engine/sglang_local.py index 0b28608be..ee08bbd08 100644 --- a/areal/engine/sglang_local.py +++ b/areal/engine/sglang_local.py @@ -223,10 +223,19 @@ def __init__(self, config: InferenceEngineConfig): # Pure composition - create internal engine with SGLang backend self._engine = LocalInfEngine(config, SGLangLocalBackend()) + def configure(self, config): + self.config = config + self._engine.configure(config) + + def create_engine(self, engine_args): + return self._engine.create_engine(engine_args) + + def destroy_engine(self): + self._engine.destroy_engine() + def initialize( self, engine_id: str | None = None, - engine_args: dict[str, Any] | None = None, train_data_parallel_size: int | None = None, ): """Initialize the engine by creating the local SGLang engine. @@ -240,7 +249,7 @@ def initialize( train_data_parallel_size : int | None Data parallel size of the training engine """ - return self._engine.initialize(engine_id, engine_args, train_data_parallel_size) + return self._engine.initialize(engine_id, train_data_parallel_size) def destroy(self): """Destroy the engine and clean up resources.""" diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 02c86b9bb..8f4f39d90 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from concurrent.futures import Future -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Optional from torchdata.stateful_dataloader import StatefulDataLoader @@ -56,7 +57,7 @@ def build_generation_request( return HttpRequest(endpoint="/generate", payload=payload) def parse_generation_response( - self, response: Dict[str, Any] + self, response: dict[str, Any] ) -> HttpGenerationResult: """Parse SGLang generation response.""" meta_info = response["meta_info"] @@ -119,7 +120,7 @@ def build_disk_weight_update_requests( ) def build_distributed_weight_update_requests( - self, meta: WeightUpdateMeta, param_specs: List[ParamSpec] + self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] ) -> WeightUpdateRequests: """Build SGLang distributed weight update requests.""" return WeightUpdateRequests( @@ -187,10 +188,20 @@ def __init__(self, config: InferenceEngineConfig): # Pure composition - create internal engine with SGLang backend self._engine = RemoteInfEngine(config, SGLangBackend()) + def configure(self, config): + self.config = config + self._engine.configure(config) + + def create_engine(self, *args, **kwargs): + return self._engine.create_engine(*args, **kwargs) + + def destroy_engine(self, *args, **kwargs): + return self._engine.destroy_engine(*args, **kwargs) + def initialize( self, - engine_id: Optional[str] = None, - addr: str | List[str] | None = None, + engine_id: str | None = None, + addr: str | list[str] | None = None, train_data_parallel_size: int | None = None, ): """Initialize the engine by discovering and connecting to servers.""" @@ -217,7 +228,7 @@ def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: return self._engine.init_weights_update_group(meta) def update_weights_from_distributed( - self, meta: WeightUpdateMeta, param_specs: List[ParamSpec] + self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] ) -> Future[None]: """Update weights from distributed memory.""" return self._engine.update_weights_from_distributed(meta, param_specs) @@ -228,25 +239,25 @@ def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: def submit( self, - data: Dict[str, Any], - workflow: Optional[RolloutWorkflow] = None, - workflow_builder: Optional[Callable] = None, + data: dict[str, Any], + workflow: RolloutWorkflow | None = None, + workflow_builder: Callable | None = None, should_accept: Callable | None = None, ) -> None: """Submit a request to the inference engine.""" return self._engine.submit(data, workflow, workflow_builder, should_accept) - def wait(self, count: int, timeout: float | None = None) -> Dict[str, Any]: + def wait(self, count: int, timeout: float | None = None) -> dict[str, Any]: """Wait for a specified number of requests to complete.""" return self._engine.wait(count, timeout) def rollout_batch( self, - data: List[Dict[str, Any]], + data: list[dict[str, Any]], workflow: Optional["RolloutWorkflow"] = None, - workflow_builder: Optional[Callable] = None, + workflow_builder: Callable | None = None, should_accept: Callable | None = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Submit a batch of requests and wait for results.""" return self._engine.rollout_batch( data, workflow, workflow_builder, should_accept @@ -255,8 +266,8 @@ def rollout_batch( def prepare_batch( self, dataloader: StatefulDataLoader, - workflow: Optional[RolloutWorkflow] = None, - workflow_builder: Optional[Callable] = None, + workflow: RolloutWorkflow | None = None, + workflow_builder: Callable | None = None, should_accept: Callable | None = None, ): """Asynchronously submit and wait until a full batch is ready.""" diff --git a/areal/engine/vllm_local.py b/areal/engine/vllm_local.py index e8f7a3c8e..18908cd86 100644 --- a/areal/engine/vllm_local.py +++ b/areal/engine/vllm_local.py @@ -1,3 +1,4 @@ +import asyncio import time import uuid from collections.abc import Callable @@ -16,6 +17,7 @@ ) from areal.api.workflow_api import RolloutWorkflow from areal.core.local_inf_engine import LocalInfEngine +from areal.platforms import current_platform class VLLMLocalBackend: @@ -37,9 +39,13 @@ def create_engine(self, engine_args: dict[str, Any]) -> Any: Any The created vLLM AsyncLLMEngine instance """ - from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm import AsyncEngineArgs, AsyncLLMEngine - engine = AsyncLLMEngine.from_engine_args(**engine_args) + engine_args.pop("host", None) + engine_args.pop("port", None) + engine_args.pop("uvicorn_log_level", None) + + engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) return engine async def async_generation(self, engine: Any, req: ModelRequest) -> ModelResponse: @@ -79,20 +85,21 @@ async def async_generation(self, engine: Any, req: ModelRequest) -> ModelRespons request_id = uuid.uuid4().hex # Call vLLM's generate method which returns an async generator + from vllm.inputs.data import TokensPrompt + results_generator = engine.generate( - prompt=None, + prompt=TokensPrompt(prompt_token_ids=req.input_ids), sampling_params=sampling_params, request_id=request_id, - prompt_token_ids=req.input_ids, ) # Iterate through the generator to get the final result - final_output = None + final_output = None # RequestOutput async for request_output in results_generator: final_output = request_output # Parse response - if final_output is None or len(final_output.outputs) == 0: + if final_output is None: latency = time.perf_counter() - start_time return ModelResponse( input_tokens=req.input_ids, @@ -108,6 +115,7 @@ async def async_generation(self, engine: Any, req: ModelRequest) -> ModelRespons ) # Extract first completion output + assert len(final_output.outputs) == 1 completion_output = final_output.outputs[0] stop_reason = completion_output.finish_reason @@ -116,16 +124,8 @@ async def async_generation(self, engine: Any, req: ModelRequest) -> ModelRespons # Extract logprobs - vLLM returns logprobs as a list of dicts output_logprobs = [] - if completion_output.logprobs: - for token_logprobs in completion_output.logprobs: - if token_logprobs: - # Get logprob for the actual selected token - # token_logprobs is a dict mapping token_id to Logprob object - # We need to find the logprob for the token that was selected - max_logprob = max(token_logprobs.values(), key=lambda x: x.logprob) - output_logprobs.append(max_logprob.logprob) - else: - output_logprobs.append(0.0) + for token_logprobs, token_id in zip(completion_output.logprobs, output_tokens): + output_logprobs.append(token_logprobs[token_id].logprob) latency = time.perf_counter() - start_time @@ -152,12 +152,11 @@ def update_weight_disk(self, engine: Any, model_path: str) -> None: model_path : str Path to the model weights on disk """ - # vLLM doesn't support updating weights from disk - # Typically requires creating a new engine - raise NotImplementedError( - "vLLM does not support updating weights from disk. " - "Please create a new engine instance with the new weights." + loop = asyncio.new_event_loop() + loop.run_until_complete( + engine.collective_rpc("areal_injected_update_weight", model_path) ) + return None def update_weight_xccl( self, @@ -176,11 +175,20 @@ def update_weight_xccl( param_specs : List[ParamSpec] Specifications for parameters to be updated """ - # vLLM doesn't support distributed weight updates in the same way - raise NotImplementedError( - "vLLM does not support distributed weight updates via NCCL/XCCL. " - "Please use disk-based updates or create a new engine instance." + loop = asyncio.new_event_loop() + task = engine.collective_rpc( + "set_weight_meta", + args=( + [pspec.name for pspec in param_specs], + [pspec.dtype for pspec in param_specs], + [pspec.shape for pspec in param_specs], + ), + ) + loop.run_until_complete(task) + loop.run_until_complete( + engine.collective_rpc("areal_injected_update_weight_xccl") ) + return None def init_update_weight_group( self, engine: Any, meta: WeightUpdateMeta, rank_offset: int @@ -196,10 +204,20 @@ def init_update_weight_group( rank_offset : int Rank offset for this engine in the communication group """ - # vLLM doesn't support initializing weight update groups - raise NotImplementedError( - "vLLM does not support weight update communication groups." + task = engine.collective_rpc( + "init_update_weight_group", + args=( + meta.nccl_master_address, + str(meta.nccl_master_port), + rank_offset, + meta.alloc_mode.gen.world_size + 1, + current_platform.communication_backend, + meta.nccl_group_name, + ), ) + loop = asyncio.new_event_loop() + loop.run_until_complete(task) + return None def destroy(self, engine: Any) -> None: """Destroy the engine and release resources. @@ -214,6 +232,12 @@ def destroy(self, engine: Any) -> None: if hasattr(engine, "shutdown"): engine.shutdown() + def pause_generation(self): + raise NotImplementedError() + + def continue_generation(self): + raise NotImplementedError() + class LocalvLLMEngine(InferenceEngine): """vLLM local inference engine. @@ -235,10 +259,19 @@ def __init__(self, config: InferenceEngineConfig): # Pure composition - create internal engine with vLLM backend self._engine = LocalInfEngine(config, VLLMLocalBackend()) + def configure(self, config): + self.config = config + self._engine.configure(config) + + def create_engine(self, engine_args): + return self._engine.create_engine(engine_args) + + def destroy_engine(self): + self._engine.destroy_engine() + def initialize( self, engine_id: str | None = None, - engine_args: dict[str, Any] | None = None, train_data_parallel_size: int | None = None, ): """Initialize the engine by creating the local vLLM engine. @@ -252,7 +285,7 @@ def initialize( train_data_parallel_size : int | None Data parallel size of the training engine """ - return self._engine.initialize(engine_id, engine_args, train_data_parallel_size) + return self._engine.initialize(engine_id, train_data_parallel_size) def destroy(self): """Destroy the engine and clean up resources.""" diff --git a/areal/engine/vllm_remote.py b/areal/engine/vllm_remote.py index 64c3176f2..6f31249ae 100644 --- a/areal/engine/vllm_remote.py +++ b/areal/engine/vllm_remote.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from concurrent.futures import Future -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Optional from torchdata.stateful_dataloader import StatefulDataLoader @@ -47,7 +48,7 @@ def build_generation_request( return HttpRequest(endpoint="/v1/completions", payload=payload) def parse_generation_response( - self, response: Dict[str, Any] + self, response: dict[str, Any] ) -> HttpGenerationResult: """Parse vLLM generation response.""" meta_info = response["choices"][0] @@ -87,7 +88,7 @@ def build_disk_weight_update_requests( ) def build_distributed_weight_update_requests( - self, meta: WeightUpdateMeta, param_specs: List[ParamSpec] + self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] ) -> WeightUpdateRequests: """Build vLLM distributed weight update requests.""" # vLLM uses two-step process: set metadata, then update @@ -160,10 +161,20 @@ def __init__(self, config: InferenceEngineConfig): # Pure composition - create internal engine with vLLM backend self._engine = RemoteInfEngine(config, VLLMBackend()) + def configure(self, config): + self.config = config + self._engine.configure(config) + + def create_engine(self, *args, **kwargs): + return self._engine.create_engine(*args, **kwargs) + + def destroy_engine(self, *args, **kwargs): + return self._engine.destroy_engine(*args, **kwargs) + def initialize( self, - engine_id: Optional[str] = None, - addr: str | List[str] | None = None, + engine_id: str | None = None, + addr: str | list[str] | None = None, train_data_parallel_size: int | None = None, ): """Initialize the engine by discovering and connecting to servers.""" @@ -190,7 +201,7 @@ def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: return self._engine.init_weights_update_group(meta) def update_weights_from_distributed( - self, meta: WeightUpdateMeta, param_specs: List[ParamSpec] + self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] ) -> Future[None]: """Update weights from distributed memory.""" return self._engine.update_weights_from_distributed(meta, param_specs) @@ -201,25 +212,25 @@ def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: def submit( self, - data: Dict[str, Any], - workflow: Optional[RolloutWorkflow] = None, - workflow_builder: Optional[Callable] = None, + data: dict[str, Any], + workflow: RolloutWorkflow | None = None, + workflow_builder: Callable | None = None, should_accept: Callable | None = None, ) -> None: """Submit a request to the inference engine.""" return self._engine.submit(data, workflow, workflow_builder, should_accept) - def wait(self, count: int, timeout: float | None = None) -> Dict[str, Any]: + def wait(self, count: int, timeout: float | None = None) -> dict[str, Any]: """Wait for a specified number of requests to complete.""" return self._engine.wait(count, timeout) def rollout_batch( self, - data: List[Dict[str, Any]], + data: list[dict[str, Any]], workflow: Optional["RolloutWorkflow"] = None, - workflow_builder: Optional[Callable] = None, + workflow_builder: Callable | None = None, should_accept: Callable | None = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Submit a batch of requests and wait for results.""" return self._engine.rollout_batch( data, workflow, workflow_builder, should_accept @@ -228,8 +239,8 @@ def rollout_batch( def prepare_batch( self, dataloader: StatefulDataLoader, - workflow: Optional[RolloutWorkflow] = None, - workflow_builder: Optional[Callable] = None, + workflow: RolloutWorkflow | None = None, + workflow_builder: Callable | None = None, should_accept: Callable | None = None, ): """Asynchronously submit and wait until a full batch is ready.""" diff --git a/areal/tests/test_inference_engines.py b/areal/tests/test_inference_engines.py index d16031084..6248bc7c2 100644 --- a/areal/tests/test_inference_engines.py +++ b/areal/tests/test_inference_engines.py @@ -46,12 +46,18 @@ def _dummy_reward_fn(*args, **kwargs): @pytest.fixture( params=[ - ("vllm", "remote"), - ("vllm", "local"), + # ("vllm", "remote"), + # ("vllm", "local"), ("sglang", "remote"), - ("sglang", "local"), + # ("sglang", "local"), ], - ids=["vllm-remote", "vllm-local", "sglang-remote", "sglang-local"], + ids=[ + # "vllm-remote", + # "vllm-local", + "sglang-remote", + # "sglang-local", + ], + scope="module", ) def inference_engine(request): """Unified fixture that provides any inference engine (vLLM/SGLang, local/remote). @@ -74,42 +80,50 @@ def inference_engine(request): seeding.set_random_seed(1, expr_name) + port, dist_port = network.find_free_ports(2) + host = network.gethostip() + sglang_config = SGLangConfig( + skip_tokenizer_init=True, + model_path=MODEL_PATH, + mem_fraction_static=0.1, + ) + sglang_args = SGLangConfig.build_args( + sglang_config=sglang_config, + tp_size=1, + base_gpu_id=0, + host=host, + port=port, + dist_init_addr=f"{host}:{dist_port}", + ) + vllm_config = vLLMConfig( + skip_tokenizer_init=False, + model=MODEL_PATH, + gpu_memory_utilization=0.1, + ) + vllm_args = vLLMConfig.build_args( + vllm_config=vllm_config, + tp_size=1, + host=host, + port=port, + ) + config = InferenceEngineConfig( + experiment_name=expr_name, + trial_name=trial_name, + ) + # Initialize engine based on backend and mode if mode == "remote": # Launch server - port, dist_port = network.find_free_ports(2) - host = network.gethostip() if backend == "vllm": from areal.engine.vllm_remote import RemotevLLMEngine - cmd = vLLMConfig.build_cmd( - vllm_config=vLLMConfig( - skip_tokenizer_init=False, - model=MODEL_PATH, - gpu_memory_utilization=0.1, - ), - host=host, - port=port, - tp_size=1, - dist_init_addr=f"{host}:{dist_port}", - ) + cmd = vLLMConfig.build_cmd_from_args(vllm_args) engine_class = RemotevLLMEngine else: # sglang from areal.engine.sglang_remote import RemoteSGLangEngine - cmd = SGLangConfig.build_cmd( - sglang_config=SGLangConfig( - skip_tokenizer_init=True, - model_path=MODEL_PATH, - mem_fraction_static=0.3, - ), - host=host, - port=port, - tp_size=1, - base_gpu_id=0, - dist_init_addr=f"{host}:{dist_port}", - ) + cmd = SGLangConfig.build_cmd_from_args(sglang_args) engine_class = RemoteSGLangEngine # Launch process @@ -133,8 +147,10 @@ def inference_engine(request): # Set environment for remote engine os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{host}:{port}" + engine = engine_class(config) + yield { - "engine_class": engine_class, + "engine": engine, "backend": backend, "mode": mode, "expr_name": expr_name, @@ -148,36 +164,27 @@ def inference_engine(request): else: # local if backend == "vllm": - from vllm import EngineArgs - from areal.engine.vllm_local import LocalvLLMEngine - engine_args = EngineArgs( - model=MODEL_PATH, - tensor_parallel_size=1, - gpu_memory_utilization=0.3, - trust_remote_code=True, - ) + engine_args = vllm_args engine_class = LocalvLLMEngine else: # sglang from areal.engine.sglang_local import LocalSGLangEngine - engine_args = { - "model_path": MODEL_PATH, - "tp_size": 1, - "mem_fraction_static": 0.3, - "skip_tokenizer_init": True, - } + engine_args = sglang_args engine_class = LocalSGLangEngine + engine = engine_class(config) + engine.create_engine(engine_args=engine_args) + yield { - "engine_class": engine_class, + "engine": engine, "backend": backend, "mode": mode, "expr_name": expr_name, "trial_name": trial_name, - "engine_args": engine_args, } + engine.destroy_engine() # ============================================================================ @@ -195,15 +202,12 @@ def test_rollout(inference_engine, n_samples): trial_name=inference_engine["trial_name"], max_concurrent_rollouts=2, consumer_batch_size=2, + enable_rollout_tracing=True, ) - engine = inference_engine["engine_class"](config) - - # Initialize based on mode - if inference_engine["mode"] == "remote": - engine.initialize() - else: # local - engine.initialize(engine_args=inference_engine["engine_args"]) + engine = inference_engine["engine"] + engine.configure(config) + engine.initialize() gconfig = GenerationHyperparameters( max_new_tokens=16, greedy=False, n_samples=n_samples @@ -245,13 +249,9 @@ def test_staleness_control(inference_engine, bs, ofp, n_samples): ), ) - engine = inference_engine["engine_class"](config) - - # Initialize based on mode - if inference_engine["mode"] == "remote": - engine.initialize() - else: # local - engine.initialize(engine_args=inference_engine["engine_args"]) + engine = inference_engine["engine"] + engine.configure(config) + engine.initialize() gconfig = GenerationHyperparameters( max_new_tokens=2, greedy=False, n_samples=n_samples @@ -295,7 +295,6 @@ def test_staleness_control(inference_engine, bs, ofp, n_samples): results = engine.wait(count=bs * 2, timeout=5) assert results["attention_mask"].shape[0] == bs * 2 * n_samples - # exit engine.destroy() @@ -334,18 +333,8 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, inference_engine name_resolve.reconfigure(name_resolve_config) # initialize inference engine - config = InferenceEngineConfig( - experiment_name=inference_engine["expr_name"], - trial_name=inference_engine["trial_name"], - ) - inf_engine = inference_engine["engine_class"](config) - - # Initialize based on mode - if inference_engine["mode"] == "remote": - inf_engine.initialize() - else: # local - inf_engine.initialize(engine_args=inference_engine["engine_args"]) - + inf_engine = inference_engine["engine"] + inf_engine.initialize() inf_engine.set_version(100) # test update weights diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 176103dc6..ac15e5ab5 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -712,7 +712,7 @@ Configuration for SwanLab experiment tracking and monitoring. | --------- | -------------- | ------------ | ----------- | | `project` | string \| None | `None` | - | | `name` | string \| None | `None` | - | -| `config` | `Dict` \| None | `None` | - | +| `config` | `dict` \| None | `None` | - | | `logdir` | string \| None | `None` | - | | `mode` | string \| None | `"disabled"` | - | | `api_key` | string \| None | `None` | - | @@ -745,7 +745,7 @@ Configuration for Weights & Biases experiment tracking. | `group` | string \| None | `None` | - | | `notes` | string \| None | `None` | - | | `tags` | list of string \| None | `None` | - | -| `config` | `Dict` \| None | `None` | - | +| `config` | `dict` \| None | `None` | - | | `id_suffix` | string \| None | `"train"` | - | (section-distributed-data-parallel)= @@ -808,6 +808,6 @@ Configuration for worker scheduling. Used in the single-controller mode. Experim | `endpoint` | string | `"http://localhost:8081"` | - | | `deploy_mode` | string | `"separation"` | - | | `functioncall_service_domain` | string | `"http://localhost:8080"` | - | -| `reward_functioncall_config` | `Dict` | **Required** | - | +| `reward_functioncall_config` | `dict` | **Required** | - | | `reward_model_path` | string | `""` | - | | `reward_model_service_url` | string | `"http://localhost:30000/classify"` | - | From 7a2f6a9574d97fc3a9136428bdde937485e09cd5 Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Tue, 28 Oct 2025 20:01:46 +0800 Subject: [PATCH 07/52] . --- areal/api/cli_args.py | 2 +- areal/core/async_task_runner.py | 17 +++++++++++++---- areal/core/workflow_executor.py | 6 ++++-- areal/tests/test_inference_engines.py | 1 - docs/cli_reference.md | 2 +- 5 files changed, 19 insertions(+), 9 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 8379f0060..7a4de3cfb 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -805,7 +805,7 @@ class InferenceEngineConfig: ) queue_size: None | int = field( default=None, - metadata={"help": "Input/Output queue size for async rollout."}, + metadata={"help": "(Deprecated) Input/Output queue size for async rollout."}, ) consumer_batch_size: int = field( default=1, diff --git a/areal/core/async_task_runner.py b/areal/core/async_task_runner.py index f9cb2f63d..9c6d3238c 100644 --- a/areal/core/async_task_runner.py +++ b/areal/core/async_task_runner.py @@ -170,7 +170,7 @@ def __init__( self.max_queue_size = max_queue_size self.poll_wait_time = poll_wait_time self.poll_sleep_time = poll_sleep_time - self.enable_tracing = enable_tracing + self._enable_tracing = enable_tracing # Thread control self.exiting = threading.Event() @@ -188,13 +188,22 @@ def __init__( self.result_cache: list[_TimedResult[T]] = [] # Thread exception handling - self._thread_exception_lock = threading.Lock() + self._lock = threading.Lock() self._thread_exception: Exception | None = None # Will be set in initialize() self.logger = None self.thread: threading.Thread | None = None + def set_enable_tracing(self, enabled: bool): + with self._lock: + self._enable_tracing = enabled + + @property + def enable_tracing(self): + with self._lock: + return self._enable_tracing + def initialize(self, logger=None): """Initialize and start the background thread. @@ -231,7 +240,7 @@ def _check_thread_health(self): RuntimeError If the background thread has died due to an exception. """ - with self._thread_exception_lock: + with self._lock: if self._thread_exception is not None: raise RuntimeError( "AsyncTaskRunner thread has died due to an exception. " @@ -247,7 +256,7 @@ def _run_thread(self): uvloop.run(self._run_async_loop()) except Exception as e: # Store exception for thread-safe access - with self._thread_exception_lock: + with self._lock: self._thread_exception = e if self.logger: self.logger.error( diff --git a/areal/core/workflow_executor.py b/areal/core/workflow_executor.py index d658642a7..8de1e1192 100644 --- a/areal/core/workflow_executor.py +++ b/areal/core/workflow_executor.py @@ -215,6 +215,9 @@ class _RolloutTaskInput: should_accept: Callable | None = None +TASK_RUNNER_MAX_QSIZE = 4096 + + class WorkflowExecutor: """Executor for asynchronous workflow-based rollout generation. @@ -268,9 +271,8 @@ def __init__( self.staleness_manager = staleness_manager # Create the generic async task runner - qsize = config.queue_size or self.max_concurrent_rollouts * 16 self.runner = AsyncTaskRunner[dict[str, Any] | None]( - max_queue_size=qsize, + max_queue_size=TASK_RUNNER_MAX_QSIZE, enable_tracing=config.enable_rollout_tracing, ) diff --git a/areal/tests/test_inference_engines.py b/areal/tests/test_inference_engines.py index 6248bc7c2..649cdec7d 100644 --- a/areal/tests/test_inference_engines.py +++ b/areal/tests/test_inference_engines.py @@ -57,7 +57,6 @@ def _dummy_reward_fn(*args, **kwargs): "sglang-remote", # "sglang-local", ], - scope="module", ) def inference_engine(request): """Unified fixture that provides any inference engine (vLLM/SGLang, local/remote). diff --git a/docs/cli_reference.md b/docs/cli_reference.md index ac15e5ab5..585169c9a 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -457,7 +457,7 @@ Configuration for inference servers, including offpolicyness control. | `experiment_name` | string \| None | `None` | - | | `trial_name` | string \| None | `None` | - | | `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | -| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | +| `queue_size` | integer \| None | `None` | (Deprecated) Input/Output queue size for async rollout. | | `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | | `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | | `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | From 46ee1507db08f3f47ab32c3b5bdbecad4b0da184 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Tue, 28 Oct 2025 20:33:39 +0800 Subject: [PATCH 08/52] add local scheduler --- areal/api/scheduler_api.py | 190 +++- areal/scheduler/exceptions.py | 117 ++ areal/scheduler/local_scheduler.py | 1035 ++++++++++++++++++ areal/scheduler/rpc/rpc_client.py | 137 --- areal/scheduler/rpc/rpc_server.py | 321 ++++-- areal/tests/test_local_scheduler.py | 1562 +++++++++++++++++++++++++++ 6 files changed, 3094 insertions(+), 268 deletions(-) create mode 100644 areal/scheduler/exceptions.py create mode 100644 areal/scheduler/local_scheduler.py delete mode 100644 areal/scheduler/rpc/rpc_client.py create mode 100644 areal/tests/test_local_scheduler.py diff --git a/areal/api/scheduler_api.py b/areal/api/scheduler_api.py index f7e9fb941..226fa2e28 100644 --- a/areal/api/scheduler_api.py +++ b/areal/api/scheduler_api.py @@ -1,69 +1,225 @@ import abc from dataclasses import dataclass, field -from typing import Dict, List +from typing import Any @dataclass class Worker: + """ + Represents a worker process in the distributed system. + + Attributes: + id: Unique identifier for the worker (e.g., "rollout/0", "actor/1"). + ip: IP address where the worker is running. + ports: List of port numbers (as strings) allocated to this worker for RPC communication. + """ + id: str ip: str - ports: List[str] = field(default_factory=list) + ports: list[str] = field(default_factory=list) @dataclass class ContainerSpec: + """ + Resource specification for a worker container/process. + + Attributes: + cpu: Number of CPU cores to allocate. + gpu: Number of GPUs to allocate. + mem: Memory in MB to allocate. + container_image: Docker container image (for containerized deployments). + cmd: Command to execute when starting the worker. + env_vars: Environment variables to set for the worker process. + port_count: Number of ports to allocate for this worker. + """ + cpu: int = 0 gpu: int = 0 mem: int = 0 container_image: str = "" cmd: str = "" - env_vars: Dict[str, str] = field(default_factory=dict) + env_vars: dict[str, str] = field(default_factory=dict) port_count: int = 2 @dataclass class ScheduleStrategy: + """ + Scheduling strategy configuration. + + Supported strategies: + - "new": Allocate new GPUs using round-robin (default). + - "colocate": Schedule workers on the same GPUs as another role. + + Attributes: + type: Type of scheduling strategy ("new" or "colocate"). + uid: For "colocate" strategy, the role name to colocate with (e.g., "actor"). + For "new" strategy, this field is optional. + """ + type: str = "" uid: str = "" @dataclass class SchedulingConfig: + """ + Complete configuration for scheduling a group of workers. + + Attributes: + replicas: Number of worker replicas to create. + specs: List of container specifications, one per replica (or a single spec for all). + schedule_strategy: Optional scheduling strategy to use. + role: Role name for this group of workers (e.g., "rollout", "actor", "critic"). + """ + replicas: int = 0 - specs: List[ContainerSpec] = field(default_factory=list) + specs: list[ContainerSpec] = field(default_factory=list) schedule_strategy: ScheduleStrategy | None = None role: str = "" class Scheduler(abc.ABC): - def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> str: + """ + Abstract base class for schedulers that manage distributed worker processes. + + A scheduler is responsible for: + - Creating and managing worker processes/containers. + - Allocating resources (GPUs, ports, memory). + - Creating and managing engine instances on workers. + - Facilitating RPC calls to engine methods. + """ + + @abc.abstractmethod + def create_workers( + self, role: str, scheduler_config: SchedulingConfig, *args, **kwargs + ) -> list[str]: + """ + Create and start worker processes for a specific role. + + Args: + role: Role name for this group of workers (e.g., "rollout", "actor", "critic"). + scheduler_config: Configuration specifying replicas, resources, and scheduling strategy. + *args: Additional positional arguments (implementation-specific). + **kwargs: Additional keyword arguments (implementation-specific). + + Returns: + List of worker IDs created (e.g., ["rollout/0", "rollout/1"]). + + Raises: + WorkerCreationError: If worker creation fails. + ValueError: If scheduler_config is invalid. """ - Start workers, return job id + raise NotImplementedError() + + @abc.abstractmethod + def get_workers(self, role: str, timeout: int | None = None) -> list[Worker]: + """ + Wait for workers to be ready and return their information. + + This method blocks until all workers for the specified role are ready + to accept RPC requests, or until the timeout is reached. + + Args: + role: Role name to query (e.g., "rollout", "actor"). + timeout: Maximum time to wait in seconds. None means use the default timeout. + + Returns: + List of Worker objects containing worker ID, IP address, and allocated ports. + + Raises: + WorkerNotFoundError: If no workers exist for the specified role. + WorkerFailedError: If any worker process has failed. + WorkerTimeoutError: If timeout is exceeded while waiting for workers. """ + raise NotImplementedError() - def get_workers(self, worker_key, timeout=None) -> List[Worker]: + @abc.abstractmethod + def delete_workers(self, role: str | None = None): """ - Wait and return worker list, including scheduling results such as ip and engine ports - (worker id, ip, ports) + Stop and clean up worker processes. + + Args: + role: Specific role to delete. If None, all workers are deleted. + + Raises: + WorkerNotFoundError: If the specified role doesn't exist. + + Note: + This method should gracefully terminate workers and clean up resources. + It should not raise an exception if workers have already stopped. """ raise NotImplementedError() - def delete_workers(self): - """stop all workers + @abc.abstractmethod + async def create_engine(self, worker_id: str, engine: str, *args, **kwargs) -> Any: + """ + Create an engine instance on a remote worker. + + The engine parameter is a string import path (e.g., "areal.engine.ppo.actor.FSDPPPOActor") + that will be dynamically imported and instantiated on the worker. + + Args: + worker_id: ID of the worker to create the engine on (e.g., "rollout/0"). + engine: Import path to the engine class (e.g., "areal.engine.ppo.actor.FSDPPPOActor"). + *args: Positional arguments passed to engine initialization. + **kwargs: Keyword arguments passed to engine initialization. - Raises exception if there is no such job, but passes if the job - has stopped either successfully or not. + Returns: + Result from engine initialization. + + Raises: + WorkerNotFoundError: If the specified worker doesn't exist. + WorkerFailedError: If the worker process has failed. + EngineCreationError: If engine creation or initialization fails. """ raise NotImplementedError() - async def create_engine(self, worker_id, engine_obj, *args, **kwargs): + @abc.abstractmethod + def call_engine(self, worker_id: str, method: str, *args, **kwargs) -> Any: """ - Create engine instance remotely + Call a method on an engine instance running on a worker (data plane operation). + + This is the synchronous version. Use `async_call_engine` for async operations. + + Args: + worker_id: ID of the worker hosting the engine (e.g., "rollout/0"). + method: Name of the method to call on the engine. + *args: Positional arguments to pass to the method. + **kwargs: Keyword arguments to pass to the method. + + Returns: + Result from the engine method call. + + Raises: + WorkerNotFoundError: If the specified worker doesn't exist. + WorkerFailedError: If the worker process has failed. + EngineCallError: If the method call fails. """ raise NotImplementedError() - def call_engine(self, worker_id, method, *args, **kwargs): + @abc.abstractmethod + async def async_call_engine( + self, worker_id: str, method: str, *args, **kwargs + ) -> Any: """ - Data plane call + Async version of call_engine for calling engine methods asynchronously. + + This is useful for concurrent operations or when integrating with async frameworks. + + Args: + worker_id: ID of the worker hosting the engine (e.g., "rollout/0"). + method: Name of the method to call on the engine. + *args: Positional arguments to pass to the method. + **kwargs: Keyword arguments to pass to the method. + + Returns: + Result from the engine method call. + + Raises: + WorkerNotFoundError: If the specified worker doesn't exist. + WorkerFailedError: If the worker process has failed. + EngineCallError: If the method call fails. """ raise NotImplementedError() diff --git a/areal/scheduler/exceptions.py b/areal/scheduler/exceptions.py new file mode 100644 index 000000000..93a007b64 --- /dev/null +++ b/areal/scheduler/exceptions.py @@ -0,0 +1,117 @@ +"""Custom exceptions for the scheduler module.""" + + +class SchedulerError(Exception): + """Base exception for all scheduler-related errors.""" + + pass + + +class WorkerCreationError(SchedulerError): + """Raised when worker creation fails during subprocess spawn or initialization.""" + + def __init__(self, worker_key: str, reason: str, details: str = ""): + self.worker_key = worker_key + self.reason = reason + self.details = details + message = f"Failed to create worker '{worker_key}': {reason}" + if details: + message += f"\nDetails: {details}" + super().__init__(message) + + +class WorkerFailedError(SchedulerError): + """Raised when a worker process fails or exits unexpectedly.""" + + def __init__(self, worker_id: str, exit_code: int, stderr: str = ""): + self.worker_id = worker_id + self.exit_code = exit_code + self.stderr = stderr + message = f"Worker '{worker_id}' failed with exit code {exit_code}" + if stderr: + message += f"\nStderr output:\n{stderr}" + super().__init__(message) + + +class WorkerNotFoundError(SchedulerError): + """Raised when attempting to access a worker that doesn't exist.""" + + def __init__(self, worker_id: str): + self.worker_id = worker_id + super().__init__(f"Worker '{worker_id}' not found") + + +class EngineCreationError(SchedulerError): + """Raised when engine creation fails on a worker.""" + + def __init__(self, worker_id: str, reason: str, status_code: int = None): + self.worker_id = worker_id + self.reason = reason + self.status_code = status_code + message = f"Failed to create engine on worker '{worker_id}': {reason}" + if status_code: + message += f" (HTTP {status_code})" + super().__init__(message) + + +class EngineCallError(SchedulerError): + """Raised when calling an engine method fails.""" + + def __init__(self, worker_id: str, method: str, reason: str, attempt: int = 1): + self.worker_id = worker_id + self.method = method + self.reason = reason + self.attempt = attempt + message = f"Failed to call method '{method}' on worker '{worker_id}': {reason}" + if attempt > 1: + message += f" (after {attempt} attempts)" + super().__init__(message) + + +class WorkerTimeoutError(SchedulerError): + """Raised when waiting for a worker exceeds the timeout.""" + + def __init__(self, worker_key: str, timeout: float): + self.worker_key = worker_key + self.timeout = timeout + super().__init__( + f"Timeout waiting for worker '{worker_key}' (waited {timeout}s)" + ) + + +class PortAllocationError(SchedulerError): + """Raised when port allocation fails.""" + + def __init__(self, reason: str): + self.reason = reason + super().__init__(f"Failed to allocate ports: {reason}") + + +class GPUAllocationError(SchedulerError): + """Raised when GPU allocation fails.""" + + def __init__(self, reason: str): + self.reason = reason + super().__init__(f"Failed to allocate GPU resources: {reason}") + + +class RPCConnectionError(SchedulerError): + """Raised when RPC connection to a worker fails.""" + + def __init__(self, worker_id: str, host: str, port: int, reason: str): + self.worker_id = worker_id + self.host = host + self.port = port + self.reason = reason + super().__init__( + f"Failed to connect to worker '{worker_id}' at {host}:{port}: {reason}" + ) + + +class EngineImportError(SchedulerError): + """Raised when importing an engine class fails on the worker.""" + + def __init__(self, import_path: str, reason: str): + self.import_path = import_path + self.reason = reason + super().__init__(f"Failed to import engine '{import_path}': {reason}") diff --git a/areal/scheduler/local_scheduler.py b/areal/scheduler/local_scheduler.py new file mode 100644 index 000000000..693dd47db --- /dev/null +++ b/areal/scheduler/local_scheduler.py @@ -0,0 +1,1035 @@ +"""Local scheduler for managing worker subprocesses on a single GPU node.""" + +import os +import shlex +import subprocess +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import httpx +import orjson +import psutil + +from areal.api.scheduler_api import ContainerSpec, Scheduler, SchedulingConfig, Worker +from areal.scheduler.exceptions import ( + EngineCallError, + EngineCreationError, + EngineImportError, + GPUAllocationError, + PortAllocationError, + RPCConnectionError, + SchedulerError, + WorkerCreationError, + WorkerFailedError, + WorkerNotFoundError, + WorkerTimeoutError, +) +from areal.utils import logging +from areal.utils.network import find_free_ports, gethostip + +logger = logging.getLogger("LocalScheduler") + + +@dataclass +class WorkerInfo: + """Internal tracking information for a worker process.""" + + worker: Worker # Public Worker object with id, ip, ports + process: subprocess.Popen # The subprocess handle + role: str # Worker role (e.g., "rollout", "actor", "critic") + gpu_devices: list[int] # Allocated GPU device IDs + created_at: float # Timestamp when worker was created + log_file: str # Path to stderr log file + env_vars: dict[str, str] = field(default_factory=dict) # Environment variables + + +class LocalScheduler(Scheduler): + """ + Local scheduler that manages worker subprocesses on a single GPU node. + + This scheduler spawns worker processes running RPC servers and manages their lifecycle. + It supports different worker types (rollout, actor, critic) through a unified interface. + + Features: + - Dynamic port allocation + - Round-robin GPU assignment + - Process health monitoring + - Comprehensive error handling + - Graceful cleanup + """ + + def __init__( + self, + gpu_devices: list[int] | None = None, + log_dir: str = "./logs/workers", + startup_timeout: float = 30.0, + health_check_interval: float = 1.0, + ): + """ + Initialize the local scheduler. + + Args: + gpu_devices: List of GPU device IDs to use. If None, uses CUDA_VISIBLE_DEVICES or all GPUs. + log_dir: Directory for worker log files + startup_timeout: Maximum time to wait for worker startup (seconds) + health_check_interval: Interval for health checks (seconds) + """ + self.gpu_devices = gpu_devices or self._detect_gpus() + self.log_dir = Path(log_dir) + self.startup_timeout = startup_timeout + self.health_check_interval = health_check_interval + + # Create log directory + self.log_dir.mkdir(parents=True, exist_ok=True) + + # Track workers by worker_key + self._workers: dict[str, list[WorkerInfo]] = {} + + # GPU allocation counter for round-robin + self._gpu_counter = 0 + + # Track all allocated ports + self._allocated_ports = set() + + # HTTP clients for RPC communication + self._http_client = httpx.Client(timeout=3600.0) # Sync client - 1 hour timeout + self._async_http_client = httpx.AsyncClient(timeout=3600.0) # Async client + + logger.info( + f"LocalScheduler initialized with GPU devices: {self.gpu_devices}, " + f"log directory: {self.log_dir}" + ) + + def _detect_gpus(self) -> list[int]: + """Detect available GPU devices.""" + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES") + if cuda_visible: + try: + return [int(x) for x in cuda_visible.split(",")] + except ValueError: + logger.warning( + f"Invalid CUDA_VISIBLE_DEVICES: {cuda_visible}, using default [0]" + ) + return [0] + # Default to single GPU + return [0] + + def _allocate_gpus(self, num_gpus: int) -> list[int]: + """ + Allocate GPUs using round-robin strategy. + + Args: + num_gpus: Number of GPUs to allocate + + Returns: + List of GPU device IDs + + Raises: + GPUAllocationError: If not enough GPUs available + """ + if num_gpus > len(self.gpu_devices): + raise GPUAllocationError( + f"Requested {num_gpus} GPUs but only {len(self.gpu_devices)} available" + ) + + allocated = [] + for _ in range(num_gpus): + gpu_id = self.gpu_devices[self._gpu_counter % len(self.gpu_devices)] + allocated.append(gpu_id) + self._gpu_counter += 1 + + return allocated + + def _get_colocated_gpus(self, target_role: str, worker_idx: int) -> list[int]: + """ + Get GPU allocation from another role for colocation. + + Args: + target_role: The role to colocate with + worker_idx: Index of the worker to get GPUs from + + Returns: + List of GPU device IDs used by the target worker + + Raises: + WorkerNotFoundError: If target role doesn't exist + ValueError: If worker index is out of range + """ + if target_role not in self._workers: + raise WorkerNotFoundError( + f"Cannot colocate with role '{target_role}' - role not found" + ) + + target_workers = self._workers[target_role] + if worker_idx >= len(target_workers): + raise ValueError( + f"Cannot colocate with {target_role}/{worker_idx} - only {len(target_workers)} workers exist" + ) + + return target_workers[worker_idx].gpu_devices + + def _allocate_ports(self, count: int) -> list[int]: + """ + Allocate free ports. + + Args: + count: Number of ports to allocate + + Returns: + List of allocated port numbers + + Raises: + PortAllocationError: If port allocation fails + """ + try: + # Pass a copy of allocated_ports to avoid reference issues + ports = find_free_ports(count, exclude_ports=set(self._allocated_ports)) + self._allocated_ports.update(ports) + return ports + except ValueError as e: + raise PortAllocationError(str(e)) from e + + def _prepare_worker_specs( + self, role: str, num_workers: int, specs: list[ContainerSpec] | None + ) -> list[ContainerSpec]: + """ + Prepare worker specs for a given number of workers. + + Args: + role: Worker role name + num_workers: Number of workers to create + specs: Optional list of specs + + Returns: + List of ContainerSpec objects (one per worker) + + Raises: + WorkerCreationError: If specs configuration is invalid + """ + if not specs: + # Default spec: 1 GPU, 2 ports + return [ContainerSpec(gpu=1, port_count=2)] * num_workers + + # If a single spec is provided, use it for all workers + if len(specs) == 1: + return [specs[0]] * num_workers + + # If per-worker specs, validate length matches + if len(specs) == num_workers: + return specs + + # Invalid configuration + raise WorkerCreationError( + role, + "Invalid configuration", + f"specs length ({len(specs)}) must be 1 or equal to replicas ({num_workers})", + ) + + def create_workers( + self, role: str, scheduler_config: SchedulingConfig, *args, **kwargs + ) -> list[str]: + """ + Create worker subprocesses. + + Args: + role: Role name for this group of workers (e.g., "rollout", "actor", "critic") + scheduler_config: Scheduling configuration with replicas, specs, and strategy + *args: Additional arguments passed to worker command + **kwargs: Additional keyword arguments + + Returns: + List of worker IDs created (e.g., ["rollout/0", "rollout/1"]) + + Raises: + WorkerCreationError: If worker creation fails + GPUAllocationError: If GPU allocation fails + PortAllocationError: If port allocation fails + """ + if role in self._workers: + raise WorkerCreationError( + role, + "Worker group already exists", + f"Use delete_workers('{role}') first to remove existing workers", + ) + + # Extract configuration + num_workers = scheduler_config.replicas + if num_workers == 0: + raise WorkerCreationError( + role, "Invalid configuration", "replicas must be greater than 0" + ) + + # Prepare worker specs + specs = self._prepare_worker_specs(role, num_workers, scheduler_config.specs) + + # Determine scheduling strategy + strategy = scheduler_config.schedule_strategy + if strategy is None: + strategy_type = "new" + colocate_role = None + else: + strategy_type = strategy.type or "new" + colocate_role = strategy.uid if strategy_type == "colocate" else None + + logger.info( + f"Creating {num_workers} workers for role '{role}' " + f"(strategy: {strategy_type}, colocate_with: {colocate_role})" + ) + + workers = [] + worker_ids = [] + try: + for idx in range(num_workers): + worker_id = f"{role}/{idx}" + spec = specs[idx] + + # Allocate resources based on strategy + try: + # GPU allocation + if strategy_type == "colocate": + if not colocate_role: + raise WorkerCreationError( + role, + "Invalid strategy", + "Colocate strategy requires uid (target role) to be specified", + ) + gpu_devices = self._get_colocated_gpus(colocate_role, idx) + logger.debug( + f"Worker {worker_id} colocated with {colocate_role}/{idx} on GPUs {gpu_devices}" + ) + else: # "new" or default + gpu_devices = self._allocate_gpus(spec.gpu) + logger.debug( + f"Worker {worker_id} allocated new GPUs {gpu_devices}" + ) + + ports = self._allocate_ports(spec.port_count) + except ( + GPUAllocationError, + PortAllocationError, + WorkerNotFoundError, + ValueError, + ) as e: + # Clean up partially created workers + self._cleanup_workers(workers) + raise WorkerCreationError( + role, f"Resource allocation failed for worker {idx}", str(e) + ) from e + + # Prepare environment + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_devices)) + env["WORKER_ID"] = worker_id + + # Merge user-provided environment variables from spec + if spec.env_vars: + env.update(spec.env_vars) + + # Prepare log file + log_file = self.log_dir / f"{worker_id.replace('/', '_')}.log" + + # Build command to start RPC server + if spec.cmd: + # Use custom command from spec + cmd = shlex.split(spec.cmd) + else: + # Default: start RPC server + cmd = [ + sys.executable, + "-m", + "areal.scheduler.rpc.rpc_server", + "--port", + str(ports[0]), # Main RPC port + ] + + # Add any additional arguments + if args: + cmd.extend(args) + + logger.debug(f"Starting worker {worker_id}: {' '.join(cmd)}") + + # Spawn subprocess + try: + with open(log_file, "w") as log_f: + process = subprocess.Popen( + cmd, + env=env, + stdout=log_f, + stderr=subprocess.STDOUT, + start_new_session=True, # Create new process group + ) + except Exception as e: + self._cleanup_workers(workers) + raise WorkerCreationError( + role, + f"Failed to spawn subprocess for worker {idx}", + str(e), + ) from e + + # Check if process started successfully + time.sleep(0.1) # Brief delay to catch immediate failures + if process.poll() is not None: + stderr = self._read_log_tail(log_file) + self._cleanup_workers(workers) + raise WorkerCreationError( + role, + f"Worker {worker_id} exited immediately with code {process.returncode}", + stderr, + ) + + # Create worker info + worker = Worker( + id=worker_id, + ip=gethostip(), + ports=[str(p) for p in ports], + ) + + worker_info = WorkerInfo( + worker=worker, + process=process, + role=role, + gpu_devices=gpu_devices, + created_at=time.time(), + log_file=str(log_file), + env_vars=env, + ) + + workers.append(worker_info) + worker_ids.append(worker_id) + logger.info( + f"Worker {worker_id} started (PID: {process.pid}, " + f"GPUs: {gpu_devices}, ports: {ports})" + ) + + # Store workers + self._workers[role] = workers + + logger.info( + f"Successfully created {len(workers)} workers for role '{role}'" + ) + return worker_ids + + except Exception as e: + # Clean up any workers created before the failure + self._cleanup_workers(workers) + if isinstance(e, SchedulerError): + raise + raise WorkerCreationError(role, "Unexpected error", str(e)) from e + + def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: + """ + Get workers and wait for them to be ready. + + Args: + role: Worker role name + timeout: Maximum time to wait for workers to be ready (None = use default) + + Returns: + List of Worker objects + + Raises: + WorkerNotFoundError: If role doesn't exist + WorkerFailedError: If any worker process failed + WorkerTimeoutError: If timeout exceeded waiting for workers + """ + if role not in self._workers: + raise WorkerNotFoundError(role) + + workers = self._workers[role] + timeout = timeout if timeout is not None else self.startup_timeout + + # First check that all processes are still alive + self._check_worker_health(role) + + # Wait for RPC servers to be ready + start_time = time.time() + ready_workers = set() + + while len(ready_workers) < len(workers): + if time.time() - start_time > timeout: + raise WorkerTimeoutError( + role, + timeout, + ) + + for worker_info in workers: + if worker_info.worker.id in ready_workers: + continue + + # Check if process is still alive + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_info.worker.id, + worker_info.process.returncode, + stderr, + ) + + # Check if RPC server is ready + if self._is_worker_ready(worker_info): + ready_workers.add(worker_info.worker.id) + logger.debug(f"Worker {worker_info.worker.id} is ready") + + if len(ready_workers) < len(workers): + time.sleep(self.health_check_interval) + + logger.info(f"All {len(workers)} workers for role '{role}' are ready") + return [w.worker for w in workers] + + def _is_worker_ready(self, worker_info: WorkerInfo) -> bool: + """Check if worker's RPC server is ready via HTTP health check.""" + port = int(worker_info.worker.ports[0]) + url = f"http://{worker_info.worker.ip}:{port}/health" + + try: + response = self._http_client.get(url, timeout=2.0) + return response.status_code == 200 + except Exception: + return False + + def _check_worker_health(self, role: str): + """ + Check health of all workers in a group. + + Raises: + WorkerFailedError: If any worker has failed + """ + if role not in self._workers: + return + + for worker_info in self._workers[role]: + returncode = worker_info.process.poll() + if returncode is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_info.worker.id, + returncode, + stderr, + ) + + def delete_workers(self, role: str | None = None): + """ + Delete workers and clean up resources. + + Args: + role: Specific worker role to delete, or None to delete all + """ + if role is None: + # Delete all workers + roles = list(self._workers.keys()) + for r in roles: + self.delete_workers(r) + return + + if role not in self._workers: + logger.warning(f"Worker role '{role}' not found, skipping deletion") + return + + workers = self._workers[role] + logger.info(f"Deleting {len(workers)} workers for role '{role}'") + + self._cleanup_workers(workers) + + # Remove from tracking + del self._workers[role] + + logger.info(f"Successfully deleted workers for role '{role}'") + + def _cleanup_workers(self, workers: list[WorkerInfo]): + """Clean up worker processes and resources.""" + for worker_info in workers: + try: + # Release ports + for port_str in worker_info.worker.ports: + self._allocated_ports.discard(int(port_str)) + + # Terminate process tree + self._terminate_process_tree(worker_info.process.pid) + + logger.debug(f"Cleaned up worker {worker_info.worker.id}") + except Exception as e: + logger.error( + f"Error cleaning up worker {worker_info.worker.id}: {e}", + exc_info=True, + ) + + def _terminate_process_tree(self, pid: int): + """Terminate a process and all its children.""" + try: + parent = psutil.Process(pid) + children = parent.children(recursive=True) + + # Try graceful termination first + for child in children: + try: + child.terminate() + except psutil.NoSuchProcess: + pass + + try: + parent.terminate() + except psutil.NoSuchProcess: + return + + # Wait for graceful termination + _, alive = psutil.wait_procs([parent] + children, timeout=3) + + # Force kill remaining processes + for proc in alive: + try: + proc.kill() + except psutil.NoSuchProcess: + pass + + except psutil.NoSuchProcess: + # Process already gone + pass + except Exception as e: + logger.warning(f"Error terminating process tree {pid}: {e}") + + def _read_log_tail(self, log_file: str, lines: int = 50) -> str: + """Read the last N lines from a log file.""" + try: + with open(log_file) as f: + all_lines = f.readlines() + return "".join(all_lines[-lines:]) + except Exception as e: + return f"[Could not read log file: {e}]" + + async def create_engine( + self, + worker_id: str, + engine: str, + *args, + **kwargs, + ) -> Any: + """ + Create an engine instance on a remote worker. + + The engine parameter is a string import path (e.g., "areal.engine.ppo.actor.FSDPPPOActor") + that will be dynamically imported and instantiated on the worker. + + Args: + worker_id: Worker ID in format "role/index" + engine: Import path to the engine class (e.g., "areal.engine.ppo.actor.FSDPPPOActor") + *args: Initialization arguments + **kwargs: Initialization keyword arguments + + Returns: + Result from engine initialization + + Raises: + WorkerNotFoundError: If worker doesn't exist + WorkerFailedError: If worker process has failed + EngineCreationError: If engine creation fails + """ + # Verify worker exists and is alive + worker_info = self._verify_worker_alive(worker_id) + + # Validate engine is a string import path + if not isinstance(engine, str): + raise EngineCreationError( + worker_id, + f"Engine must be a string import path, got {type(engine)}", + ) + + # Build JSON payload + payload = { + "engine": engine, + "init_args": list(args), + "init_kwargs": kwargs, + } + + # Send HTTP request to create engine + port = int(worker_info.worker.ports[0]) + url = f"http://{worker_info.worker.ip}:{port}/create_engine" + + try: + logger.info(f"Creating engine '{engine}' on worker '{worker_id}'") + + response = self._http_client.post( + url, + content=orjson.dumps(payload), + headers={"Content-Type": "application/json"}, + timeout=300.0, + ) + + if response.status_code == 200: + result = response.json() + logger.info(f"Engine created successfully on worker '{worker_id}'") + return result.get("result") + elif response.status_code == 400: + # Import error or bad request + error_detail = response.json().get("detail", "Unknown error") + if "Failed to import" in error_detail: + raise EngineImportError(engine, error_detail) + else: + raise EngineCreationError(worker_id, error_detail, 400) + elif response.status_code == 500: + # Engine initialization failed + error_detail = response.json().get("detail", "Unknown error") + raise EngineCreationError(worker_id, error_detail, 500) + else: + raise EngineCreationError( + worker_id, + f"Unexpected status code: {response.status_code}", + response.status_code, + ) + + except httpx.ConnectError as e: + # Check if worker died + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_id, worker_info.process.returncode, stderr + ) from e + raise RPCConnectionError( + worker_id, worker_info.worker.ip, port, str(e) + ) from e + + except httpx.TimeoutException as e: + raise EngineCreationError(worker_id, f"Request timed out: {e}") from e + + except (EngineCreationError, EngineImportError, RPCConnectionError): + raise + + except Exception as e: + raise EngineCreationError(worker_id, f"Unexpected error: {str(e)}") from e + + def call_engine( + self, + worker_id: str, + method: str, + *args, + max_retries: int = 3, + retry_delay: float = 1.0, + **kwargs, + ) -> Any: + """ + Call a method on an engine. + + Args: + worker_id: Worker ID in format "role/index" + method: Method name to call + *args: Method arguments + max_retries: Maximum number of retry attempts + retry_delay: Initial delay between retries (exponential backoff) + **kwargs: Method keyword arguments + + Returns: + Result from method call + + Raises: + WorkerNotFoundError: If worker doesn't exist + WorkerFailedError: If worker process has failed + EngineCallError: If method call fails + """ + # Get worker info (initial verification) + worker_info = self._find_worker_by_id(worker_id) + if worker_info is None: + raise WorkerNotFoundError(worker_id) + + # Build JSON payload + payload = { + "method": method, + "args": list(args), + "kwargs": kwargs, + } + + # Retry logic with exponential backoff + port = int(worker_info.worker.ports[0]) + url = f"http://{worker_info.worker.ip}:{port}/call" + last_error = None + + for attempt in range(1, max_retries + 1): + # Check worker health before each attempt + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_id, + worker_info.process.returncode, + stderr, + ) + + try: + logger.debug( + f"Calling method '{method}' on worker '{worker_id}' (attempt {attempt})" + ) + + response = self._http_client.post( + url, + content=orjson.dumps(payload), + headers={"Content-Type": "application/json"}, + timeout=7200.0, # 2 hours for long-running operations + ) + + result, should_retry, error_msg = self._handle_call_response( + response, worker_id, method, attempt + ) + if not should_retry: + if attempt > 1: + logger.info( + f"Method '{method}' succeeded on worker '{worker_id}' " + f"after {attempt} attempts" + ) + return result + last_error = error_msg + + except Exception as e: + last_error = self._handle_call_exception(e, worker_info, worker_id) + + # Retry with exponential backoff + if attempt < max_retries: + delay = retry_delay * (2 ** (attempt - 1)) + logger.warning( + f"Method '{method}' failed on worker '{worker_id}' " + f"(attempt {attempt}/{max_retries}): {last_error}. " + f"Retrying in {delay:.1f}s..." + ) + time.sleep(delay) + + # All retries exhausted + raise EngineCallError( + worker_id, + method, + last_error or "Max retries exceeded", + attempt=max_retries, + ) + + async def async_call_engine( + self, + worker_id: str, + method: str, + *args, + max_retries: int = 3, + retry_delay: float = 1.0, + **kwargs, + ) -> Any: + """ + Async version of call_engine for calling engine methods asynchronously. + + Args: + worker_id: Worker ID in format "role/index" + method: Method name to call + *args: Method arguments + max_retries: Maximum number of retry attempts + retry_delay: Initial delay between retries (exponential backoff) + **kwargs: Method keyword arguments + + Returns: + Result from method call + + Raises: + WorkerNotFoundError: If worker doesn't exist + WorkerFailedError: If worker process has failed + EngineCallError: If method call fails + """ + # Get worker info (initial verification) + worker_info = self._find_worker_by_id(worker_id) + if worker_info is None: + raise WorkerNotFoundError(worker_id) + + # Build JSON payload + payload = { + "method": method, + "args": list(args), + "kwargs": kwargs, + } + + # Retry logic with exponential backoff + port = int(worker_info.worker.ports[0]) + url = f"http://{worker_info.worker.ip}:{port}/call" + last_error = None + + for attempt in range(1, max_retries + 1): + # Check worker health before each attempt + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_id, + worker_info.process.returncode, + stderr, + ) + + try: + logger.debug( + f"Async calling method '{method}' on worker '{worker_id}' (attempt {attempt})" + ) + + response = await self._async_http_client.post( + url, + content=orjson.dumps(payload), + headers={"Content-Type": "application/json"}, + timeout=7200.0, # 2 hours for long-running operations + ) + + result, should_retry, error_msg = self._handle_call_response( + response, worker_id, method, attempt + ) + if not should_retry: + if attempt > 1: + logger.info( + f"Method '{method}' succeeded on worker '{worker_id}' " + f"after {attempt} attempts" + ) + return result + last_error = error_msg + + except Exception as e: + last_error = self._handle_call_exception(e, worker_info, worker_id) + + # Retry with exponential backoff + if attempt < max_retries: + delay = retry_delay * (2 ** (attempt - 1)) + logger.warning( + f"Method '{method}' failed on worker '{worker_id}' " + f"(attempt {attempt}/{max_retries}): {last_error}. " + f"Retrying in {delay:.1f}s..." + ) + import asyncio + + await asyncio.sleep(delay) + + # All retries exhausted + raise EngineCallError( + worker_id, + method, + last_error or "Max retries exceeded", + attempt=max_retries, + ) + + def _find_worker_by_id(self, worker_id: str) -> WorkerInfo | None: + """Find a worker by its ID.""" + for workers in self._workers.values(): + for worker_info in workers: + if worker_info.worker.id == worker_id: + return worker_info + return None + + def _verify_worker_alive(self, worker_id: str) -> WorkerInfo: + """ + Verify a worker exists and is alive. + + Args: + worker_id: Worker ID to verify + + Returns: + WorkerInfo object + + Raises: + WorkerNotFoundError: If worker doesn't exist + WorkerFailedError: If worker process has failed + """ + worker_info = self._find_worker_by_id(worker_id) + if worker_info is None: + raise WorkerNotFoundError(worker_id) + + # Check if process has exited + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_id, + worker_info.process.returncode, + stderr, + ) + + return worker_info + + def _handle_call_response( + self, response, worker_id: str, method: str, attempt: int + ): + """ + Handle HTTP response from engine call. + + Args: + response: HTTP response object + worker_id: Worker ID + method: Method name being called + attempt: Current retry attempt number + + Returns: + Tuple of (result, should_retry, error_message) + - result: The result from the call if successful, None otherwise + - should_retry: Whether to retry the request + - error_message: Error message if failed, None if successful + """ + if response.status_code == 200: + return response.json().get("result"), False, None + elif response.status_code == 400: + # Bad request (e.g., method doesn't exist) - don't retry + error_detail = response.json().get("detail", "Unknown error") + raise EngineCallError(worker_id, method, error_detail, attempt) + elif response.status_code == 500: + # Engine method failed - don't retry + error_detail = response.json().get("detail", "Unknown error") + raise EngineCallError(worker_id, method, error_detail, attempt) + elif response.status_code == 503: + # Service unavailable - retry + return None, True, "Service unavailable" + else: + # Other errors - retry + return None, True, f"HTTP {response.status_code}: {response.text}" + + def _handle_call_exception( + self, e: Exception, worker_info: WorkerInfo, worker_id: str + ) -> str: + """ + Handle exceptions during engine calls and return error message. + + Args: + e: The exception that occurred + worker_info: Worker information + worker_id: Worker ID + + Returns: + Error message string + + Raises: + WorkerFailedError: If worker has died + EngineCallError: If non-retryable error + """ + if isinstance(e, httpx.ConnectError): + # Check if worker died + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_id, + worker_info.process.returncode, + stderr, + ) from e + return f"Connection error: {e}" + elif isinstance(e, httpx.TimeoutException): + return f"Timeout: {e}" + elif isinstance(e, EngineCallError): + raise + else: + return f"Unexpected error: {e}" + + def __del__(self): + """Cleanup on deletion.""" + try: + self.delete_workers() + except Exception: + pass + try: + self._http_client.close() + except Exception: + pass + try: + import asyncio + + # Close async client if event loop is available + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + asyncio.create_task(self._async_http_client.aclose()) + else: + loop.run_until_complete(self._async_http_client.aclose()) + except RuntimeError: + # No event loop, sync close + pass + except Exception: + pass diff --git a/areal/scheduler/rpc/rpc_client.py b/areal/scheduler/rpc/rpc_client.py deleted file mode 100644 index 28f4b8082..000000000 --- a/areal/scheduler/rpc/rpc_client.py +++ /dev/null @@ -1,137 +0,0 @@ -import gzip -import time -from http import HTTPStatus -from typing import Any, Union - -import cloudpickle -import requests - -from areal.api.cli_args import InferenceEngineConfig, TrainEngineConfig -from areal.api.engine_api import InferenceEngine, TrainEngine -from areal.utils import logging -from areal.utils.http import response_ok, response_retryable - -logger = logging.getLogger("RPCClient") - - -class RPCClient: - def __init__(self): - self._addrs = {} - - def register(self, worker_id: str, ip: str, port: int) -> None: - self._addrs[worker_id] = (ip, port) - logger.info(f"Registered worker {worker_id} at {ip}:{port}") - - def create_engine( - self, - worker_id: str, - engine_obj: Union[InferenceEngine, TrainEngine], - init_config: Union[InferenceEngineConfig, TrainEngineConfig], - ) -> None: - ip, port = self._addrs[worker_id] - url = f"http://{ip}:{port}/create_engine" - logger.info(f"send create_engine to {worker_id} ({ip}:{port})") - payload = (engine_obj, init_config) - serialized_data = cloudpickle.dumps(payload) - serialized_obj = gzip.compress(serialized_data) - resp = requests.post(url, data=serialized_obj) - logger.info( - f"send create_engine to {worker_id} ({ip}:{port}), status={resp.status_code}" - ) - if resp.status_code == HTTPStatus.OK: - logger.info(f"create engine success.") - return cloudpickle.loads(resp.content) - else: - logger.error(f"Failed to create engine, {resp.status_code}, {resp.content}") - raise RuntimeError( - f"Failed to create engine, {resp.status_code}, {resp.content}" - ) - - def call_engine( - self, worker_id: str, method: str, max_retries: int = 3, *args, **kwargs - ) -> Any: - """ - call the rpc server with method name and args, retry on failure - - Parameters - ---------- - worker_id: str - the id of the worker to call - method: str - the method name to call - max_retries: int - max retries on failure - *args: - args to pass to the method - **kwargs: - kwargs to pass to the method - - Returns - ------- - the deserialized result from the rpc server - """ - req = (method, args, kwargs) - serialized_data = cloudpickle.dumps(req) - - return self._call_engine_with_serialized_data( - worker_id, serialized_data, max_retries - ) - - def _call_engine_with_serialized_data( - self, worker_id: str, serialized_data: bytes, max_retries=3 - ) -> Any: - """ - call the rpc server with serialized data, retry on failure - - Parameters - ---------- - worker_id: str - the id of the worker to call - serialized_data: bytes - the serialized data to send - max_retries: int - max retries on failure - - Returns - ------- - the deserialized result from the rpc server - """ - ip, port = self._addrs[worker_id] - url = f"http://{ip}:{port}/call" - last_exception = None - - for attempt in range(max_retries): - try: - resp = requests.post(url, data=serialized_data, timeout=7200) - logger.info( - f"Sent call to {worker_id} ({ip}:{port}), status={resp.status_code}, attempt {attempt + 1}/{max_retries}" - ) - - if response_ok(resp.status_code): - return cloudpickle.loads(resp.content) - elif response_retryable(resp.status_code): - last_exception = RuntimeError( - f"Retryable HTTP status {resp.status_code}: {resp.content}" - ) - else: - raise RuntimeError( - f"Non-retryable HTTP error: {resp.status_code} - {resp.content}" - ) - - except (RuntimeError, TimeoutError) as e: - logger.error(f"stop retrying, error on attempt {attempt + 1}: {e}") - raise e - except Exception as e: - last_exception = e - logger.error(f"error on attempt {attempt + 1}: {e}") - - if last_exception is not None: - if attempt < max_retries - 1: - logger.warning( - f"Retrying in 1 second... ({attempt + 1}/{max_retries})" - ) - time.sleep(1) - continue - else: - logger.error(f"Max retries exceeded for {url}") - raise last_exception diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index b2bc3d612..3ea1574f9 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -1,149 +1,242 @@ +"""Modern FastAPI-based RPC server for engine workers. + +This server runs on worker nodes to expose engine methods via HTTP/JSON RPC. +It uses safe JSON serialization instead of cloudpickle. +""" + import argparse -import gzip -import os +import importlib import traceback -from http import HTTPStatus -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from typing import AnyStr +from contextlib import asynccontextmanager -import cloudpickle -from tensordict import TensorDict +import orjson +import uvicorn +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import ORJSONResponse -from areal.api.controller_api import DistributedBatch -from areal.controller.batch import DistributedBatchMemory +from areal.api.engine_api import InferenceEngine, TrainEngine from areal.utils import logging logger = logging.getLogger("RPCServer") +# Global engine instance - must be TrainEngine or InferenceEngine +_engine: TrainEngine | InferenceEngine | None = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Lifespan context manager for startup and shutdown events.""" + # Startup + logger.info("RPC server starting up...") + yield + # Shutdown + global _engine + logger.info("Shutting down RPC server...") + if _engine is not None: + try: + # Call destroy method if available + if hasattr(_engine, "destroy"): + _engine.destroy() + logger.info("Engine destroyed successfully") + except Exception as e: + logger.error(f"Error destroying engine: {e}") + _engine = None + -def process_input_to_distributed_batch(*args, **kwargs): - for i in range(len(args)): - if isinstance(args[i], DistributedBatch): - args = list(args) - args[i] = args[i].get_data() - args = tuple(args) +app = FastAPI( + title="AReaL Worker RPC Server", + description="FastAPI-based RPC server for remote engine operations", + default_response_class=ORJSONResponse, + lifespan=lifespan, +) - for k in list(kwargs.keys()): - if isinstance(kwargs[k], DistributedBatch): - kwargs[k] = kwargs[k].get_data() - return args, kwargs +@app.get("/health") +async def health_check(): + """Health check endpoint to verify server is alive.""" + return {"status": "healthy", "engine_initialized": _engine is not None} -def process_output_to_distributed_batch(result): - if isinstance(result, dict): - return DistributedBatchMemory.from_dict(result) - elif isinstance(result, TensorDict): - return DistributedBatchMemory.from_dict(result.to_dict()) - elif isinstance(result, (list, tuple)): - return DistributedBatchMemory.from_list(list(result)) - else: - return result +@app.post("/create_engine") +async def create_engine(request: Request): + """ + Create and initialize an engine instance on this worker. + Expected JSON payload: + { + "engine": "areal.engine.ppo.actor.FSDPPPOActor", # Import path + "init_args": [...], # Positional arguments + "init_kwargs": {...} # Keyword arguments + } + """ + global _engine -class EngineRPCServer(BaseHTTPRequestHandler): - engine = None + try: + body = await request.body() + data = orjson.loads(body) - def _read_body(self, timeout=120.0) -> AnyStr: - old_timeout = None + engine_path = data.get("engine") + init_args = data.get("init_args", []) + init_kwargs = data.get("init_kwargs", {}) + + if not engine_path: + raise HTTPException( + status_code=400, detail="Missing 'engine' field in request" + ) + + # Dynamic import try: - length = int(self.headers["Content-Length"]) - old_timeout = self.request.gettimeout() - logger.info(f"Receive rpc call, path: {self.path}, timeout: {old_timeout}") - # set max read timeout = 120s here, if read hang raise exception - self.request.settimeout(timeout) - return self.rfile.read(length) - except Exception as e: - raise e - finally: - self.request.settimeout(old_timeout) + module_path, class_name = engine_path.rsplit(".", 1) + module = importlib.import_module(module_path) + engine_class = getattr(module, class_name) + + # Validate that the class is a TrainEngine or InferenceEngine + if not ( + issubclass(engine_class, TrainEngine) + or issubclass(engine_class, InferenceEngine) + ): + raise TypeError( + f"Engine class must be a subclass of TrainEngine or InferenceEngine, " + f"got {engine_class}" + ) + except (ValueError, ImportError, AttributeError) as e: + logger.error(f"Failed to import engine '{engine_path}': {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to import engine '{engine_path}': {str(e)}", + ) + except TypeError as e: + logger.error(f"Invalid engine type: {e}") + raise HTTPException( + status_code=400, + detail=str(e), + ) - def do_POST(self): - data = None + # Instantiate engine try: - data = self._read_body() + _engine = engine_class(*init_args, **init_kwargs) + logger.info(f"Engine '{engine_path}' instantiated successfully") except Exception as e: - self.send_response( - HTTPStatus.REQUEST_TIMEOUT - ) # 408 means read request timeout - self.end_headers() - self.wfile.write( - f"Exception: {e}\n{traceback.format_exc()}".encode("utf-8") + logger.error(f"Failed to instantiate engine: {e}\n{traceback.format_exc()}") + raise HTTPException( + status_code=500, + detail=f"Failed to instantiate engine: {str(e)}", ) - logger.error(f"Exception in do_POST: {e}\n{traceback.format_exc()}") - return + # Initialize engine if it has initialize method try: - if self.path == "/create_engine": - decompressed_data = gzip.decompress(data) - engine_obj, init_args = cloudpickle.loads(decompressed_data) - EngineRPCServer.engine = engine_obj - result = EngineRPCServer.engine.initialize(init_args) - logger.info(f"Engine created and initialized on RPC server: {result}") - self.send_response(HTTPStatus.OK) - self.end_headers() - self.wfile.write(cloudpickle.dumps(result)) - elif self.path == "/call": - if EngineRPCServer.engine is None: - self.send_response(HTTPStatus.INTERNAL_SERVER_ERROR) - self.end_headers() - self.wfile.write(b"Engine is none") - logger.error("Call received but engine is none.") - return - action, args, kwargs = cloudpickle.loads(data) - method = getattr(EngineRPCServer.engine, action) - # NOTE: DO NOT print args here, args may be a very huge tensor - logger.info(f"RPC server calling engine method: {action}") - args, kwargs = process_input_to_distributed_batch(*args, **kwargs) - result = method(*args, **kwargs) - result = process_output_to_distributed_batch(result) - self.send_response(HTTPStatus.OK) - self.end_headers() - self.wfile.write(cloudpickle.dumps(result)) - else: - self.send_response(HTTPStatus.NOT_FOUND) - self.end_headers() + result = _engine.initialize(*init_args, **init_kwargs) + logger.info(f"Engine initialized with result: {result}") + return { + "status": "success", + "message": f"Engine '{engine_path}' created and initialized", + "result": result, + } except Exception as e: - self.send_response(HTTPStatus.INTERNAL_SERVER_ERROR) - self.end_headers() - self.wfile.write( - f"Exception: {e}\n{traceback.format_exc()}".encode("utf-8") + logger.error(f"Failed to initialize engine: {e}\n{traceback.format_exc()}") + raise HTTPException( + status_code=500, detail=f"Failed to initialize engine: {str(e)}" ) - logger.error(f"Exception in do_POST: {e}\n{traceback.format_exc()}") + except HTTPException: + raise + except Exception as e: + logger.error( + f"Unexpected error in create_engine: {e}\n{traceback.format_exc()}" + ) + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + +@app.post("/call") +async def call_engine_method(request: Request): + """ + Call a method on the engine instance. + + Expected JSON payload: + { + "method": "train_batch", + "args": [...], + "kwargs": {...} + } + """ + global _engine + + if _engine is None: + raise HTTPException( + status_code=503, + detail="Engine not initialized. Call /create_engine first.", + ) + + try: + body = await request.body() + data = orjson.loads(body) + + method_name = data.get("method") + args = data.get("args", []) + kwargs = data.get("kwargs", {}) + + if not method_name: + raise HTTPException( + status_code=400, detail="Missing 'method' field in request" + ) -def start_rpc_server(port): - server = ThreadingHTTPServer(("0.0.0.0", port), EngineRPCServer) - server.serve_forever() + # Call method directly (no need for hasattr/getattr with typed engine) + logger.info(f"Calling engine method: {method_name}") + try: + # Get the method - will raise AttributeError if it doesn't exist + method = getattr(_engine, method_name) + result = method(*args, **kwargs) + + # Serialize result + # Note: This assumes the result is JSON-serializable + # For complex types (tensors, etc.), you may need custom serialization + return {"status": "success", "result": result} + + except AttributeError as e: + logger.error(f"Method '{method_name}' not found on engine: {e}") + raise HTTPException( + status_code=400, + detail=f"Engine does not have method '{method_name}'", + ) + except Exception as e: + logger.error( + f"Engine method '{method_name}' failed: {e}\n{traceback.format_exc()}" + ) + raise HTTPException( + status_code=500, + detail=f"Engine method '{method_name}' failed: {str(e)}", + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"Unexpected error in call: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") -def get_serve_port(args): - port = args.port - port_str = os.environ.get("PORT_LIST", "").strip() - - # Check if PORT_LIST is set - if port_str: - # Split by comma and strip whitespace - ports = [p.strip() for p in port_str.split(",")] - # Use the first valid port from the list - if ports and ports[0]: - try: - return int(ports[0]) - except ValueError: - logger.warning( - f"Invalid port '{ports[0]}' in PORT_LIST. Falling back to --port argument." - ) - return port +def main(): + """Main entry point for the RPC server.""" + parser = argparse.ArgumentParser(description="AReaL Worker RPC Server") + parser.add_argument("--port", type=int, required=True, help="Port to serve on") + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)" + ) -if __name__ == "__main__": - parser = argparse.ArgumentParser() + args, unknown = parser.parse_known_args() + port = args.port - parser.add_argument("--port", type=int, required=False) + logger.info(f"Starting RPC server on {args.host}:{port}") - args, unknown = parser.parse_known_args() - port = get_serve_port(args) + # Run uvicorn server with a single worker (required for GPU workloads) + uvicorn.run( + app, + host=args.host, + port=port, + workers=1, # Single worker required for GPU memory management + log_level="info", + access_log=True, + ) - logger.info(f"About to start RPC server on {port}") - start_rpc_server(port) +if __name__ == "__main__": + main() diff --git a/areal/tests/test_local_scheduler.py b/areal/tests/test_local_scheduler.py new file mode 100644 index 000000000..1e700a784 --- /dev/null +++ b/areal/tests/test_local_scheduler.py @@ -0,0 +1,1562 @@ +""" +Comprehensive unit tests for LocalScheduler. + +This test suite covers: +1. Initialization and GPU detection +2. Worker creation with various configurations +3. GPU allocation strategies (new, colocate, round-robin) +4. Port allocation and tracking +5. Worker health checks and readiness +6. Engine creation and method calls (sync and async) +7. Error handling for all exception types +8. Resource cleanup and process termination +9. Edge cases (duplicate workers, worker not found, GPU exhaustion, port conflicts) +10. Log file handling +11. HTTP client interactions +""" + +import asyncio +import os +import time +from unittest.mock import AsyncMock, Mock, call, patch + +import httpx +import psutil +import pytest + +from areal.api.scheduler_api import ( + ContainerSpec, + ScheduleStrategy, + SchedulingConfig, + Worker, +) +from areal.scheduler.exceptions import ( + EngineCallError, + EngineCreationError, + EngineImportError, + GPUAllocationError, + PortAllocationError, + RPCConnectionError, + WorkerCreationError, + WorkerFailedError, + WorkerNotFoundError, + WorkerTimeoutError, +) +from areal.scheduler.local_scheduler import LocalScheduler, WorkerInfo + +# ============================================================================ +# Fixtures and Helper Functions +# ============================================================================ + + +@pytest.fixture +def scheduler(tmp_path): + """Create a LocalScheduler instance with default configuration.""" + return LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + +@pytest.fixture +def multi_gpu_scheduler(tmp_path): + """Create a LocalScheduler instance with multiple GPUs.""" + return LocalScheduler(gpu_devices=[0, 1, 2], log_dir=str(tmp_path)) + + +def create_mock_process(pid=1234, is_alive=True, exit_code=None): + """Create a mock subprocess.Popen process. + + Args: + pid: Process ID + is_alive: Whether process is still running + exit_code: Exit code if process has terminated + + Returns: + Mock process object + """ + mock_proc = Mock() + mock_proc.pid = pid + mock_proc.poll.return_value = None if is_alive else exit_code + if not is_alive: + mock_proc.returncode = exit_code + return mock_proc + + +def create_worker_info( + worker_id="test/0", + role="test", + ip="127.0.0.1", + ports=None, + gpu_devices=None, + log_file="/tmp/test.log", + process=None, +): + """Create a WorkerInfo instance with sensible defaults. + + Args: + worker_id: Worker identifier + role: Worker role name + ip: IP address + ports: List of port strings + gpu_devices: List of GPU device IDs + log_file: Path to log file + process: Mock process object (created if not provided) + + Returns: + WorkerInfo instance + """ + if ports is None: + ports = ["8000"] + if gpu_devices is None: + gpu_devices = [0] + if process is None: + process = create_mock_process() + + return WorkerInfo( + worker=Worker(id=worker_id, ip=ip, ports=ports), + process=process, + role=role, + gpu_devices=gpu_devices, + created_at=time.time(), + log_file=log_file, + ) + + +def create_mock_http_response(status_code=200, json_data=None): + """Create a mock HTTP response. + + Args: + status_code: HTTP status code + json_data: Dictionary to return from response.json() + + Returns: + Mock response object + """ + mock_response = Mock() + mock_response.status_code = status_code + if json_data is not None: + mock_response.json.return_value = json_data + return mock_response + + +class TestLocalSchedulerInitialization: + """Test LocalScheduler initialization and GPU detection.""" + + def test_init_with_explicit_gpu_devices(self, tmp_path): + """Should initialize with explicitly provided GPU devices.""" + scheduler = LocalScheduler( + gpu_devices=[0, 1, 2], + log_dir=str(tmp_path), + startup_timeout=60.0, + health_check_interval=2.0, + ) + + assert scheduler.gpu_devices == [0, 1, 2] + assert scheduler.log_dir == tmp_path + assert scheduler.startup_timeout == 60.0 + assert scheduler.health_check_interval == 2.0 + assert scheduler._gpu_counter == 0 + assert len(scheduler._allocated_ports) == 0 + assert len(scheduler._workers) == 0 + assert tmp_path.exists() + + def test_init_without_gpu_devices_uses_cuda_visible_devices(self, tmp_path): + """Should detect GPUs from CUDA_VISIBLE_DEVICES environment variable.""" + with patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1,3"}): + scheduler = LocalScheduler(log_dir=str(tmp_path)) + assert scheduler.gpu_devices == [0, 1, 3] + + def test_init_with_invalid_cuda_visible_devices(self, tmp_path): + """Should fall back to default [0] when CUDA_VISIBLE_DEVICES is invalid.""" + with patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "invalid,gpu,ids"}): + scheduler = LocalScheduler(log_dir=str(tmp_path)) + assert scheduler.gpu_devices == [0] + + def test_init_without_cuda_visible_devices(self, tmp_path): + """Should default to [0] when CUDA_VISIBLE_DEVICES is not set.""" + with patch.dict(os.environ, {}, clear=True): + if "CUDA_VISIBLE_DEVICES" in os.environ: + del os.environ["CUDA_VISIBLE_DEVICES"] + scheduler = LocalScheduler(log_dir=str(tmp_path)) + assert scheduler.gpu_devices == [0] + + def test_init_creates_log_directory(self, tmp_path): + """Should create log directory if it doesn't exist.""" + log_dir = tmp_path / "nested" / "log" / "dir" + assert not log_dir.exists() + + scheduler = LocalScheduler(log_dir=str(log_dir)) + + assert log_dir.exists() + assert scheduler.log_dir == log_dir + + def test_init_creates_http_clients(self, tmp_path): + """Should initialize both sync and async HTTP clients.""" + scheduler = LocalScheduler(log_dir=str(tmp_path)) + + assert isinstance(scheduler._http_client, httpx.Client) + assert isinstance(scheduler._async_http_client, httpx.AsyncClient) + + +class TestGPUAllocation: + """Test GPU allocation strategies.""" + + def test_allocate_gpus_round_robin(self, tmp_path): + """Should allocate GPUs in round-robin fashion.""" + scheduler = LocalScheduler(gpu_devices=[0, 1, 2], log_dir=str(tmp_path)) + + # First allocation + gpus1 = scheduler._allocate_gpus(2) + assert gpus1 == [0, 1] + + # Second allocation (wraps around) + gpus2 = scheduler._allocate_gpus(3) + assert gpus2 == [2, 0, 1] + + # Third allocation + gpus3 = scheduler._allocate_gpus(1) + assert gpus3 == [2] + + def test_allocate_gpus_exceeds_available(self, tmp_path): + """Should raise GPUAllocationError when requesting more GPUs than available.""" + scheduler = LocalScheduler(gpu_devices=[0, 1], log_dir=str(tmp_path)) + + with pytest.raises(GPUAllocationError) as exc_info: + scheduler._allocate_gpus(3) + + assert "Requested 3 GPUs but only 2 available" in str(exc_info.value) + + def test_allocate_gpus_single_gpu_multiple_times(self, scheduler): + """Should allow multiple workers to share a single GPU via round-robin.""" + # Multiple allocations should all get GPU 0 + for _ in range(5): + gpus = scheduler._allocate_gpus(1) + assert gpus == [0] + + def test_get_colocated_gpus_success(self, multi_gpu_scheduler): + """Should return GPU devices from target worker for colocation.""" + # Create mock workers for target role + worker1 = create_worker_info( + worker_id="actor/0", role="actor", ports=["8000"], gpu_devices=[0, 1] + ) + worker2 = create_worker_info( + worker_id="actor/1", role="actor", ports=["8001"], gpu_devices=[2] + ) + multi_gpu_scheduler._workers["actor"] = [worker1, worker2] + + # Get colocated GPUs + gpus = multi_gpu_scheduler._get_colocated_gpus("actor", 0) + assert gpus == [0, 1] + + gpus = multi_gpu_scheduler._get_colocated_gpus("actor", 1) + assert gpus == [2] + + def test_get_colocated_gpus_role_not_found(self, scheduler): + """Should raise WorkerNotFoundError when target role doesn't exist.""" + with pytest.raises(WorkerNotFoundError) as exc_info: + scheduler._get_colocated_gpus("nonexistent", 0) + + assert "Cannot colocate with role 'nonexistent' - role not found" in str( + exc_info.value + ) + + def test_get_colocated_gpus_worker_index_out_of_range(self, scheduler): + """Should raise ValueError when worker index is out of range.""" + # Create only one worker for target role + worker = create_worker_info(worker_id="actor/0", role="actor", gpu_devices=[0]) + scheduler._workers["actor"] = [worker] + + with pytest.raises(ValueError) as exc_info: + scheduler._get_colocated_gpus("actor", 5) + + assert "only 1 workers exist" in str(exc_info.value) + + +class TestPortAllocation: + """Test port allocation and tracking.""" + + def test_allocate_ports_success(self, tmp_path): + """Should allocate requested number of free ports.""" + with patch( + "areal.scheduler.local_scheduler.find_free_ports" + ) as mock_find_ports: + mock_find_ports.return_value = [8000, 8001, 8002] + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + ports = scheduler._allocate_ports(3) + + assert ports == [8000, 8001, 8002] + assert scheduler._allocated_ports == {8000, 8001, 8002} + mock_find_ports.assert_called_once_with(3, exclude_ports=set()) + + def test_allocate_ports_excludes_already_allocated(self, tmp_path): + """Should exclude already allocated ports from search.""" + with patch( + "areal.scheduler.local_scheduler.find_free_ports" + ) as mock_find_ports: + mock_find_ports.side_effect = [ + [8000, 8001], + [8002, 8003], + ] + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + # First allocation + ports1 = scheduler._allocate_ports(2) + assert ports1 == [8000, 8001] + + # Second allocation should exclude previously allocated ports + ports2 = scheduler._allocate_ports(2) + assert ports2 == [8002, 8003] + assert scheduler._allocated_ports == {8000, 8001, 8002, 8003} + + # Verify excluded ports were passed + calls = mock_find_ports.call_args_list + assert calls[0] == call(2, exclude_ports=set()) + assert calls[1] == call(2, exclude_ports={8000, 8001}) + + def test_allocate_ports_failure(self, tmp_path): + """Should raise PortAllocationError when port allocation fails.""" + with patch( + "areal.scheduler.local_scheduler.find_free_ports" + ) as mock_find_ports: + mock_find_ports.side_effect = ValueError("No free ports available") + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + with pytest.raises(PortAllocationError) as exc_info: + scheduler._allocate_ports(5) + + assert "No free ports available" in str(exc_info.value) + + +class TestWorkerCreation: + """Test worker creation with various configurations.""" + + @patch("areal.scheduler.local_scheduler.gethostip") + @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_with_default_spec( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should create workers with default spec (1 GPU, 2 ports) when no specs provided.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.side_effect = [[8000, 8001], [8002, 8003]] + + # Mock process + mock_process1 = Mock() + mock_process1.pid = 1234 + mock_process1.poll.return_value = None + mock_process2 = Mock() + mock_process2.pid = 1235 + mock_process2.poll.return_value = None + mock_popen.side_effect = [mock_process1, mock_process2] + + scheduler = LocalScheduler(gpu_devices=[0, 1], log_dir=str(tmp_path)) + + config = SchedulingConfig(replicas=2, role="rollout") + worker_ids = scheduler.create_workers("rollout", config) + + assert worker_ids == ["rollout/0", "rollout/1"] + assert "rollout" in scheduler._workers + assert len(scheduler._workers["rollout"]) == 2 + + # Verify default spec was used + assert mock_popen.call_count == 2 + + @patch("areal.scheduler.local_scheduler.gethostip") + @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_with_single_spec_for_all( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should use single spec for all workers when specs length is 1.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.side_effect = [[8000, 8001, 8002]] * 3 + + # Mock processes + mock_processes = [] + for i in range(3): + mock_proc = Mock() + mock_proc.pid = 1000 + i + mock_proc.poll.return_value = None + mock_processes.append(mock_proc) + mock_popen.side_effect = mock_processes + + scheduler = LocalScheduler(gpu_devices=[0, 1, 2], log_dir=str(tmp_path)) + + config = SchedulingConfig( + replicas=3, + role="actor", + specs=[ContainerSpec(gpu=2, port_count=3)], + ) + worker_ids = scheduler.create_workers("actor", config) + + assert len(worker_ids) == 3 + assert mock_popen.call_count == 3 + + # All workers should use the same spec + for worker_info in scheduler._workers["actor"]: + assert len(worker_info.worker.ports) == 3 + + @patch("areal.scheduler.local_scheduler.gethostip") + @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_with_per_worker_specs( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should use individual specs when specs length equals replicas.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.side_effect = [[8000], [8001, 8002]] + + # Mock processes + mock_proc1 = Mock() + mock_proc1.pid = 1000 + mock_proc1.poll.return_value = None + mock_proc2 = Mock() + mock_proc2.pid = 1001 + mock_proc2.poll.return_value = None + mock_popen.side_effect = [mock_proc1, mock_proc2] + + scheduler = LocalScheduler(gpu_devices=[0, 1], log_dir=str(tmp_path)) + + config = SchedulingConfig( + replicas=2, + role="critic", + specs=[ + ContainerSpec(gpu=1, port_count=1), + ContainerSpec(gpu=1, port_count=2), + ], + ) + worker_ids = scheduler.create_workers("critic", config) + + assert len(worker_ids) == 2 + assert len(scheduler._workers["critic"][0].worker.ports) == 1 + assert len(scheduler._workers["critic"][1].worker.ports) == 2 + + @patch("areal.scheduler.local_scheduler.gethostip") + @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_with_custom_command( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should use custom command from spec when provided.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.return_value = [8000, 8001] + + mock_proc = Mock() + mock_proc.pid = 1234 + mock_proc.poll.return_value = None + mock_popen.return_value = mock_proc + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + config = SchedulingConfig( + replicas=1, + role="custom", + specs=[ + ContainerSpec( + gpu=1, port_count=2, cmd="python my_custom_server.py --port 8000" + ) + ], + ) + worker_ids = scheduler.create_workers("custom", config) + + assert len(worker_ids) == 1 + + # Verify custom command was used + popen_call = mock_popen.call_args + cmd_args = popen_call[0][0] + assert cmd_args == ["python", "my_custom_server.py", "--port", "8000"] + + @patch("areal.scheduler.local_scheduler.gethostip") + @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_with_environment_variables( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should merge environment variables from spec into worker environment.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.return_value = [8000, 8001] + + mock_proc = Mock() + mock_proc.pid = 1234 + mock_proc.poll.return_value = None + mock_popen.return_value = mock_proc + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + config = SchedulingConfig( + replicas=1, + role="envtest", + specs=[ + ContainerSpec( + gpu=1, + port_count=2, + env_vars={"CUSTOM_VAR": "custom_value", "ANOTHER_VAR": "123"}, + ) + ], + ) + worker_ids = scheduler.create_workers("envtest", config) + + assert len(worker_ids) == 1 + + # Verify environment variables were passed + popen_call = mock_popen.call_args + env = popen_call[1]["env"] + assert env["CUSTOM_VAR"] == "custom_value" + assert env["ANOTHER_VAR"] == "123" + assert env["CUDA_VISIBLE_DEVICES"] == "0" + assert env["WORKER_ID"] == "envtest/0" + + @patch("areal.scheduler.local_scheduler.gethostip") + @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_with_colocate_strategy( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should colocate workers on same GPUs as target role when colocate strategy is used.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.return_value = [8000, 8001] + + mock_processes = [] + for i in range(4): + mock_proc = Mock() + mock_proc.pid = 1000 + i + mock_proc.poll.return_value = None + mock_processes.append(mock_proc) + mock_popen.side_effect = mock_processes + + scheduler = LocalScheduler(gpu_devices=[0, 1, 2, 3], log_dir=str(tmp_path)) + + # Create target workers (actors) + actor_config = SchedulingConfig( + replicas=2, role="actor", specs=[ContainerSpec(gpu=2, port_count=2)] + ) + scheduler.create_workers("actor", actor_config) + + # Get GPU allocations for actors + actor_gpus_0 = scheduler._workers["actor"][0].gpu_devices + actor_gpus_1 = scheduler._workers["actor"][1].gpu_devices + + # Reset mock + mock_find_ports.reset_mock() + mock_find_ports.return_value = [8010, 8011] + + # Create colocated workers (critics) + critic_config = SchedulingConfig( + replicas=2, + role="critic", + specs=[ContainerSpec(gpu=2, port_count=2)], + schedule_strategy=ScheduleStrategy(type="colocate", uid="actor"), + ) + critic_ids = scheduler.create_workers("critic", critic_config) + + assert len(critic_ids) == 2 + + # Verify critics are colocated with actors + critic_gpus_0 = scheduler._workers["critic"][0].gpu_devices + critic_gpus_1 = scheduler._workers["critic"][1].gpu_devices + + assert critic_gpus_0 == actor_gpus_0 + assert critic_gpus_1 == actor_gpus_1 + + def test_create_workers_duplicate_role_error(self, tmp_path): + """Should raise WorkerCreationError when attempting to create workers for existing role.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + with ( + patch("areal.scheduler.local_scheduler.subprocess.Popen") as mock_popen, + patch("areal.scheduler.local_scheduler.find_free_ports") as mock_find_ports, + patch("areal.scheduler.local_scheduler.gethostip") as mock_gethostip, + ): + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.return_value = [8000, 8001] + mock_proc = Mock() + mock_proc.pid = 1234 + mock_proc.poll.return_value = None + mock_popen.return_value = mock_proc + + config = SchedulingConfig(replicas=1, role="test") + scheduler.create_workers("test", config) + + # Try to create again + with pytest.raises(WorkerCreationError) as exc_info: + scheduler.create_workers("test", config) + + assert "Worker group already exists" in str(exc_info.value) + assert exc_info.value.worker_key == "test" + + def test_create_workers_zero_replicas_error(self, tmp_path): + """Should raise WorkerCreationError when replicas is 0.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + config = SchedulingConfig(replicas=0, role="test") + + with pytest.raises(WorkerCreationError) as exc_info: + scheduler.create_workers("test", config) + + assert "replicas must be greater than 0" in str(exc_info.value) + + def test_create_workers_invalid_specs_length(self, tmp_path): + """Should raise WorkerCreationError when specs length is invalid.""" + scheduler = LocalScheduler(gpu_devices=[0, 1], log_dir=str(tmp_path)) + + config = SchedulingConfig( + replicas=3, + role="test", + specs=[ + ContainerSpec(gpu=1, port_count=2), + ContainerSpec(gpu=1, port_count=2), + ], # 2 specs for 3 replicas + ) + + with pytest.raises(WorkerCreationError) as exc_info: + scheduler.create_workers("test", config) + + assert "specs length (2) must be 1 or equal to replicas (3)" in str( + exc_info.value + ) + + @patch("areal.scheduler.local_scheduler.gethostip") + @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_subprocess_fails_immediately( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should raise WorkerCreationError when subprocess exits immediately.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.return_value = [8000, 8001] + + # Mock process that exits immediately + mock_proc = Mock() + mock_proc.pid = 1234 + mock_proc.poll.return_value = 1 # Exit code 1 + mock_proc.returncode = 1 + mock_popen.return_value = mock_proc + + # Create log file with error message + log_file = tmp_path / "test_0.log" + log_file.write_text("Error: Failed to start server\n") + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + config = SchedulingConfig(replicas=1, role="test") + + with patch.object( + scheduler, "_read_log_tail", return_value="Error: Failed to start server" + ): + with pytest.raises(WorkerCreationError) as exc_info: + scheduler.create_workers("test", config) + + assert "exited immediately with code 1" in str(exc_info.value) + + @patch("areal.scheduler.local_scheduler.gethostip") + @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_cleanup_on_partial_failure( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should clean up successfully created workers when a later worker fails.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.side_effect = [ + [8000, 8001], # First worker succeeds + ValueError("No free ports"), # Second worker fails + ] + + # First process succeeds + mock_proc1 = Mock() + mock_proc1.pid = 1234 + mock_proc1.poll.return_value = None + mock_popen.return_value = mock_proc1 + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + config = SchedulingConfig(replicas=2, role="test") + + with patch.object(scheduler, "_cleanup_workers") as mock_cleanup: + with pytest.raises(WorkerCreationError) as exc_info: + scheduler.create_workers("test", config) + + # Verify cleanup was called + assert mock_cleanup.called + assert "Resource allocation failed" in str(exc_info.value) + + def test_create_workers_colocate_strategy_missing_uid(self, tmp_path): + """Should raise WorkerCreationError when colocate strategy is missing target role uid.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + config = SchedulingConfig( + replicas=1, + role="test", + specs=[ContainerSpec(gpu=1, port_count=2)], + schedule_strategy=ScheduleStrategy(type="colocate", uid=""), # Missing uid + ) + + with pytest.raises(WorkerCreationError) as exc_info: + scheduler.create_workers("test", config) + + assert "Colocate strategy requires uid" in str(exc_info.value) + + +class TestGetWorkers: + """Test getting workers and waiting for readiness.""" + + def test_get_workers_role_not_found(self, scheduler): + """Should raise WorkerNotFoundError when role doesn't exist.""" + with pytest.raises(WorkerNotFoundError) as exc_info: + scheduler.get_workers("nonexistent") + + assert exc_info.value.worker_id == "nonexistent" + + @patch("areal.scheduler.local_scheduler.time.sleep") + def test_get_workers_success(self, mock_sleep, scheduler, tmp_path): + """Should return workers when all are ready.""" + # Create mock workers + worker1 = create_worker_info( + worker_id="test/0", ports=["8000"], log_file=str(tmp_path / "test_0.log") + ) + worker2 = create_worker_info( + worker_id="test/1", ports=["8001"], log_file=str(tmp_path / "test_1.log") + ) + + scheduler._workers["test"] = [worker1, worker2] + + with patch.object(scheduler, "_is_worker_ready", return_value=True): + workers = scheduler.get_workers("test", timeout=10.0) + + assert len(workers) == 2 + assert workers[0].id == "test/0" + assert workers[1].id == "test/1" + + @patch("areal.scheduler.local_scheduler.time.time") + @patch("areal.scheduler.local_scheduler.time.sleep") + def test_get_workers_timeout(self, mock_sleep, mock_time, scheduler, tmp_path): + """Should raise WorkerTimeoutError when timeout is exceeded.""" + # Mock time progression - provide enough values + mock_time.side_effect = [0.0] + [i for i in range(1, 20)] + + worker = create_worker_info(log_file=str(tmp_path / "test_0.log")) + worker.created_at = 0.0 + + scheduler._workers["test"] = [worker] + + # Worker never becomes ready + with patch.object(scheduler, "_is_worker_ready", return_value=False): + with pytest.raises(WorkerTimeoutError) as exc_info: + scheduler.get_workers("test", timeout=5.0) + + assert exc_info.value.worker_key == "test" + assert exc_info.value.timeout == 5.0 + + def test_get_workers_process_died(self, scheduler, tmp_path): + """Should raise WorkerFailedError when worker process dies during readiness check.""" + log_file = tmp_path / "test_0.log" + log_file.write_text("Error: Connection refused\n") + + # Process dies after first check + mock_proc = create_mock_process() + mock_proc.poll.side_effect = [None, 1] # None (alive), then 1 (dead) + mock_proc.returncode = 1 + + worker = create_worker_info(process=mock_proc, log_file=str(log_file)) + scheduler._workers["test"] = [worker] + + with patch.object(scheduler, "_is_worker_ready", return_value=False): + with pytest.raises(WorkerFailedError) as exc_info: + scheduler.get_workers("test", timeout=10.0) + + assert exc_info.value.worker_id == "test/0" + assert exc_info.value.exit_code == 1 + + @patch("areal.scheduler.local_scheduler.time.sleep") + def test_get_workers_gradual_readiness(self, mock_sleep, scheduler, tmp_path): + """Should wait for all workers to become ready gradually.""" + worker1 = create_worker_info( + worker_id="test/0", ports=["8000"], log_file=str(tmp_path / "test_0.log") + ) + worker2 = create_worker_info( + worker_id="test/1", ports=["8001"], log_file=str(tmp_path / "test_1.log") + ) + + scheduler._workers["test"] = [worker1, worker2] + + # Worker 1 ready immediately, worker 2 ready on second check + ready_calls = [True, False, True, True] + with patch.object(scheduler, "_is_worker_ready", side_effect=ready_calls): + workers = scheduler.get_workers("test", timeout=10.0) + + assert len(workers) == 2 + + +class TestWorkerHealthCheck: + """Test worker health checking functionality.""" + + @pytest.mark.parametrize( + "status_code,expected", + [ + (200, True), # Success + (503, False), # Service unavailable + (500, False), # Internal server error + ], + ) + def test_is_worker_ready_http_status( + self, scheduler, tmp_path, status_code, expected + ): + """Should return appropriate result based on HTTP status code.""" + worker_info = create_worker_info(log_file=str(tmp_path / "test.log")) + mock_response = create_mock_http_response(status_code=status_code) + + with patch.object(scheduler._http_client, "get", return_value=mock_response): + assert scheduler._is_worker_ready(worker_info) is expected + + def test_is_worker_ready_connection_error(self, scheduler, tmp_path): + """Should return False when connection to worker fails.""" + worker_info = create_worker_info(log_file=str(tmp_path / "test.log")) + + with patch.object( + scheduler._http_client, + "get", + side_effect=httpx.ConnectError("Connection refused"), + ): + assert scheduler._is_worker_ready(worker_info) is False + + def test_check_worker_health_all_healthy(self, scheduler, tmp_path): + """Should pass when all workers are healthy.""" + worker1 = create_worker_info( + worker_id="test/0", ports=["8000"], log_file=str(tmp_path / "test_0.log") + ) + worker2 = create_worker_info( + worker_id="test/1", ports=["8001"], log_file=str(tmp_path / "test_1.log") + ) + + scheduler._workers["test"] = [worker1, worker2] + + # Should not raise + scheduler._check_worker_health("test") + + def test_check_worker_health_worker_failed(self, scheduler, tmp_path): + """Should raise WorkerFailedError when a worker has failed.""" + log_file = tmp_path / "test_0.log" + log_file.write_text("Killed by signal\n") + + mock_proc = create_mock_process(is_alive=False, exit_code=137) + worker = create_worker_info(process=mock_proc, log_file=str(log_file)) + + scheduler._workers["test"] = [worker] + + with pytest.raises(WorkerFailedError) as exc_info: + scheduler._check_worker_health("test") + + assert exc_info.value.worker_id == "test/0" + assert exc_info.value.exit_code == 137 + + def test_check_worker_health_nonexistent_role(self, scheduler): + """Should silently pass when role doesn't exist.""" + # Should not raise + scheduler._check_worker_health("nonexistent") + + +class TestDeleteWorkers: + """Test worker deletion and cleanup.""" + + def test_delete_workers_specific_role(self, scheduler, tmp_path): + """Should delete workers for specific role.""" + # Create mock workers for multiple roles + worker1 = create_worker_info( + worker_id="role1/0", + role="role1", + ports=["8000"], + log_file=str(tmp_path / "role1_0.log"), + ) + worker2 = create_worker_info( + worker_id="role2/0", + role="role2", + ports=["8001"], + log_file=str(tmp_path / "role2_0.log"), + ) + + scheduler._workers["role1"] = [worker1] + scheduler._workers["role2"] = [worker2] + scheduler._allocated_ports = {8000, 8001} + + with patch.object(scheduler, "_terminate_process_tree"): + scheduler.delete_workers("role1") + + # role1 should be deleted, role2 should remain + assert "role1" not in scheduler._workers + assert "role2" in scheduler._workers + assert 8000 not in scheduler._allocated_ports + assert 8001 in scheduler._allocated_ports + + def test_delete_workers_all_roles(self, scheduler, tmp_path): + """Should delete all workers when role is None.""" + worker1 = create_worker_info( + worker_id="role1/0", + role="role1", + ports=["8000"], + log_file=str(tmp_path / "role1_0.log"), + ) + worker2 = create_worker_info( + worker_id="role2/0", + role="role2", + ports=["8001"], + log_file=str(tmp_path / "role2_0.log"), + ) + + scheduler._workers["role1"] = [worker1] + scheduler._workers["role2"] = [worker2] + scheduler._allocated_ports = {8000, 8001} + + with patch.object(scheduler, "_terminate_process_tree"): + scheduler.delete_workers(None) + + # All workers should be deleted + assert len(scheduler._workers) == 0 + assert len(scheduler._allocated_ports) == 0 + + def test_delete_workers_nonexistent_role(self, scheduler): + """Should log warning and return when role doesn't exist.""" + # Should not raise + scheduler.delete_workers("nonexistent") + + def test_cleanup_workers_releases_ports(self, scheduler, tmp_path): + """Should release allocated ports when cleaning up workers.""" + worker = create_worker_info( + ports=["8000", "8001"], log_file=str(tmp_path / "test.log") + ) + scheduler._allocated_ports = {8000, 8001, 8002} + + with patch.object(scheduler, "_terminate_process_tree"): + scheduler._cleanup_workers([worker]) + + # Ports 8000 and 8001 should be released + assert scheduler._allocated_ports == {8002} + + def test_cleanup_workers_handles_errors(self, scheduler, tmp_path): + """Should continue cleanup even if terminating a process fails.""" + worker1 = create_worker_info( + worker_id="test/0", ports=["8000"], log_file=str(tmp_path / "test_0.log") + ) + worker2 = create_worker_info( + worker_id="test/1", ports=["8001"], log_file=str(tmp_path / "test_1.log") + ) + + # First termination fails, second succeeds + with patch.object( + scheduler, + "_terminate_process_tree", + side_effect=[Exception("Failed to terminate"), None], + ): + # Should not raise, just log error + scheduler._cleanup_workers([worker1, worker2]) + + +class TestProcessTermination: + """Test process termination functionality.""" + + @patch("areal.scheduler.local_scheduler.psutil.Process") + @patch("areal.scheduler.local_scheduler.psutil.wait_procs") + def test_terminate_process_tree_graceful( + self, mock_wait_procs, mock_process_class, tmp_path + ): + """Should gracefully terminate process tree.""" + # Mock parent process + mock_parent = Mock() + mock_child1 = Mock() + mock_child2 = Mock() + + mock_parent.children.return_value = [mock_child1, mock_child2] + mock_process_class.return_value = mock_parent + + # All processes terminate gracefully + mock_wait_procs.return_value = ([], []) # (gone, alive) + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + scheduler._terminate_process_tree(1234) + + # Verify termination sequence + mock_child1.terminate.assert_called_once() + mock_child2.terminate.assert_called_once() + mock_parent.terminate.assert_called_once() + + # Should not call kill since all terminated gracefully + mock_child1.kill.assert_not_called() + mock_child2.kill.assert_not_called() + mock_parent.kill.assert_not_called() + + @patch("areal.scheduler.local_scheduler.psutil.Process") + @patch("areal.scheduler.local_scheduler.psutil.wait_procs") + def test_terminate_process_tree_force_kill( + self, mock_wait_procs, mock_process_class, tmp_path + ): + """Should force kill processes that don't terminate gracefully.""" + mock_parent = Mock() + mock_child = Mock() + + mock_parent.children.return_value = [mock_child] + + # Return mock_parent only when called with pid=1234, otherwise raise NoSuchProcess + # This prevents interference from __del__ cleanup of previous test's schedulers + def process_side_effect(pid): + if pid == 1234: + return mock_parent + raise psutil.NoSuchProcess(pid) + + mock_process_class.side_effect = process_side_effect + + # Child doesn't terminate gracefully + mock_wait_procs.return_value = ([], [mock_child]) # (gone, alive) + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + scheduler._terminate_process_tree(1234) + + # Verify force kill was called + mock_child.terminate.assert_called_once() + mock_child.kill.assert_called_once() + + @patch("areal.scheduler.local_scheduler.psutil.Process") + def test_terminate_process_tree_no_such_process(self, mock_process_class, tmp_path): + """Should handle gracefully when process doesn't exist.""" + mock_process_class.side_effect = psutil.NoSuchProcess(1234) + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + # Should not raise + scheduler._terminate_process_tree(1234) + + @patch("areal.scheduler.local_scheduler.psutil.Process") + def test_terminate_process_tree_handles_child_no_such_process( + self, mock_process_class, tmp_path + ): + """Should handle when child process disappears during termination.""" + mock_parent = Mock() + mock_child = Mock() + mock_child.terminate.side_effect = psutil.NoSuchProcess(1235) + + mock_parent.children.return_value = [mock_child] + mock_process_class.return_value = mock_parent + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + # Should not raise + scheduler._terminate_process_tree(1234) + + +class TestLogFileHandling: + """Test log file reading and handling.""" + + def test_read_log_tail_success(self, tmp_path): + """Should read last N lines from log file.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + log_file = tmp_path / "test.log" + log_lines = [f"Line {i}\n" for i in range(100)] + log_file.write_text("".join(log_lines)) + + tail = scheduler._read_log_tail(str(log_file), lines=10) + + # Should contain last 10 lines + assert "Line 90" in tail + assert "Line 99" in tail + assert "Line 89" not in tail + + def test_read_log_tail_file_not_found(self, tmp_path): + """Should return error message when log file doesn't exist.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + tail = scheduler._read_log_tail("/nonexistent/file.log") + + assert "Could not read log file" in tail + + def test_read_log_tail_fewer_lines_than_requested(self, tmp_path): + """Should return all lines when file has fewer lines than requested.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + log_file = tmp_path / "test.log" + log_file.write_text("Line 1\nLine 2\nLine 3\n") + + tail = scheduler._read_log_tail(str(log_file), lines=50) + + assert "Line 1" in tail + assert "Line 2" in tail + assert "Line 3" in tail + + +class TestEngineCreation: + """Test engine creation on workers.""" + + def test_create_engine_success(self, scheduler, tmp_path): + """Should successfully create engine on worker.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response( + status_code=200, + json_data={"result": {"status": "initialized", "name": "TestEngine"}}, + ) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + result = asyncio.run( + scheduler.create_engine( + "test/0", "test_engines.DummyEngine", name="TestEngine", param=123 + ) + ) + + assert result == {"status": "initialized", "name": "TestEngine"} + + def test_create_engine_worker_not_found(self, scheduler): + """Should raise WorkerNotFoundError when worker doesn't exist.""" + with pytest.raises(WorkerNotFoundError) as exc_info: + asyncio.run( + scheduler.create_engine("nonexistent/0", "test_engines.DummyEngine") + ) + + assert exc_info.value.worker_id == "nonexistent/0" + + def test_create_engine_worker_died(self, scheduler, tmp_path): + """Should raise WorkerFailedError when worker process has died.""" + log_file = tmp_path / "test.log" + log_file.write_text("Worker crashed\n") + + mock_proc = create_mock_process(is_alive=False, exit_code=1) + worker = create_worker_info(process=mock_proc, log_file=str(log_file)) + scheduler._workers["test"] = [worker] + + with pytest.raises(WorkerFailedError) as exc_info: + asyncio.run(scheduler.create_engine("test/0", "test_engines.DummyEngine")) + + assert exc_info.value.worker_id == "test/0" + assert exc_info.value.exit_code == 1 + + def test_create_engine_invalid_engine_type(self, scheduler, tmp_path): + """Should raise EngineCreationError when engine is not a string.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + with pytest.raises(EngineCreationError) as exc_info: + asyncio.run(scheduler.create_engine("test/0", 123)) # Invalid type + + assert "Engine must be a string import path" in str(exc_info.value) + + def test_create_engine_import_error(self, scheduler, tmp_path): + """Should raise EngineImportError when engine import fails.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response( + status_code=400, + json_data={"detail": "Failed to import 'nonexistent.Engine'"}, + ) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + with pytest.raises(EngineImportError) as exc_info: + asyncio.run(scheduler.create_engine("test/0", "nonexistent.Engine")) + + assert "nonexistent.Engine" in str(exc_info.value) + + def test_create_engine_initialization_error(self, scheduler, tmp_path): + """Should raise EngineCreationError when engine initialization fails.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response( + status_code=500, + json_data={"detail": "Engine initialization failed: out of memory"}, + ) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + with pytest.raises(EngineCreationError) as exc_info: + asyncio.run( + scheduler.create_engine("test/0", "test_engines.DummyEngine") + ) + + assert "out of memory" in str(exc_info.value) + assert exc_info.value.status_code == 500 + + def test_create_engine_connection_error_worker_died(self, scheduler, tmp_path): + """Should raise WorkerFailedError when connection fails and worker is dead.""" + log_file = tmp_path / "test.log" + log_file.write_text("Worker crashed during engine creation\n") + + # First call returns None (alive), second call returns exit code (dead) + mock_proc = create_mock_process() + mock_proc.poll.side_effect = [None, 1] + mock_proc.returncode = 1 + + worker = create_worker_info(process=mock_proc, log_file=str(log_file)) + scheduler._workers["test"] = [worker] + + with patch.object( + scheduler._http_client, + "post", + side_effect=httpx.ConnectError("Connection refused"), + ): + with pytest.raises(WorkerFailedError) as exc_info: + asyncio.run( + scheduler.create_engine("test/0", "test_engines.DummyEngine") + ) + + assert exc_info.value.worker_id == "test/0" + + def test_create_engine_connection_error_worker_alive(self, scheduler, tmp_path): + """Should raise RPCConnectionError when connection fails but worker is alive.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + with patch.object( + scheduler._http_client, + "post", + side_effect=httpx.ConnectError("Connection refused"), + ): + with pytest.raises(RPCConnectionError) as exc_info: + asyncio.run( + scheduler.create_engine("test/0", "test_engines.DummyEngine") + ) + + assert exc_info.value.worker_id == "test/0" + assert exc_info.value.host == "127.0.0.1" + assert exc_info.value.port == 8000 + + def test_create_engine_timeout(self, scheduler, tmp_path): + """Should raise EngineCreationError when request times out.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + with patch.object( + scheduler._http_client, + "post", + side_effect=httpx.TimeoutException("Request timeout"), + ): + with pytest.raises(EngineCreationError) as exc_info: + asyncio.run( + scheduler.create_engine("test/0", "test_engines.DummyEngine") + ) + + assert "Request timed out" in str(exc_info.value) + + +class TestEngineMethodCalls: + """Test calling methods on engines (sync and async).""" + + def test_call_engine_success(self, scheduler, tmp_path): + """Should successfully call engine method synchronously.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response( + status_code=200, json_data={"result": 42} + ) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + result = scheduler.call_engine("test/0", "compute", arg1=10, arg2=20) + + assert result == 42 + + def test_call_engine_worker_not_found(self, scheduler): + """Should raise WorkerNotFoundError when worker doesn't exist.""" + with pytest.raises(WorkerNotFoundError): + scheduler.call_engine("nonexistent/0", "method") + + def test_call_engine_worker_died(self, scheduler, tmp_path): + """Should raise WorkerFailedError when worker dies before call.""" + log_file = tmp_path / "test.log" + log_file.write_text("Worker crashed\n") + + mock_proc = create_mock_process(is_alive=False, exit_code=1) + worker = create_worker_info(process=mock_proc, log_file=str(log_file)) + scheduler._workers["test"] = [worker] + + with pytest.raises(WorkerFailedError): + scheduler.call_engine("test/0", "method") + + def test_call_engine_method_error(self, scheduler, tmp_path): + """Should raise EngineCallError when method call returns 400/500.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response( + status_code=400, json_data={"detail": "Method 'nonexistent' not found"} + ) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + with pytest.raises(EngineCallError) as exc_info: + scheduler.call_engine("test/0", "nonexistent") + + assert "Method 'nonexistent' not found" in str(exc_info.value) + + @patch("areal.scheduler.local_scheduler.time.sleep") + def test_call_engine_retry_on_503(self, mock_sleep, scheduler, tmp_path): + """Should retry on 503 Service Unavailable.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + # First call returns 503, second call succeeds + mock_response_503 = create_mock_http_response(status_code=503) + mock_response_200 = create_mock_http_response( + status_code=200, json_data={"result": "success"} + ) + + with patch.object( + scheduler._http_client, + "post", + side_effect=[mock_response_503, mock_response_200], + ): + result = scheduler.call_engine("test/0", "method", max_retries=3) + + assert result == "success" + assert mock_sleep.called + + @patch("areal.scheduler.local_scheduler.time.sleep") + def test_call_engine_max_retries_exhausted(self, mock_sleep, scheduler, tmp_path): + """Should raise EngineCallError after max retries.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response(status_code=503) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + with pytest.raises(EngineCallError) as exc_info: + scheduler.call_engine("test/0", "method", max_retries=3) + + assert "Max retries exceeded" in str( + exc_info.value + ) or "Service unavailable" in str(exc_info.value) + assert exc_info.value.attempt == 3 + + @patch("areal.scheduler.local_scheduler.time.sleep") + def test_call_engine_exponential_backoff(self, mock_sleep, scheduler, tmp_path): + """Should use exponential backoff for retries.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response(status_code=503) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + try: + scheduler.call_engine( + "test/0", "method", max_retries=3, retry_delay=1.0 + ) + except EngineCallError: + pass + + # Verify exponential backoff: 1.0, 2.0 + sleep_calls = [call_args[0][0] for call_args in mock_sleep.call_args_list] + assert sleep_calls[0] == 1.0 # First retry + assert sleep_calls[1] == 2.0 # Second retry + + def test_async_call_engine_success(self, scheduler, tmp_path): + """Should successfully call engine method asynchronously.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + # Use Mock instead of AsyncMock for the response object + mock_response = create_mock_http_response( + status_code=200, json_data={"result": 42} + ) + + # But AsyncMock for post() since it's an async method + async_mock_post = AsyncMock(return_value=mock_response) + with patch.object(scheduler._async_http_client, "post", async_mock_post): + result = asyncio.run( + scheduler.async_call_engine("test/0", "compute", arg1=10, arg2=20) + ) + + assert result == 42 + + def test_async_call_engine_worker_not_found(self, scheduler): + """Should raise WorkerNotFoundError when worker doesn't exist (async).""" + with pytest.raises(WorkerNotFoundError): + asyncio.run(scheduler.async_call_engine("nonexistent/0", "method")) + + def test_async_call_engine_retry_with_backoff(self, scheduler, tmp_path): + """Should retry with exponential backoff in async mode.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + # First call returns 503, second call succeeds + # Use Mock (not AsyncMock) for response objects since response.json() is synchronous + mock_response_503 = create_mock_http_response(status_code=503) + mock_response_200 = create_mock_http_response( + status_code=200, json_data={"result": "success"} + ) + + # AsyncMock for post() since it's an async method + async_mock_post = AsyncMock(side_effect=[mock_response_503, mock_response_200]) + with patch.object(scheduler._async_http_client, "post", async_mock_post): + with patch("asyncio.sleep") as mock_async_sleep: + result = asyncio.run( + scheduler.async_call_engine("test/0", "method", max_retries=3) + ) + + assert result == "success" + assert mock_async_sleep.called + + +class TestFindWorkerById: + """Test finding workers by ID.""" + + def test_find_worker_by_id_success(self, scheduler, tmp_path): + """Should find worker by ID.""" + worker1 = create_worker_info( + worker_id="role1/0", + role="role1", + ports=["8000"], + log_file=str(tmp_path / "role1_0.log"), + ) + worker2 = create_worker_info( + worker_id="role2/0", + role="role2", + ports=["8001"], + log_file=str(tmp_path / "role2_0.log"), + ) + + scheduler._workers["role1"] = [worker1] + scheduler._workers["role2"] = [worker2] + + found = scheduler._find_worker_by_id("role2/0") + + assert found is worker2 + assert found.worker.id == "role2/0" + + def test_find_worker_by_id_not_found(self, scheduler, tmp_path): + """Should return None when worker ID is not found.""" + worker = create_worker_info( + worker_id="role1/0", role="role1", log_file=str(tmp_path / "role1_0.log") + ) + scheduler._workers["role1"] = [worker] + + found = scheduler._find_worker_by_id("nonexistent/99") + + assert found is None + + +class TestSchedulerCleanup: + """Test scheduler cleanup and destructor.""" + + def test_destructor_deletes_all_workers(self, scheduler, tmp_path): + """Should delete all workers when scheduler is destroyed.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + with patch.object(scheduler, "delete_workers") as mock_delete: + scheduler.__del__() + + mock_delete.assert_called_once() + + def test_destructor_closes_http_clients(self, scheduler): + """Should close HTTP clients when scheduler is destroyed.""" + with patch.object(scheduler._http_client, "close") as mock_close: + scheduler.__del__() + + mock_close.assert_called_once() + + def test_destructor_handles_errors_gracefully(self, scheduler): + """Should handle errors gracefully in destructor.""" + with patch.object(scheduler, "delete_workers", side_effect=Exception("Error")): + # Should not raise + scheduler.__del__() + + +class TestEdgeCases: + """Test various edge cases and corner scenarios.""" + + def test_gpu_counter_wraps_correctly(self, tmp_path): + """Should correctly wrap GPU counter for round-robin allocation.""" + scheduler = LocalScheduler(gpu_devices=[0, 1], log_dir=str(tmp_path)) + + # Allocate many times to ensure wrapping + for i in range(10): + gpus = scheduler._allocate_gpus(1) + expected_gpu = i % 2 + assert gpus == [expected_gpu] + + def test_port_allocation_accumulates_correctly(self, tmp_path): + """Should correctly accumulate allocated ports over multiple allocations.""" + with patch( + "areal.scheduler.local_scheduler.find_free_ports" + ) as mock_find_ports: + mock_find_ports.side_effect = [ + [8000, 8001], + [8002, 8003], + [8004, 8005, 8006], + ] + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + scheduler._allocate_ports(2) + scheduler._allocate_ports(2) + scheduler._allocate_ports(3) + + assert scheduler._allocated_ports == { + 8000, + 8001, + 8002, + 8003, + 8004, + 8005, + 8006, + } + + @patch("areal.scheduler.local_scheduler.gethostip") + @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_worker_id_format( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should create worker IDs in correct format (role/index).""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.return_value = [8000, 8001] + + mock_processes = [] + for i in range(5): + mock_proc = Mock() + mock_proc.pid = 1000 + i + mock_proc.poll.return_value = None + mock_processes.append(mock_proc) + mock_popen.side_effect = mock_processes + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + config = SchedulingConfig(replicas=5, role="worker") + worker_ids = scheduler.create_workers("worker", config) + + assert worker_ids == [ + "worker/0", + "worker/1", + "worker/2", + "worker/3", + "worker/4", + ] + + def test_empty_workers_dict_operations(self, tmp_path): + """Should handle operations on empty workers dictionary gracefully.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + # Delete all workers when none exist + scheduler.delete_workers(None) + + # Check health of non-existent role + scheduler._check_worker_health("nonexistent") + + # Find worker by ID when no workers exist + assert scheduler._find_worker_by_id("any/0") is None + + def test_concurrent_gpu_allocations(self, tmp_path): + """Should handle concurrent GPU allocations correctly.""" + scheduler = LocalScheduler(gpu_devices=[0, 1, 2], log_dir=str(tmp_path)) + + # Simulate multiple workers requesting GPUs simultaneously + results = [] + for _ in range(6): + gpus = scheduler._allocate_gpus(1) + results.append(gpus[0]) + + # Should cycle through GPUs in order + assert results == [0, 1, 2, 0, 1, 2] + + def test_log_directory_with_special_characters(self, tmp_path): + """Should handle log directory paths with special characters.""" + log_dir = tmp_path / "logs with spaces" / "special-chars_123" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(log_dir)) + + assert log_dir.exists() + assert scheduler.log_dir == log_dir From 266d6d6ed3b7a8f7877696ebd98618e8a939de6d Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Tue, 28 Oct 2025 22:42:08 +0800 Subject: [PATCH 09/52] implement run workflow endpoint and rolllout controller --- areal/controller/__init__.py | 9 + areal/controller/rollout_controller.py | 713 +++++++++++++++++++++++++ areal/scheduler/local_scheduler.py | 23 +- areal/scheduler/rpc/rpc_server.py | 134 +++++ 4 files changed, 870 insertions(+), 9 deletions(-) create mode 100644 areal/controller/__init__.py create mode 100644 areal/controller/rollout_controller.py diff --git a/areal/controller/__init__.py b/areal/controller/__init__.py new file mode 100644 index 000000000..d6905531d --- /dev/null +++ b/areal/controller/__init__.py @@ -0,0 +1,9 @@ +"""Controller components for managing distributed training and inference.""" + +from areal.controller.batch import DistributedBatchMemory +from areal.controller.rollout_controller import RolloutController + +__all__ = [ + "DistributedBatchMemory", + "RolloutController", +] diff --git a/areal/controller/rollout_controller.py b/areal/controller/rollout_controller.py new file mode 100644 index 000000000..7680f6f43 --- /dev/null +++ b/areal/controller/rollout_controller.py @@ -0,0 +1,713 @@ +"""RolloutController implementation using LocalScheduler and RPC workers.""" + +from __future__ import annotations + +import asyncio +import queue +import random +import time +from collections.abc import Callable +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.cli_args import InferenceEngineConfig +from areal.api.controller_api import DistributedBatch +from areal.api.controller_api import RolloutController as RolloutControllerAPI +from areal.api.engine_api import InferenceEngine +from areal.api.io_struct import ModelRequest, ModelResponse, ParamSpec, WeightUpdateMeta +from areal.api.scheduler_api import Scheduler, SchedulingConfig, Worker +from areal.controller.batch import DistributedBatchMemory +from areal.core.async_task_runner import AsyncTaskRunner, TaskQueueFullError +from areal.core.staleness_manager import StalenessManager +from areal.utils import logging +from areal.utils.data import cycle_dataloader + +if TYPE_CHECKING: + pass + + +@dataclass +class _RemoteRolloutTaskInput: + data: dict[str, Any] + workflow_path: str + workflow_kwargs: dict[str, Any] + should_accept_path: str | None = None + + +class RolloutController(RolloutControllerAPI): + """A centralized controller managing multiple InferenceEngine workers for rollout generation. + + This controller orchestrates distributed inference by: + 1. Launching local inference engines on workers via scheduler + 2. Scheduling requests to specific engines via round-robin + 3. Delegating actual execution to AsyncTaskRunner + 4. Aggregating results from workers into DistributedBatch + + Parameters + ---------- + inf_engine : InferenceEngine + The inference engine class to instantiate on each worker + config : InferenceEngineConfig + Configuration for inference engines + scheduler : Scheduler + Scheduler for worker management + """ + + def __init__( + self, + inf_engine: InferenceEngine, + config: InferenceEngineConfig, + scheduler: Scheduler, + ): + """Initialize the RolloutController. + + Parameters + ---------- + inf_engine : InferenceEngine + The inference engine class (not instance) to create on workers + config : InferenceEngineConfig + Configuration for the inference engines + scheduler : Scheduler + Scheduler for managing workers + """ + super().__init__(inf_engine, config, scheduler) + + # Worker management + self.workers: list[Worker] = [] # List of Worker objects from scheduler + self.num_workers = 0 + self._worker_role = "rollout" # Role name for workers + + # Round-robin scheduling + self._current_worker_idx = 0 + + # Async task execution + self.runner: AsyncTaskRunner | None = None + + # Thread pool for weight updates + self.executor: ThreadPoolExecutor | None = None + + # Logging + self.logger = None + + # State + self._initialized = False + self._version = 0 + + # Staleness management + self.staleness_manager: StalenessManager | None = None + self._pending_inputs: list[ + _RemoteRolloutTaskInput + ] = [] # Queue for inputs waiting for capacity + + def initialize(self, num_workers: int = 1, *args, **kwargs): + """Initialize the controller by creating workers and deploying engines. + + Parameters + ---------- + num_workers : int + Number of worker instances to create (default: 1) + *args + Additional positional arguments + **kwargs + Additional keyword arguments including: + - scheduling_config: SchedulingConfig for worker creation + - engine_init_args: List of args to pass to engine initialization + - engine_init_kwargs: Dict of kwargs to pass to engine initialization + - timeout: Timeout for worker creation + """ + if self._initialized: + self.logger.warning("RolloutController already initialized, skipping...") + return + + self.logger = logging.getLogger("[RolloutController]") + self.logger.info( + f"Initializing RolloutController with {num_workers} workers..." + ) + + self.num_workers = num_workers + + # Get scheduling config from kwargs or use defaults + scheduling_config = kwargs.get( + "scheduling_config", + SchedulingConfig(replicas=num_workers), + ) + + # Get engine initialization parameters + engine_init_args = kwargs.get("engine_init_args", []) + engine_init_kwargs = kwargs.get("engine_init_kwargs", {}) + timeout = kwargs.get("timeout", 60.0) + + # Use asyncio.run to call async scheduler methods synchronously + asyncio.run( + self._async_initialize( + scheduling_config, + engine_init_args, + engine_init_kwargs, + timeout, + ) + ) + + # Initialize AsyncTaskRunner for task execution + max_queue_size = getattr(self.config, "max_queue_size", 1024) + self.runner = AsyncTaskRunner(max_queue_size=max_queue_size) + self.runner.initialize(logger=self.logger) + + # Initialize thread pool for weight updates + self.executor = ThreadPoolExecutor(max_workers=num_workers) + + # Initialize staleness manager for global capacity control + max_concurrent_rollouts = ( + self.config.max_concurrent_rollouts or self.config.consumer_batch_size + ) + consumer_batch_size = self.config.consumer_batch_size + + self.staleness_manager = StalenessManager( + max_concurrent_rollouts=max_concurrent_rollouts, + consumer_batch_size=consumer_batch_size, + max_staleness=self.config.max_head_offpolicyness, + ) + + self._initialized = True + self.logger.info(f"RolloutController initialized with {num_workers} workers") + + async def _async_initialize( + self, + scheduling_config: SchedulingConfig, + engine_init_args: list, + engine_init_kwargs: dict, + timeout: float, + ): + """Async helper to initialize workers and engines. + + Parameters + ---------- + scheduling_config : SchedulingConfig + Configuration for worker creation + engine_init_args : list + Positional arguments for engine initialization + engine_init_kwargs : dict + Keyword arguments for engine initialization + timeout : float + Timeout for worker readiness + """ + # Create workers via scheduler + self.logger.info("Creating workers via scheduler...") + worker_ids = self.scheduler.create_workers( + role=self._worker_role, + scheduler_config=scheduling_config, + ) + self.logger.info(f"Workers created: {worker_ids}") + + # Wait for workers to be ready + self.logger.info("Waiting for workers to be ready...") + self.workers = self.scheduler.get_workers( + role=self._worker_role, + timeout=timeout, + ) + self.logger.info(f"Workers ready: {[w.id for w in self.workers]}") + + # Get engine class path for dynamic import on workers + engine_class = self.inf_engine + engine_path = f"{engine_class.__module__}.{engine_class.__name__}" + + # Create and initialize engines on workers + for i, worker in enumerate(self.workers): + self.logger.info(f"Creating engine on worker {worker.id}...") + + # Create engine on worker + await self.scheduler.create_engine( + worker_id=worker.id, + engine=engine_path, + init_args=engine_init_args, + init_kwargs={ + **engine_init_kwargs, + "engine_id": f"worker_{i}", + }, + ) + self.logger.info(f"Engine created on worker {worker.id}") + + def destroy(self): + """Destroy the controller and clean up resources.""" + if not self._initialized: + return + + self.logger.info("Destroying RolloutController...") + + # Destroy task runner + if self.runner is not None: + self.runner.destroy() + self.runner = None + + # Delete workers via scheduler + try: + self.scheduler.delete_workers(role=self._worker_role) + self.logger.info("Workers deleted") + except Exception as e: + self.logger.error(f"Error deleting workers: {e}") + + self.workers.clear() + + # Shutdown executor + if self.executor is not None: + self.executor.shutdown(wait=True) + self.executor = None + + self._initialized = False + self.logger.info("RolloutController destroyed") + + def get_capacity(self) -> int: + """Get current available capacity for new rollouts based on staleness. + + Returns + ------- + int + Number of new rollout slots available based on staleness constraints + """ + if not self._initialized: + return 0 + version = self.get_version() # Use controller's global version + return self.staleness_manager.get_capacity(version) + + def _choose_worker(self) -> Worker: + """Choose a worker for the next request using round-robin scheduling. + + Returns + ------- + Worker + The chosen worker object + """ + if self.num_workers == 0: + raise RuntimeError("No workers available") + + worker = self.workers[self._current_worker_idx] + self._current_worker_idx = (self._current_worker_idx + 1) % self.num_workers + return worker + + async def _run_workflow_on_worker( + self, + worker: Worker, + data: dict[str, Any], + workflow_path: str, + workflow_kwargs: dict[str, Any], + should_accept_path: str | None = None, + ) -> dict[str, Any] | None: + # Call run_workflow on worker via scheduler + # This will hit the /run_workflow endpoint + result = await self.scheduler.async_call_engine( + worker_id=worker.id, + method="run_workflow", + workflow=workflow_path, + workflow_kwargs=workflow_kwargs, + data=data, + should_accept_path=should_accept_path, + check_trajectory_format=self.config.check_trajectory_format, + ) + + # The RPCServer will return None if the + # trajectory is rejected. + if result is not None: + self.staleness_manager.on_rollout_accepted() + if self.config.enable_rollout_tracing: + stat = self.staleness_manager.get_stats() + self.logger.info( + f"Finish and accept rollout. " + f"Submit: {stat.submitted}, " + f"running: {stat.running}, " + f"accepted: {stat.accepted}." + ) + return result + else: + self.staleness_manager.on_rollout_rejected() + if self.config.enable_rollout_tracing: + stat = self.staleness_manager.get_stats() + self.logger.info( + f"Finish but reject rollout. " + f"Submit: {stat.submitted}, " + f"running: {stat.running}, " + f"accepted: {stat.accepted}." + ) + return None + + def submit( + self, + data: dict[str, Any], + workflow_path: str, + workflow_kwargs: dict[str, Any], + should_accept_path: str | None = None, + ) -> None: + if not self._initialized: + raise RuntimeError("RolloutController not initialized") + + # Add to pending queue (will be submitted when capacity allows) + self._pending_inputs.append( + _RemoteRolloutTaskInput( + data=data, + workflow_kwargs=workflow_kwargs, + workflow_path=workflow_path, + should_accept_path=should_accept_path, + ) + ) + + def _commit_one_to_runner(self): + """Commit one pending input to task runner with staleness tracking.""" + task_input = self._pending_inputs.pop(0) + + # Choose worker via round-robin + worker = self._choose_worker() + + # Submit to AsyncTaskRunner + try: + self.runner.submit( + self._run_workflow_on_worker, + worker, + task_input.data, + task_input.workflow_path, + task_input.workflow_kwargs, + task_input.should_accept_path, + ) + except TaskQueueFullError: + raise queue.Full("Input queue full") + + # Notify staleness manager AFTER successful submission + self.staleness_manager.on_rollout_submitted() + if self.config.enable_rollout_tracing: + stat = self.staleness_manager.get_stats() + self.logger.info( + f"Submit rollout. " + f"Submit: {stat.submitted}, " + f"running: {stat.running}, " + f"accepted: {stat.accepted}." + ) + + def wait(self, count: int, timeout: float | None = None) -> DistributedBatch: + if not self._initialized: + raise RuntimeError("RolloutController not initialized") + + ####################################################### + # The following logic is copied from WorkflowExecutor # + ####################################################### + start_time = time.perf_counter() + timeout = timeout or float(7 * 24 * 3600) + + # Keep requesting results from runner until we have enough accepted + # (non-None) results. Use short timeout (1 second) for each wait call + # to allow periodic checking. This matches original behavior where + # wait() would poll and allow prepare_batch() to continue + while True: + # Submit pending inputs + # Check capacity before submitting + capacity = self.get_capacity() + # Submit pending tasks + for _ in range(capacity): + if len(self._pending_inputs) == 0: + break + self._commit_one_to_runner() + + if len(self._pending_results) >= count: + break + + elapsed = time.perf_counter() - start_time + remaining_timeout = timeout - elapsed + + if remaining_timeout <= 0: + raise TimeoutError( + f"Timed out waiting for {count} rollouts, only received " + f"{len(self._pending_results)}." + ) + + # Try to get at least the number we still need, but request at least 1 + # Note: runner.wait() might return fewer due to rejections (None results) + needed = max(1, count - len(self._pending_results)) + + try: + # Use short timeout to allow periodic returns (matches original + # polling behavior) + batch = self.runner.wait( + count=needed, timeout=min(0.1, remaining_timeout) + ) + + # Filter out None results (rejected trajectories) + # runner.wait() returns List[T] where T can be None for + # rejected rollouts + accepted = [result for result in batch if result is not None] + self._pending_results.extend(accepted) + except TimeoutError: + pass + + if self.config.enable_rollout_tracing: + self.logger.info("Rollout results are ready!") + + # Extract requested number of results + results = self._pending_results[:count] + self._pending_results = self._pending_results[count:] + + # Shuffle for randomness (helps with data diversity) + random.shuffle(results) + + # Convert to DistributedBatch + if len(results) == 0: + return DistributedBatchMemory.from_dict({}) + + return DistributedBatchMemory.concat( + [DistributedBatchMemory.from_dict(r) for r in results] + ) + + def rollout_batch( + self, + data: list[dict[str, Any]], + workflow_path: str, + workflow_kwargs: dict[str, Any], + should_accept_path: str | None = None, + ) -> DistributedBatch: + if not self._initialized: + raise RuntimeError("RolloutController not initialized") + + # Submit all requests + for item in data: + self.submit( + item, + workflow_kwargs=workflow_kwargs, + workflow_path=workflow_path, + should_accept_path=should_accept_path, + ) + + # Wait for all results + return self.wait(count=len(data)) + + def prepare_batch( + self, + dataloader: StatefulDataLoader, + workflow_path: str, + workflow_kwargs: dict[str, Any], + should_accept_path: str | None = None, + ) -> DistributedBatch: + if not self._initialized: + raise RuntimeError("RolloutController not initialized") + + ####################################################### + # The following logic is copied from WorkflowExecutor # + ####################################################### + if not hasattr(self, "data_generator"): + self.data_generator = cycle_dataloader(dataloader) + assert dataloader.batch_size is not None + while True: + # Submit at least two batches to allow maximum overlap + if ( + self.get_capacity() + dataloader.batch_size > 0 + and self.runner.get_input_queue_size() + dataloader.batch_size + < self.runner.max_queue_size + ): + data = next(self.data_generator) + for item in data: + try: + self.submit( + item, + workflow_kwargs=workflow_kwargs, + workflow_path=workflow_path, + should_accept_path=should_accept_path, + ) + except queue.Full: + # Capacity exhausted during batch submission, stop and wait + break + try: + return self.wait(dataloader.batch_size, timeout=1) + except TimeoutError: + pass + + async def agenerate(self, req: ModelRequest) -> ModelResponse: + """Asynchronously generate a response for the given request. + + Parameters + ---------- + req : ModelRequest + Model request containing input data and generation parameters + + Returns + ------- + ModelResponse + Generated response from the model + """ + if not self._initialized: + raise RuntimeError("RolloutController not initialized") + + # Choose worker and delegate + worker = self._choose_worker() + + # Call agenerate on engine via scheduler + return await self.scheduler.async_call_engine( + worker_id=worker.id, + method="agenerate", + req=req, + ) + + def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: + """Initialize weight update process group for all workers. + + Parameters + ---------- + meta : WeightUpdateMeta + Metadata containing weight update information + + Returns + ------- + Future[None] + Future representing the async initialization operation + """ + if not self._initialized: + raise RuntimeError("RolloutController not initialized") + + async def _init_all_workers(): + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, + method="init_weights_update_group", + meta=meta, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + + def init_all_workers(): + asyncio.run(_init_all_workers()) + + return self.executor.submit(init_all_workers) + + def update_weights_from_distributed( + self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] + ) -> Future[None]: + """Update weights from distributed memory for all workers. + + Parameters + ---------- + meta : WeightUpdateMeta + Metadata containing weight update information + param_specs : list[ParamSpec] + Parameter specifications for weights to update + + Returns + ------- + Future[None] + Future representing the async update operation + """ + if not self._initialized: + raise RuntimeError("RolloutController not initialized") + + async def _update_all_workers(): + tasks = [ + self.scheduler.call_engine( + worker_id=worker.id, + method="update_weights_from_distributed", + meta=meta, + param_specs=param_specs, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + + def update_all_workers(): + asyncio.run(_update_all_workers()) + + return self.executor.submit(update_all_workers) + + def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: + """Update weights from disk for all workers. + + Parameters + ---------- + meta : WeightUpdateMeta + Metadata containing weight update information + + Returns + ------- + Future[None] + Future representing the async update operation + """ + if not self._initialized: + raise RuntimeError("RolloutController not initialized") + + async def _update_all_workers(): + tasks = [ + self.scheduler.call_engine( + worker_id=worker.id, + method="update_weights_from_disk", + meta=meta, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + + def update_all_workers(): + asyncio.run(_update_all_workers) + + return self.executor.submit(update_all_workers) + + def set_version(self, version: int) -> None: + """Set the current weight version for all workers. + + Parameters + ---------- + version : int + Weight version number to set + """ + if not self._initialized: + raise RuntimeError("RolloutController not initialized") + + self._version = version + for worker in self.workers: + try: + self.scheduler.call_engine( + worker_id=worker.id, + method="set_version", + version=version, + ) + except Exception as e: + self.logger.error(f"Error setting version for worker {worker.id}: {e}") + + def get_version(self) -> int: + """Get the current weight version. + + Returns + ------- + int + Current weight version number + """ + return self._version + + def pause(self): + """Pause request submission for async rollout on all workers.""" + if not self._initialized: + return + + for worker in self.workers: + try: + self.scheduler.call_engine( + worker_id=worker.id, + method="pause", + ) + except Exception as e: + self.logger.error(f"Error pausing worker {worker.id}: {e}") + + def resume(self): + """Resume request submission for async rollout on all workers.""" + if not self._initialized: + return + + for worker in self.workers: + try: + self.scheduler.call_engine( + worker_id=worker.id, + method="resume", + ) + except Exception as e: + self.logger.error(f"Error resuming worker {worker.id}: {e}") + + def register_callback_to_all_worker( + self, method: str, callback: Callable, **kwargs + ): + raise NotImplementedError() + + def abort_all_requests(self) -> None: + raise NotImplementedError() diff --git a/areal/scheduler/local_scheduler.py b/areal/scheduler/local_scheduler.py index 693dd47db..8755c0996 100644 --- a/areal/scheduler/local_scheduler.py +++ b/areal/scheduler/local_scheduler.py @@ -832,16 +832,21 @@ async def async_call_engine( if worker_info is None: raise WorkerNotFoundError(worker_id) - # Build JSON payload - payload = { - "method": method, - "args": list(args), - "kwargs": kwargs, - } - - # Retry logic with exponential backoff + # Route to different endpoint based on method port = int(worker_info.worker.ports[0]) - url = f"http://{worker_info.worker.ip}:{port}/call" + if method == "run_workflow": + # Special routing for workflow execution + url = f"http://{worker_info.worker.ip}:{port}/run_workflow" + payload = kwargs + else: + # Standard engine method call + url = f"http://{worker_info.worker.ip}:{port}/call" + payload = { + "method": method, + "args": list(args), + "kwargs": kwargs, + } + last_error = None for attempt in range(1, max_retries + 1): diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index 3ea1574f9..7d4fb22fe 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -49,6 +49,7 @@ async def lifespan(app: FastAPI): default_response_class=ORJSONResponse, lifespan=lifespan, ) +app._expected_trajectory_keys = None @app.get("/health") @@ -214,6 +215,139 @@ async def call_engine_method(request: Request): raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") +@app.post("/run_workflow") +async def run_workflow(request: Request): + """ + Run a workflow's arun_episode method directly without using the engine. + + Expected JSON payload: + { + "workflow": "areal.api.workflow_api.RolloutWorkflow", # Import path + "workflow_kwargs": {...}, # Keyword arguments for workflow instantiation + "data": {...} # Data to pass to arun_episode + } + """ + try: + body = await request.body() + data = orjson.loads(body) + + workflow_path = data.get("workflow") + workflow_kwargs = data.get("workflow_kwargs") + episode_data = data.get("data") + should_accept_path = data.get("should_accept_path", None) + check_trajectory_format = data.get("check_trajectory_format") + + if not workflow_path: + raise HTTPException( + status_code=400, detail="Missing 'workflow' field in request" + ) + + if episode_data is None: + raise HTTPException( + status_code=400, detail="Missing 'data' field in request" + ) + + # Dynamic import workflow + try: + module_path, class_name = workflow_path.rsplit(".", 1) + module = importlib.import_module(module_path) + workflow_class = getattr(module, class_name) + logger.info(f"Imported workflow class: {workflow_path}") + except (ValueError, ImportError, AttributeError) as e: + logger.error(f"Failed to import workflow '{workflow_path}': {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to import workflow '{workflow_path}': {str(e)}", + ) + # Instantiate workflow + try: + workflow = workflow_class(**workflow_kwargs) + logger.info(f"Workflow '{workflow_path}' instantiated successfully") + except Exception as e: + logger.error( + f"Failed to instantiate workflow: {e}\n{traceback.format_exc()}" + ) + raise HTTPException( + status_code=500, + detail=f"Failed to instantiate workflow: {str(e)}", + ) + + should_accept = None + if should_accept_path is not None: + # Dynamic import filtering function + try: + module_path, fn_name = should_accept_path.rsplit(".", 1) + module = importlib.import_module(module_path) + should_accept = getattr(module, fn_name) + logger.info(f"Imported filtering function: {should_accept_path}") + except (ValueError, ImportError, AttributeError) as e: + logger.error( + f"Failed to import filtering function '{should_accept_path}': {e}" + ) + raise HTTPException( + status_code=400, + detail=f"Failed to import filtering function '{should_accept_path}': {str(e)}", + ) + + # Run episode + try: + traj = await workflow.arun_episode(episode_data) + + global app + if check_trajectory_format and traj is not None: + from areal.core.workflow_executor import ( + check_trajectory_format as check_fn, + ) + + check_fn( + traj, + expected_keys=app._expected_trajectory_keys, + logger=logger, + ) + # Track expected keys for consistency checking + if isinstance(traj, dict) and "input_ids" in traj: + if app._expected_trajectory_keys is None: + app._expected_trajectory_keys = set(traj.keys()) + logger.info( + f"Trajectory format check: tracking keys " + f"{app._expected_trajectory_keys}" + ) + + from areal.experimental.openai.types import InteractionWithTokenLogpReward + from areal.utils.data import concat_padded_tensors + + # Convert InteractionWithTokenLogpReward to tensor dict if needed + if isinstance(traj, dict) and all( + isinstance(v, InteractionWithTokenLogpReward) for v in traj.values() + ): + traj = concat_padded_tensors( + [v.to_tensor_dict() for v in traj.values()] + ) + + assert traj is None or isinstance(traj, dict), traj + + # Apply should_accept filtering + accept_this = traj is not None and ( + should_accept is None or should_accept(traj) + ) + if accept_this: + return {"status": "success", "result": traj} + else: + return {"status": "success", "result": None} + except Exception as e: + logger.error(f"Workflow arun_episode failed: {e}\n{traceback.format_exc()}") + raise HTTPException( + status_code=500, + detail=f"Workflow arun_episode failed: {str(e)}", + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Unexpected error in run_workflow: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + def main(): """Main entry point for the RPC server.""" parser = argparse.ArgumentParser(description="AReaL Worker RPC Server") From f67dd6091847d9c2b058d3564c216dc3295e11b4 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 09:42:26 +0800 Subject: [PATCH 10/52] add tensor serialization --- areal/controller/rollout_controller.py | 129 ++------ areal/scheduler/local_scheduler.py | 26 +- areal/scheduler/rpc/rpc_server.py | 24 +- areal/scheduler/rpc/serialization.py | 213 ++++++++++++++ areal/tests/test_serialization.py | 391 +++++++++++++++++++++++++ areal/tests/utils.py | 24 ++ 6 files changed, 684 insertions(+), 123 deletions(-) create mode 100644 areal/scheduler/rpc/serialization.py create mode 100644 areal/tests/test_serialization.py diff --git a/areal/controller/rollout_controller.py b/areal/controller/rollout_controller.py index 7680f6f43..f162539af 100644 --- a/areal/controller/rollout_controller.py +++ b/areal/controller/rollout_controller.py @@ -9,10 +9,11 @@ from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import Any from torchdata.stateful_dataloader import StatefulDataLoader +from areal.api.alloc_mode import AllocationMode from areal.api.cli_args import InferenceEngineConfig from areal.api.controller_api import DistributedBatch from areal.api.controller_api import RolloutController as RolloutControllerAPI @@ -25,8 +26,7 @@ from areal.utils import logging from areal.utils.data import cycle_dataloader -if TYPE_CHECKING: - pass +CREATE_WORKER_TIMEOUT = 60.0 @dataclass @@ -77,7 +77,6 @@ def __init__( # Worker management self.workers: list[Worker] = [] # List of Worker objects from scheduler - self.num_workers = 0 self._worker_role = "rollout" # Role name for workers # Round-robin scheduling @@ -93,7 +92,6 @@ def __init__( self.logger = None # State - self._initialized = False self._version = 0 # Staleness management @@ -102,61 +100,28 @@ def __init__( _RemoteRolloutTaskInput ] = [] # Queue for inputs waiting for capacity - def initialize(self, num_workers: int = 1, *args, **kwargs): - """Initialize the controller by creating workers and deploying engines. - - Parameters - ---------- - num_workers : int - Number of worker instances to create (default: 1) - *args - Additional positional arguments - **kwargs - Additional keyword arguments including: - - scheduling_config: SchedulingConfig for worker creation - - engine_init_args: List of args to pass to engine initialization - - engine_init_kwargs: Dict of kwargs to pass to engine initialization - - timeout: Timeout for worker creation - """ - if self._initialized: - self.logger.warning("RolloutController already initialized, skipping...") - return - + def initialize( + self, + alloc_mode: AllocationMode, + ): self.logger = logging.getLogger("[RolloutController]") - self.logger.info( - f"Initializing RolloutController with {num_workers} workers..." - ) - - self.num_workers = num_workers # Get scheduling config from kwargs or use defaults - scheduling_config = kwargs.get( - "scheduling_config", - SchedulingConfig(replicas=num_workers), - ) - - # Get engine initialization parameters - engine_init_args = kwargs.get("engine_init_args", []) - engine_init_kwargs = kwargs.get("engine_init_kwargs", {}) - timeout = kwargs.get("timeout", 60.0) + # FIXME: Should get scheduling config in a more strategical way + scheduling_config = SchedulingConfig(replicas=alloc_mode.gen.dp_size) # Use asyncio.run to call async scheduler methods synchronously - asyncio.run( - self._async_initialize( - scheduling_config, - engine_init_args, - engine_init_kwargs, - timeout, - ) - ) + asyncio.run(self._async_initialize(scheduling_config)) # Initialize AsyncTaskRunner for task execution - max_queue_size = getattr(self.config, "max_queue_size", 1024) - self.runner = AsyncTaskRunner(max_queue_size=max_queue_size) + self.runner = AsyncTaskRunner( + max_queue_size=self.config.queue_size, + enable_tracing=self.config.enable_rollout_tracing, + ) self.runner.initialize(logger=self.logger) # Initialize thread pool for weight updates - self.executor = ThreadPoolExecutor(max_workers=num_workers) + self.executor = ThreadPoolExecutor(max_workers=alloc_mode.gen.dp_size) # Initialize staleness manager for global capacity control max_concurrent_rollouts = ( @@ -170,29 +135,10 @@ def initialize(self, num_workers: int = 1, *args, **kwargs): max_staleness=self.config.max_head_offpolicyness, ) - self._initialized = True - self.logger.info(f"RolloutController initialized with {num_workers} workers") - async def _async_initialize( self, scheduling_config: SchedulingConfig, - engine_init_args: list, - engine_init_kwargs: dict, - timeout: float, ): - """Async helper to initialize workers and engines. - - Parameters - ---------- - scheduling_config : SchedulingConfig - Configuration for worker creation - engine_init_args : list - Positional arguments for engine initialization - engine_init_kwargs : dict - Keyword arguments for engine initialization - timeout : float - Timeout for worker readiness - """ # Create workers via scheduler self.logger.info("Creating workers via scheduler...") worker_ids = self.scheduler.create_workers( @@ -205,7 +151,7 @@ async def _async_initialize( self.logger.info("Waiting for workers to be ready...") self.workers = self.scheduler.get_workers( role=self._worker_role, - timeout=timeout, + timeout=CREATE_WORKER_TIMEOUT, ) self.logger.info(f"Workers ready: {[w.id for w in self.workers]}") @@ -221,19 +167,12 @@ async def _async_initialize( await self.scheduler.create_engine( worker_id=worker.id, engine=engine_path, - init_args=engine_init_args, - init_kwargs={ - **engine_init_kwargs, - "engine_id": f"worker_{i}", - }, + init_kwargs=dict(config=self.config), ) self.logger.info(f"Engine created on worker {worker.id}") def destroy(self): """Destroy the controller and clean up resources.""" - if not self._initialized: - return - self.logger.info("Destroying RolloutController...") # Destroy task runner @@ -255,7 +194,6 @@ def destroy(self): self.executor.shutdown(wait=True) self.executor = None - self._initialized = False self.logger.info("RolloutController destroyed") def get_capacity(self) -> int: @@ -266,8 +204,6 @@ def get_capacity(self) -> int: int Number of new rollout slots available based on staleness constraints """ - if not self._initialized: - return 0 version = self.get_version() # Use controller's global version return self.staleness_manager.get_capacity(version) @@ -279,11 +215,9 @@ def _choose_worker(self) -> Worker: Worker The chosen worker object """ - if self.num_workers == 0: - raise RuntimeError("No workers available") worker = self.workers[self._current_worker_idx] - self._current_worker_idx = (self._current_worker_idx + 1) % self.num_workers + self._current_worker_idx = (self._current_worker_idx + 1) % len(self.workers) return worker async def _run_workflow_on_worker( @@ -338,9 +272,6 @@ def submit( workflow_kwargs: dict[str, Any], should_accept_path: str | None = None, ) -> None: - if not self._initialized: - raise RuntimeError("RolloutController not initialized") - # Add to pending queue (will be submitted when capacity allows) self._pending_inputs.append( _RemoteRolloutTaskInput( @@ -383,9 +314,6 @@ def _commit_one_to_runner(self): ) def wait(self, count: int, timeout: float | None = None) -> DistributedBatch: - if not self._initialized: - raise RuntimeError("RolloutController not initialized") - ####################################################### # The following logic is copied from WorkflowExecutor # ####################################################### @@ -462,9 +390,6 @@ def rollout_batch( workflow_kwargs: dict[str, Any], should_accept_path: str | None = None, ) -> DistributedBatch: - if not self._initialized: - raise RuntimeError("RolloutController not initialized") - # Submit all requests for item in data: self.submit( @@ -484,9 +409,6 @@ def prepare_batch( workflow_kwargs: dict[str, Any], should_accept_path: str | None = None, ) -> DistributedBatch: - if not self._initialized: - raise RuntimeError("RolloutController not initialized") - ####################################################### # The following logic is copied from WorkflowExecutor # ####################################################### @@ -530,8 +452,6 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse: ModelResponse Generated response from the model """ - if not self._initialized: - raise RuntimeError("RolloutController not initialized") # Choose worker and delegate worker = self._choose_worker() @@ -556,8 +476,6 @@ def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: Future[None] Future representing the async initialization operation """ - if not self._initialized: - raise RuntimeError("RolloutController not initialized") async def _init_all_workers(): tasks = [ @@ -592,8 +510,6 @@ def update_weights_from_distributed( Future[None] Future representing the async update operation """ - if not self._initialized: - raise RuntimeError("RolloutController not initialized") async def _update_all_workers(): tasks = [ @@ -625,8 +541,6 @@ def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: Future[None] Future representing the async update operation """ - if not self._initialized: - raise RuntimeError("RolloutController not initialized") async def _update_all_workers(): tasks = [ @@ -652,8 +566,6 @@ def set_version(self, version: int) -> None: version : int Weight version number to set """ - if not self._initialized: - raise RuntimeError("RolloutController not initialized") self._version = version for worker in self.workers: @@ -678,8 +590,6 @@ def get_version(self) -> int: def pause(self): """Pause request submission for async rollout on all workers.""" - if not self._initialized: - return for worker in self.workers: try: @@ -692,9 +602,6 @@ def pause(self): def resume(self): """Resume request submission for async rollout on all workers.""" - if not self._initialized: - return - for worker in self.workers: try: self.scheduler.call_engine( diff --git a/areal/scheduler/local_scheduler.py b/areal/scheduler/local_scheduler.py index 8755c0996..dbcb82794 100644 --- a/areal/scheduler/local_scheduler.py +++ b/areal/scheduler/local_scheduler.py @@ -27,6 +27,7 @@ WorkerNotFoundError, WorkerTimeoutError, ) +from areal.scheduler.rpc.serialization import deserialize_value, serialize_value from areal.utils import logging from areal.utils.network import find_free_ports, gethostip @@ -732,11 +733,15 @@ def call_engine( if worker_info is None: raise WorkerNotFoundError(worker_id) + # Serialize args and kwargs (convert tensors to SerializedTensor dicts) + serialized_args = serialize_value(list(args)) + serialized_kwargs = serialize_value(kwargs) + # Build JSON payload payload = { "method": method, - "args": list(args), - "kwargs": kwargs, + "args": serialized_args, + "kwargs": serialized_kwargs, } # Retry logic with exponential backoff @@ -837,18 +842,23 @@ async def async_call_engine( if method == "run_workflow": # Special routing for workflow execution url = f"http://{worker_info.worker.ip}:{port}/run_workflow" - payload = kwargs + # Serialize kwargs for workflow execution + payload = serialize_value(kwargs) else: # Standard engine method call url = f"http://{worker_info.worker.ip}:{port}/call" + # Serialize args and kwargs + serialized_args = serialize_value(list(args)) + serialized_kwargs = serialize_value(kwargs) payload = { "method": method, - "args": list(args), - "kwargs": kwargs, + "args": serialized_args, + "kwargs": serialized_kwargs, } last_error = None + print(url) for attempt in range(1, max_retries + 1): # Check worker health before each attempt if worker_info.process.poll() is not None: @@ -870,6 +880,7 @@ async def async_call_engine( headers={"Content-Type": "application/json"}, timeout=7200.0, # 2 hours for long-running operations ) + print(response, payload, response.json()) result, should_retry, error_msg = self._handle_call_response( response, worker_id, method, attempt @@ -962,7 +973,10 @@ def _handle_call_response( - error_message: Error message if failed, None if successful """ if response.status_code == 200: - return response.json().get("result"), False, None + result = response.json().get("result") + # Deserialize result (convert SerializedTensor dicts back to tensors) + deserialized_result = deserialize_value(result) + return deserialized_result, False, None elif response.status_code == 400: # Bad request (e.g., method doesn't exist) - don't retry error_detail = response.json().get("detail", "Unknown error") diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index 7d4fb22fe..1b5efd057 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -15,6 +15,7 @@ from fastapi.responses import ORJSONResponse from areal.api.engine_api import InferenceEngine, TrainEngine +from areal.scheduler.rpc.serialization import deserialize_value, serialize_value from areal.utils import logging logger = logging.getLogger("RPCServer") @@ -181,6 +182,10 @@ async def call_engine_method(request: Request): status_code=400, detail="Missing 'method' field in request" ) + # Deserialize args and kwargs (convert SerializedTensor dicts to tensors) + args = deserialize_value(args) + kwargs = deserialize_value(kwargs) + # Call method directly (no need for hasattr/getattr with typed engine) logger.info(f"Calling engine method: {method_name}") try: @@ -188,10 +193,9 @@ async def call_engine_method(request: Request): method = getattr(_engine, method_name) result = method(*args, **kwargs) - # Serialize result - # Note: This assumes the result is JSON-serializable - # For complex types (tensors, etc.), you may need custom serialization - return {"status": "success", "result": result} + # Serialize result (convert tensors to SerializedTensor dicts) + serialized_result = serialize_value(result) + return {"status": "success", "result": serialized_result} except AttributeError as e: logger.error(f"Method '{method_name}' not found on engine: {e}") @@ -247,6 +251,9 @@ async def run_workflow(request: Request): status_code=400, detail="Missing 'data' field in request" ) + # Deserialize episode_data (may contain tensors) + episode_data = deserialize_value(episode_data) + # Dynamic import workflow try: module_path, class_name = workflow_path.rsplit(".", 1) @@ -291,7 +298,8 @@ async def run_workflow(request: Request): # Run episode try: - traj = await workflow.arun_episode(episode_data) + global _engine + traj = await workflow.arun_episode(_engine, episode_data) global app if check_trajectory_format and traj is not None: @@ -330,8 +338,12 @@ async def run_workflow(request: Request): accept_this = traj is not None and ( should_accept is None or should_accept(traj) ) + print(">>>>>>>>>", traj, accept_this, flush=True) + + # Serialize trajectory result (convert tensors to SerializedTensor dicts) if accept_this: - return {"status": "success", "result": traj} + serialized_traj = serialize_value(traj) + return {"status": "success", "result": serialized_traj} else: return {"status": "success", "result": None} except Exception as e: diff --git a/areal/scheduler/rpc/serialization.py b/areal/scheduler/rpc/serialization.py new file mode 100644 index 000000000..afa423ad4 --- /dev/null +++ b/areal/scheduler/rpc/serialization.py @@ -0,0 +1,213 @@ +"""Tensor serialization utilities for RPC communication. + +This module provides utilities to serialize and deserialize PyTorch tensors +for transmission over HTTP/JSON. Tensors are encoded as base64 strings with +metadata stored in Pydantic models. + +Assumptions: +- All tensors are on CPU +- Gradient tracking (requires_grad) is not preserved +""" + +import base64 +from typing import Any, Literal + +import torch +from pydantic import BaseModel, Field + + +class SerializedTensor(BaseModel): + """Pydantic model for serialized tensor with metadata. + + Attributes + ---------- + type : str + Type marker, always "tensor" + data : str + Base64-encoded tensor data + shape : List[int] + Tensor shape + dtype : str + String representation of dtype (e.g., "torch.float32") + """ + + type: Literal["tensor"] = Field(default="tensor") + data: str + shape: list[int] + dtype: str + + @classmethod + def from_tensor(cls, tensor: torch.Tensor) -> "SerializedTensor": + """Create SerializedTensor from a PyTorch tensor. + + Assumes tensor is on CPU or will be moved to CPU for serialization. + + Parameters + ---------- + tensor : torch.Tensor + Input tensor to serialize + + Returns + ------- + SerializedTensor + Serialized tensor with metadata + """ + # Move to CPU for serialization (detach to avoid gradient tracking) + cpu_tensor = tensor.detach().cpu() + + # Convert to bytes and encode as base64 + buffer = cpu_tensor.numpy().tobytes() + data_b64 = base64.b64encode(buffer).decode("utf-8") + + return cls( + data=data_b64, + shape=list(tensor.shape), + dtype=str(tensor.dtype), + ) + + def to_tensor(self) -> torch.Tensor: + """Reconstruct PyTorch tensor from serialized data. + + Returns CPU tensor without gradient tracking. + + Returns + ------- + torch.Tensor + Reconstructed CPU tensor + """ + # Decode base64 to bytes + buffer = base64.b64decode(self.data.encode("utf-8")) + + # Parse dtype string (e.g., "torch.float32" -> torch.float32) + dtype_str = self.dtype.replace("torch.", "") + dtype = getattr(torch, dtype_str) + + # Reconstruct tensor from bytes + import numpy as np + + np_array = np.frombuffer(buffer, dtype=self._torch_dtype_to_numpy(dtype)) + # Copy the array to make it writable before converting to tensor + np_array = np_array.copy() + tensor = torch.from_numpy(np_array).reshape(self.shape) + + # Cast to correct dtype (numpy might have different dtype) + tensor = tensor.to(dtype) + + return tensor + + @staticmethod + def _torch_dtype_to_numpy(torch_dtype: torch.dtype): + """Convert torch dtype to numpy dtype for buffer reading. + + Parameters + ---------- + torch_dtype : torch.dtype + PyTorch data type + + Returns + ------- + numpy.dtype + Corresponding NumPy data type + """ + import numpy as np + + dtype_map = { + torch.float32: np.float32, + torch.float64: np.float64, + torch.float16: np.float16, + torch.int32: np.int32, + torch.int64: np.int64, + torch.int16: np.int16, + torch.int8: np.int8, + torch.uint8: np.uint8, + torch.bool: np.bool_, + } + return dtype_map.get(torch_dtype, np.float32) + + +def serialize_value(value: Any) -> Any: + """Recursively serialize a value, converting tensors to SerializedTensor dicts. + + This function transparently handles: + - torch.Tensor -> SerializedTensor dict (CPU only, no gradient tracking) + - dict -> recursively serialize values + - list/tuple -> recursively serialize elements + - primitives (int, float, str, bool, None) -> unchanged + + Parameters + ---------- + value : Any + Value to serialize (can be nested structure) + + Returns + ------- + Any + Serialized value (JSON-compatible with SerializedTensor dicts) + """ + # Handle None + if value is None: + return None + + # Handle torch.Tensor + if isinstance(value, torch.Tensor): + return SerializedTensor.from_tensor(value).model_dump() + + # Handle dict - recursively serialize values + if isinstance(value, dict): + return {key: serialize_value(val) for key, val in value.items()} + + # Handle list - recursively serialize elements + if isinstance(value, list): + return [serialize_value(item) for item in value] + + # Handle tuple - convert to list and recursively serialize + if isinstance(value, tuple): + return [serialize_value(item) for item in value] + + # Primitives (int, float, str, bool) pass through unchanged + return value + + +def deserialize_value(value: Any) -> Any: + """Recursively deserialize a value, converting SerializedTensor dicts back to tensors. + + This function transparently handles: + - SerializedTensor dict -> torch.Tensor (CPU, no gradient tracking) + - dict -> recursively deserialize values + - list -> recursively deserialize elements + - primitives -> unchanged + + Parameters + ---------- + value : Any + Value to deserialize (potentially containing SerializedTensor dicts) + + Returns + ------- + Any + Deserialized value with torch.Tensor objects restored + """ + # Handle None + if value is None: + return None + + # Handle dict - check if it's a SerializedTensor + if isinstance(value, dict): + # Check for SerializedTensor marker + if value.get("type") == "tensor": + try: + serialized_tensor = SerializedTensor.model_validate(value) + return serialized_tensor.to_tensor() + except Exception: + # If parsing fails, treat as regular dict + pass + + # Regular dict - recursively deserialize values + return {key: deserialize_value(val) for key, val in value.items()} + + # Handle list - recursively deserialize elements + if isinstance(value, list): + return [deserialize_value(item) for item in value] + + # Primitives pass through unchanged + return value diff --git a/areal/tests/test_serialization.py b/areal/tests/test_serialization.py new file mode 100644 index 000000000..0ee0c29bb --- /dev/null +++ b/areal/tests/test_serialization.py @@ -0,0 +1,391 @@ +"""Pytest test suite for tensor serialization utilities.""" + +import pytest +import torch + +from areal.scheduler.rpc.serialization import ( + SerializedTensor, + deserialize_value, + serialize_value, +) + + +class TestSerializedTensor: + """Test suite for SerializedTensor Pydantic model.""" + + def test_from_tensor_float32(self): + """Test serialization of float32 tensor.""" + tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + serialized = SerializedTensor.from_tensor(tensor) + + assert serialized.type == "tensor" + assert serialized.shape == [3] + assert serialized.dtype == "torch.float32" + + def test_from_tensor_various_dtypes(self): + """Test serialization of tensors with various dtypes.""" + dtypes = [ + torch.float32, + torch.float64, + torch.int32, + torch.int64, + torch.bool, + torch.uint8, + ] + + for dtype in dtypes: + tensor = torch.tensor([1, 2, 3], dtype=dtype) + serialized = SerializedTensor.from_tensor(tensor) + assert serialized.dtype == str(dtype) + + def test_from_tensor_various_shapes(self): + """Test serialization of tensors with various shapes.""" + shapes = [ + (), # scalar + (5,), # 1D + (3, 4), # 2D + (2, 3, 4), # 3D + (2, 3, 4, 5), # 4D + ] + + for shape in shapes: + tensor = torch.randn(shape) + serialized = SerializedTensor.from_tensor(tensor) + assert serialized.shape == list(shape) + + def test_from_tensor_with_requires_grad(self): + """Test serialization ignores requires_grad flag.""" + tensor = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + serialized = SerializedTensor.from_tensor(tensor) + # Serialization should work but requires_grad is not preserved + assert serialized.type == "tensor" + + def test_roundtrip_float32(self): + """Test serialize-deserialize roundtrip for float32 tensor.""" + original = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + serialized = SerializedTensor.from_tensor(original) + reconstructed = serialized.to_tensor() + + assert torch.allclose(original, reconstructed) + assert reconstructed.dtype == original.dtype + assert reconstructed.shape == original.shape + + def test_roundtrip_various_dtypes(self): + """Test roundtrip for various dtypes.""" + test_cases = [ + (torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32), torch.float32), + (torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64), torch.float64), + (torch.tensor([1, 2, 3], dtype=torch.int32), torch.int32), + (torch.tensor([1, 2, 3], dtype=torch.int64), torch.int64), + (torch.tensor([True, False, True], dtype=torch.bool), torch.bool), + ] + + for original, expected_dtype in test_cases: + serialized = SerializedTensor.from_tensor(original) + reconstructed = serialized.to_tensor() + + assert reconstructed.dtype == expected_dtype + if expected_dtype == torch.bool: + assert torch.equal(original, reconstructed) + else: + assert torch.allclose(original.float(), reconstructed.float()) + + def test_roundtrip_ignores_requires_grad(self): + """Test roundtrip does not preserve requires_grad.""" + original = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + serialized = SerializedTensor.from_tensor(original) + reconstructed = serialized.to_tensor() + + # requires_grad is not preserved + assert reconstructed.requires_grad is False + assert torch.allclose(original.detach(), reconstructed) + + def test_empty_tensor(self): + """Test serialization of empty tensor.""" + tensor = torch.tensor([]) + serialized = SerializedTensor.from_tensor(tensor) + reconstructed = serialized.to_tensor() + + assert reconstructed.shape == torch.Size([0]) + assert torch.equal(tensor, reconstructed) + + def test_large_tensor(self): + """Test serialization of large tensor.""" + tensor = torch.randn(100, 100) + serialized = SerializedTensor.from_tensor(tensor) + reconstructed = serialized.to_tensor() + + assert torch.allclose(tensor, reconstructed) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_cuda_tensor(self): + """Test serialization of CUDA tensor (moves to CPU).""" + tensor = torch.tensor([1.0, 2.0, 3.0]).cuda() + serialized = SerializedTensor.from_tensor(tensor) + + # Serialization works with CUDA tensors + assert serialized.type == "tensor" + + # Reconstructed tensor is always on CPU + reconstructed = serialized.to_tensor() + assert reconstructed.device.type == "cpu" + assert torch.allclose(tensor.cpu(), reconstructed) + + +class TestSerializeValue: + """Test suite for serialize_value function.""" + + def test_serialize_none(self): + """Test serialization of None.""" + assert serialize_value(None) is None + + def test_serialize_primitives(self): + """Test serialization of primitive types.""" + assert serialize_value(42) == 42 + assert serialize_value(3.14) == 3.14 + assert serialize_value("hello") == "hello" + assert serialize_value(True) is True + + def test_serialize_tensor(self): + """Test serialization of torch tensor.""" + tensor = torch.tensor([1.0, 2.0, 3.0]) + result = serialize_value(tensor) + + assert isinstance(result, dict) + assert result["type"] == "tensor" + assert "data" in result + assert "shape" in result + assert "dtype" in result + + def test_serialize_list_of_primitives(self): + """Test serialization of list of primitives.""" + original = [1, 2, 3, "hello", True] + result = serialize_value(original) + + assert result == original + + def test_serialize_list_of_tensors(self): + """Test serialization of list of tensors.""" + tensors = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])] + result = serialize_value(tensors) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(item["type"] == "tensor" for item in result) + + def test_serialize_dict_of_primitives(self): + """Test serialization of dict with primitives.""" + original = {"a": 1, "b": "hello", "c": 3.14} + result = serialize_value(original) + + assert result == original + + def test_serialize_dict_of_tensors(self): + """Test serialization of dict with tensors.""" + tensors = { + "input_ids": torch.tensor([1, 2, 3]), + "attention_mask": torch.tensor([1, 1, 1]), + } + result = serialize_value(tensors) + + assert isinstance(result, dict) + assert all(result[key]["type"] == "tensor" for key in result) + + def test_serialize_nested_dict(self): + """Test serialization of nested dictionary.""" + nested = { + "level1": {"level2": {"tensor": torch.tensor([1.0, 2.0]), "value": 42}} + } + result = serialize_value(nested) + + assert isinstance(result, dict) + assert isinstance(result["level1"], dict) + assert isinstance(result["level1"]["level2"], dict) + assert result["level1"]["level2"]["tensor"]["type"] == "tensor" + assert result["level1"]["level2"]["value"] == 42 + + def test_serialize_mixed_structure(self): + """Test serialization of complex mixed structure.""" + mixed = { + "tensors": [torch.tensor([1.0]), torch.tensor([2.0])], + "metadata": {"batch_size": 32, "device": "cpu"}, + "mask": torch.tensor([True, False, True]), + } + result = serialize_value(mixed) + + assert isinstance(result["tensors"], list) + assert result["tensors"][0]["type"] == "tensor" + assert result["metadata"]["batch_size"] == 32 + assert result["mask"]["type"] == "tensor" + + def test_serialize_tuple(self): + """Test serialization of tuple (converts to list).""" + original = (1, 2, torch.tensor([3.0])) + result = serialize_value(original) + + assert isinstance(result, list) + assert result[0] == 1 + assert result[1] == 2 + assert result[2]["type"] == "tensor" + + +class TestDeserializeValue: + """Test suite for deserialize_value function.""" + + def test_deserialize_none(self): + """Test deserialization of None.""" + assert deserialize_value(None) is None + + def test_deserialize_primitives(self): + """Test deserialization of primitive types.""" + assert deserialize_value(42) == 42 + assert deserialize_value(3.14) == 3.14 + assert deserialize_value("hello") == "hello" + assert deserialize_value(True) is True + + def test_deserialize_tensor(self): + """Test deserialization of serialized tensor.""" + tensor = torch.tensor([1.0, 2.0, 3.0]) + serialized = serialize_value(tensor) + result = deserialize_value(serialized) + + assert isinstance(result, torch.Tensor) + assert torch.allclose(tensor, result) + + def test_deserialize_list_of_primitives(self): + """Test deserialization of list of primitives.""" + original = [1, 2, 3, "hello", True] + result = deserialize_value(original) + + assert result == original + + def test_deserialize_list_of_tensors(self): + """Test deserialization of list of tensors.""" + tensors = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])] + serialized = serialize_value(tensors) + result = deserialize_value(serialized) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(item, torch.Tensor) for item in result) + assert torch.allclose(tensors[0], result[0]) + assert torch.allclose(tensors[1], result[1]) + + def test_deserialize_dict_of_tensors(self): + """Test deserialization of dict with tensors.""" + original = { + "input_ids": torch.tensor([1, 2, 3]), + "attention_mask": torch.tensor([1, 1, 1]), + } + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert isinstance(result, dict) + assert all(isinstance(result[key], torch.Tensor) for key in result) + assert torch.equal(original["input_ids"], result["input_ids"]) + assert torch.equal(original["attention_mask"], result["attention_mask"]) + + def test_deserialize_nested_structure(self): + """Test deserialization of nested structure.""" + original = { + "level1": { + "tensor": torch.tensor([1.0, 2.0]), + "value": 42, + } + } + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert isinstance(result["level1"]["tensor"], torch.Tensor) + assert result["level1"]["value"] == 42 + assert torch.allclose(original["level1"]["tensor"], result["level1"]["tensor"]) + + def test_deserialize_invalid_tensor_dict(self): + """Test deserialization handles invalid tensor dict gracefully.""" + # Dict with type="tensor" but missing required fields + invalid = {"type": "tensor", "invalid_field": "value"} + result = deserialize_value(invalid) + + # Should treat as regular dict if parsing fails + assert isinstance(result, dict) + assert result["type"] == "tensor" + + +class TestRoundtrip: + """Test suite for full serialize-deserialize roundtrips.""" + + def test_roundtrip_simple_tensor(self): + """Test roundtrip for simple tensor.""" + original = torch.tensor([1.0, 2.0, 3.0]) + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert torch.allclose(original, result) + + def test_roundtrip_trajectory_dict(self): + """Test roundtrip for typical trajectory dictionary.""" + trajectory = { + "input_ids": torch.tensor([101, 102, 103, 104]), + "attention_mask": torch.tensor([1, 1, 1, 1]), + "rewards": torch.tensor([0.1, 0.2, 0.3, 0.4]), + "logprobs": torch.tensor([-1.0, -2.0, -3.0, -4.0]), + } + + serialized = serialize_value(trajectory) + result = deserialize_value(serialized) + + assert isinstance(result, dict) + assert set(result.keys()) == set(trajectory.keys()) + for key in trajectory: + assert torch.allclose(trajectory[key].float(), result[key].float()) + + def test_roundtrip_mixed_types(self): + """Test roundtrip for mixed type structure.""" + original = { + "tensors": [torch.tensor([1.0]), torch.tensor([2.0])], + "metadata": {"count": 2, "name": "test"}, + "value": 42, + "flag": True, + } + + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert len(result["tensors"]) == 2 + assert torch.allclose(original["tensors"][0], result["tensors"][0]) + assert result["metadata"] == original["metadata"] + assert result["value"] == original["value"] + assert result["flag"] == original["flag"] + + def test_roundtrip_with_none_values(self): + """Test roundtrip with None values in structure.""" + original = { + "tensor": torch.tensor([1.0, 2.0]), + "optional": None, + "nested": {"value": 42, "empty": None}, + } + + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert torch.allclose(original["tensor"], result["tensor"]) + assert result["optional"] is None + assert result["nested"]["value"] == 42 + assert result["nested"]["empty"] is None + + def test_roundtrip_empty_structures(self): + """Test roundtrip for empty structures.""" + test_cases = [ + {}, # Empty dict + [], # Empty list + {"empty_list": [], "empty_dict": {}}, # Nested empty + ] + + for original in test_cases: + serialized = serialize_value(original) + result = deserialize_value(serialized) + assert result == original + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/areal/tests/utils.py b/areal/tests/utils.py index 1666eff49..cb9bad52c 100644 --- a/areal/tests/utils.py +++ b/areal/tests/utils.py @@ -1,5 +1,13 @@ +import asyncio import os +import random +from typing import Any +import torch + +from areal.api.engine_api import InferenceEngine +from areal.api.workflow_api import RolloutWorkflow +from areal.experimental.openai.types import InteractionWithTokenLogpReward from areal.utils import logging logger = logging.getLogger("areal.tests.utils") @@ -26,3 +34,19 @@ def get_bool_env_var(name: str, default: str = "false") -> bool: def is_in_ci(): return get_bool_env_var("AREAL_IS_IN_CI") + + +class TestWorkflow(RolloutWorkflow): + async def arun_episode( + self, engine: InferenceEngine, data: dict[str, Any] + ) -> dict[str, Any] | None | dict[str, InteractionWithTokenLogpReward]: + await asyncio.sleep(0.1) + prompt_len = random.randint(2, 8) + gen_len = random.randint(2, 8) + seqlen = prompt_len + gen_len + return dict( + input_ids=torch.randint(0, 100, (seqlen,)), + attention_mask=torch.ones(seqlen, dtype=torch.bool), + loss_mask=torch.tensor([0] * prompt_len + [1] * gen_len, dtype=torch.bool), + rewards=torch.randn(1), + ) From a58c98456a035fda5b8b8a6870de8f7fb29cda64 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 10:40:09 +0800 Subject: [PATCH 11/52] fix test --- areal/controller/batch.py | 36 +--- areal/controller/rollout_controller.py | 47 +++-- areal/core/async_task_runner.py | 2 +- areal/scheduler/local_scheduler.py | 8 +- areal/scheduler/rpc/rpc_server.py | 24 +-- areal/scheduler/rpc/serialization.py | 122 +++++++++++- areal/tests/test_serialization.py | 259 ++++++++++++++++++++++++- areal/tests/utils.py | 15 +- 8 files changed, 437 insertions(+), 76 deletions(-) diff --git a/areal/controller/batch.py b/areal/controller/batch.py index 66df87573..53cd3a776 100644 --- a/areal/controller/batch.py +++ b/areal/controller/batch.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Union +from typing import Any import torch from torch import Tensor @@ -9,6 +9,7 @@ convert_list_to_dict, validate_dict_dataset, ) +from areal.utils.data import concat_padded_tensors from areal.utils.datapack import ffd_allocate from areal.utils.errors import FrameworkError @@ -17,7 +18,7 @@ class DistributedBatchMemory(DistributedBatch): dataset = None @classmethod - def from_dict(cls, dict_dataset: Dict[str, Union[Tensor, Any]]): + def from_dict(cls, dict_dataset: dict[str, Tensor | Any]): """Create a DistributedBatchMemory from dictionary format dataset. Parameters @@ -36,7 +37,7 @@ def from_dict(cls, dict_dataset: Dict[str, Union[Tensor, Any]]): return instance @classmethod - def from_list(cls, list_dataset: List[Dict[str, Union[Tensor, Any]]]): + def from_list(cls, list_dataset: list[dict[str, Tensor | Any]]): """Create a DistributedBatchMemory from list format dataset. Parameters @@ -103,9 +104,9 @@ def chunk_by_ffd( List of DistributedBatchMemory objects """ total_size = self._get_total_size() - assert ( - total_size % group_size == 0 - ), "tensor length must be devided by group_size" + assert total_size % group_size == 0, ( + "tensor length must be devided by group_size" + ) # Handle seqlen calculation for both tensor and scalar types if "seqlen" in self.dataset.keys(): @@ -209,7 +210,7 @@ def _get_total_size(self) -> int: # For scalar values, assume it's a single sample return 1 - def get_data(self) -> Dict[str, Union[torch.Tensor, Any]]: + def get_data(self) -> dict[str, torch.Tensor | Any]: """Get all data from the DistributedBatchMemory. Returns @@ -253,24 +254,7 @@ def concat(data: list["DistributedBatchMemory"]) -> "DistributedBatchMemory": batch.dataset = {} return batch - merged_data = {} - for batch in data: - for k, v in batch.dataset.items(): - if k in merged_data: - if isinstance(merged_data[k], torch.Tensor) and isinstance( - v, torch.Tensor - ): - merged_data[k] = torch.cat([merged_data[k], v], dim=0) - elif isinstance(merged_data[k], list) and isinstance(v, list): - merged_data[k] = merged_data[k] + v - else: - # Handle mixed types or scalar values - if isinstance(merged_data[k], list): - merged_data[k].append(v) - else: - merged_data[k] = [merged_data[k], v] - else: - merged_data[k] = v + merged_data = concat_padded_tensors([k.dataset for k in data]) result = DistributedBatchMemory.__new__(DistributedBatchMemory) result.dataset = merged_data return result @@ -319,7 +303,7 @@ def __setitem__(self, key, value): self.dataset[key] = value else: raise FrameworkError( - "FrameworkError", "DistributedBatchMemoryError", f"key must be str" + "FrameworkError", "DistributedBatchMemoryError", "key must be str" ) def __delitem__(self, key): diff --git a/areal/controller/rollout_controller.py b/areal/controller/rollout_controller.py index f162539af..9157d4b01 100644 --- a/areal/controller/rollout_controller.py +++ b/areal/controller/rollout_controller.py @@ -27,6 +27,7 @@ from areal.utils.data import cycle_dataloader CREATE_WORKER_TIMEOUT = 60.0 +TASK_RUNNER_MAX_QSIZE = 32768 @dataclass @@ -96,13 +97,15 @@ def __init__( # Staleness management self.staleness_manager: StalenessManager | None = None - self._pending_inputs: list[ - _RemoteRolloutTaskInput - ] = [] # Queue for inputs waiting for capacity + + self._pending_results: list[dict[str, Any]] = [] + self._pending_inputs: list[_RemoteRolloutTaskInput] = [] def initialize( self, alloc_mode: AllocationMode, + *args, + **kwargs, ): self.logger = logging.getLogger("[RolloutController]") @@ -111,11 +114,17 @@ def initialize( scheduling_config = SchedulingConfig(replicas=alloc_mode.gen.dp_size) # Use asyncio.run to call async scheduler methods synchronously - asyncio.run(self._async_initialize(scheduling_config)) + asyncio.run( + self._async_initialize( + scheduling_config, + *args, + **kwargs, + ) + ) # Initialize AsyncTaskRunner for task execution self.runner = AsyncTaskRunner( - max_queue_size=self.config.queue_size, + max_queue_size=TASK_RUNNER_MAX_QSIZE, enable_tracing=self.config.enable_rollout_tracing, ) self.runner.initialize(logger=self.logger) @@ -136,8 +145,7 @@ def initialize( ) async def _async_initialize( - self, - scheduling_config: SchedulingConfig, + self, scheduling_config: SchedulingConfig, *args, **kwargs ): # Create workers via scheduler self.logger.info("Creating workers via scheduler...") @@ -160,16 +168,27 @@ async def _async_initialize( engine_path = f"{engine_class.__module__}.{engine_class.__name__}" # Create and initialize engines on workers - for i, worker in enumerate(self.workers): - self.logger.info(f"Creating engine on worker {worker.id}...") - - # Create engine on worker - await self.scheduler.create_engine( + self.logger.info("Creating engines...") + tasks = [ + self.scheduler.create_engine( worker_id=worker.id, engine=engine_path, - init_kwargs=dict(config=self.config), + config=self.config, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + self.logger.info("Engine created on all workers!") + + self.logger.info("Calling engine initialization...") + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, method="initialize", *args, **kwargs ) - self.logger.info(f"Engine created on worker {worker.id}") + for worker in self.workers + ] + await asyncio.gather(*tasks) + self.logger.info("All engines are initialized...") def destroy(self): """Destroy the controller and clean up resources.""" diff --git a/areal/core/async_task_runner.py b/areal/core/async_task_runner.py index 9c6d3238c..66520f0ae 100644 --- a/areal/core/async_task_runner.py +++ b/areal/core/async_task_runner.py @@ -348,7 +348,7 @@ async def _run_async_loop(self): ) if self.enable_tracing and self.logger: self.logger.info( - f"AsyncTaskRunner: Completed task {tid}. " + f"AsyncTaskRunner: Completed task ID: {tid}. " f"Running: {len(running_tasks)}" ) except queue.Full: diff --git a/areal/scheduler/local_scheduler.py b/areal/scheduler/local_scheduler.py index dbcb82794..ef3159ec4 100644 --- a/areal/scheduler/local_scheduler.py +++ b/areal/scheduler/local_scheduler.py @@ -637,11 +637,11 @@ async def create_engine( f"Engine must be a string import path, got {type(engine)}", ) - # Build JSON payload + # Build JSON payload with serialized args and kwargs payload = { "engine": engine, - "init_args": list(args), - "init_kwargs": kwargs, + "init_args": serialize_value(list(args)), + "init_kwargs": serialize_value(kwargs), } # Send HTTP request to create engine @@ -858,7 +858,6 @@ async def async_call_engine( last_error = None - print(url) for attempt in range(1, max_retries + 1): # Check worker health before each attempt if worker_info.process.poll() is not None: @@ -880,7 +879,6 @@ async def async_call_engine( headers={"Content-Type": "application/json"}, timeout=7200.0, # 2 hours for long-running operations ) - print(response, payload, response.json()) result, should_retry, error_msg = self._handle_call_response( response, worker_id, method, attempt diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index 1b5efd057..ef8884db6 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -78,8 +78,9 @@ async def create_engine(request: Request): data = orjson.loads(body) engine_path = data.get("engine") - init_args = data.get("init_args", []) - init_kwargs = data.get("init_kwargs", {}) + # Deserialize init_args and init_kwargs (may contain tensors or dataclasses) + init_args = deserialize_value(data.get("init_args", [])) + init_kwargs = deserialize_value(data.get("init_kwargs", {})) if not engine_path: raise HTTPException( @@ -118,26 +119,16 @@ async def create_engine(request: Request): try: _engine = engine_class(*init_args, **init_kwargs) logger.info(f"Engine '{engine_path}' instantiated successfully") - except Exception as e: - logger.error(f"Failed to instantiate engine: {e}\n{traceback.format_exc()}") - raise HTTPException( - status_code=500, - detail=f"Failed to instantiate engine: {str(e)}", - ) - - # Initialize engine if it has initialize method - try: - result = _engine.initialize(*init_args, **init_kwargs) - logger.info(f"Engine initialized with result: {result}") return { "status": "success", "message": f"Engine '{engine_path}' created and initialized", - "result": result, + "result": None, } except Exception as e: - logger.error(f"Failed to initialize engine: {e}\n{traceback.format_exc()}") + logger.error(f"Failed to instantiate engine: {e}\n{traceback.format_exc()}") raise HTTPException( - status_code=500, detail=f"Failed to initialize engine: {str(e)}" + status_code=500, + detail=f"Failed to instantiate engine: {str(e)}", ) except HTTPException: @@ -338,7 +329,6 @@ async def run_workflow(request: Request): accept_this = traj is not None and ( should_accept is None or should_accept(traj) ) - print(">>>>>>>>>", traj, accept_this, flush=True) # Serialize trajectory result (convert tensors to SerializedTensor dicts) if accept_this: diff --git a/areal/scheduler/rpc/serialization.py b/areal/scheduler/rpc/serialization.py index afa423ad4..ba13ba5cd 100644 --- a/areal/scheduler/rpc/serialization.py +++ b/areal/scheduler/rpc/serialization.py @@ -1,15 +1,19 @@ -"""Tensor serialization utilities for RPC communication. +"""Tensor and dataclass serialization utilities for RPC communication. This module provides utilities to serialize and deserialize PyTorch tensors -for transmission over HTTP/JSON. Tensors are encoded as base64 strings with -metadata stored in Pydantic models. +and dataclass instances for transmission over HTTP/JSON. Tensors are encoded +as base64 strings and dataclasses preserve their type information with metadata +stored in Pydantic models. Assumptions: - All tensors are on CPU - Gradient tracking (requires_grad) is not preserved +- Dataclasses are reconstructed with their original types """ import base64 +import importlib +from dataclasses import is_dataclass from typing import Any, Literal import torch @@ -125,11 +129,81 @@ def _torch_dtype_to_numpy(torch_dtype: torch.dtype): return dtype_map.get(torch_dtype, np.float32) +class SerializedDataclass(BaseModel): + """Pydantic model for serialized dataclass with metadata. + + Attributes + ---------- + type : str + Type marker, always "dataclass" + class_path : str + Full import path to the dataclass (e.g., "areal.api.cli_args.InferenceEngineConfig") + data : dict + Dataclass fields as dictionary (recursively serialized) + """ + + type: Literal["dataclass"] = Field(default="dataclass") + class_path: str + data: dict[str, Any] + + @classmethod + def from_dataclass(cls, dataclass_instance: Any) -> "SerializedDataclass": + """Create SerializedDataclass from a dataclass instance. + + Parameters + ---------- + dataclass_instance : Any + Dataclass instance to serialize + + Returns + ------- + SerializedDataclass + Serialized dataclass with metadata + """ + class_path = ( + f"{dataclass_instance.__class__.__module__}." + f"{dataclass_instance.__class__.__name__}" + ) + # Get fields without recursive conversion to preserve nested dataclass instances + # We'll handle recursive serialization in serialize_value() + from dataclasses import fields + + data = {} + for field in fields(dataclass_instance): + data[field.name] = getattr(dataclass_instance, field.name) + + return cls(class_path=class_path, data=data) + + def to_dataclass(self) -> Any: + """Reconstruct dataclass instance from serialized data. + + Returns + ------- + Any + Reconstructed dataclass instance + + Raises + ------ + ImportError + If the dataclass module cannot be imported + AttributeError + If the dataclass class is not found in the module + """ + # Dynamically import the dataclass type + module_path, class_name = self.class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + dataclass_type = getattr(module, class_name) + + # Return the dataclass type and data for caller to deserialize fields + return dataclass_type, self.data + + def serialize_value(value: Any) -> Any: - """Recursively serialize a value, converting tensors to SerializedTensor dicts. + """Recursively serialize a value, converting tensors and dataclasses to serialized dicts. This function transparently handles: - torch.Tensor -> SerializedTensor dict (CPU only, no gradient tracking) + - dataclass instances -> SerializedDataclass dict (preserves type information) - dict -> recursively serialize values - list/tuple -> recursively serialize elements - primitives (int, float, str, bool, None) -> unchanged @@ -142,7 +216,7 @@ def serialize_value(value: Any) -> Any: Returns ------- Any - Serialized value (JSON-compatible with SerializedTensor dicts) + Serialized value (JSON-compatible with SerializedTensor and SerializedDataclass dicts) """ # Handle None if value is None: @@ -152,6 +226,20 @@ def serialize_value(value: Any) -> Any: if isinstance(value, torch.Tensor): return SerializedTensor.from_tensor(value).model_dump() + # Handle dataclass instances (check before dict, as dataclasses can be dict-like) + # Note: is_dataclass returns True for both classes and instances, so check it's not a type + if is_dataclass(value) and not isinstance(value, type): + serialized_dc = SerializedDataclass.from_dataclass(value) + # Recursively serialize the data fields + serialized_data = { + key: serialize_value(val) for key, val in serialized_dc.data.items() + } + return { + "type": "dataclass", + "class_path": serialized_dc.class_path, + "data": serialized_data, + } + # Handle dict - recursively serialize values if isinstance(value, dict): return {key: serialize_value(val) for key, val in value.items()} @@ -169,10 +257,11 @@ def serialize_value(value: Any) -> Any: def deserialize_value(value: Any) -> Any: - """Recursively deserialize a value, converting SerializedTensor dicts back to tensors. + """Recursively deserialize a value, converting SerializedTensor and SerializedDataclass dicts back. This function transparently handles: - SerializedTensor dict -> torch.Tensor (CPU, no gradient tracking) + - SerializedDataclass dict -> dataclass instance (reconstructed with original type) - dict -> recursively deserialize values - list -> recursively deserialize elements - primitives -> unchanged @@ -180,19 +269,34 @@ def deserialize_value(value: Any) -> Any: Parameters ---------- value : Any - Value to deserialize (potentially containing SerializedTensor dicts) + Value to deserialize (potentially containing SerializedTensor and SerializedDataclass dicts) Returns ------- Any - Deserialized value with torch.Tensor objects restored + Deserialized value with torch.Tensor and dataclass objects restored """ # Handle None if value is None: return None - # Handle dict - check if it's a SerializedTensor + # Handle dict - check if it's a SerializedDataclass or SerializedTensor if isinstance(value, dict): + # Check for SerializedDataclass marker (check before tensor) + if value.get("type") == "dataclass": + try: + serialized_dc = SerializedDataclass.model_validate(value) + dataclass_type, data = serialized_dc.to_dataclass() + # Recursively deserialize the fields + deserialized_data = { + key: deserialize_value(val) for key, val in data.items() + } + # Reconstruct the dataclass instance + return dataclass_type(**deserialized_data) + except Exception: + # If parsing fails, treat as regular dict + pass + # Check for SerializedTensor marker if value.get("type") == "tensor": try: diff --git a/areal/tests/test_serialization.py b/areal/tests/test_serialization.py index 0ee0c29bb..76b7836cb 100644 --- a/areal/tests/test_serialization.py +++ b/areal/tests/test_serialization.py @@ -1,15 +1,44 @@ -"""Pytest test suite for tensor serialization utilities.""" +"""Pytest test suite for tensor and dataclass serialization utilities.""" + +from dataclasses import dataclass import pytest import torch from areal.scheduler.rpc.serialization import ( + SerializedDataclass, SerializedTensor, deserialize_value, serialize_value, ) +# Test dataclasses +@dataclass +class SimpleConfig: + """Simple test dataclass.""" + + batch_size: int + learning_rate: float + name: str + + +@dataclass +class ConfigWithTensor: + """Dataclass containing a tensor field.""" + + data: torch.Tensor + label: str + + +@dataclass +class NestedConfig: + """Dataclass containing another dataclass.""" + + inner: SimpleConfig + outer_value: int + + class TestSerializedTensor: """Test suite for SerializedTensor Pydantic model.""" @@ -387,5 +416,233 @@ def test_roundtrip_empty_structures(self): assert result == original +class TestSerializedDataclass: + """Test suite for SerializedDataclass Pydantic model.""" + + def test_from_dataclass_simple(self): + """Test serialization of simple dataclass.""" + config = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + serialized = SerializedDataclass.from_dataclass(config) + + assert serialized.type == "dataclass" + assert "SimpleConfig" in serialized.class_path + assert serialized.data["batch_size"] == 32 + assert serialized.data["learning_rate"] == 0.001 + assert serialized.data["name"] == "test" + + def test_to_dataclass_simple(self): + """Test deserialization of simple dataclass.""" + config = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + serialized = SerializedDataclass.from_dataclass(config) + dataclass_type, data = serialized.to_dataclass() + + reconstructed = dataclass_type(**data) + assert isinstance(reconstructed, SimpleConfig) + assert reconstructed.batch_size == 32 + assert reconstructed.learning_rate == 0.001 + assert reconstructed.name == "test" + + def test_roundtrip_simple_dataclass(self): + """Test serialize-deserialize roundtrip for simple dataclass.""" + original = SimpleConfig(batch_size=64, learning_rate=0.01, name="experiment") + serialized = SerializedDataclass.from_dataclass(original) + dataclass_type, data = serialized.to_dataclass() + reconstructed = dataclass_type(**data) + + assert reconstructed.batch_size == original.batch_size + assert reconstructed.learning_rate == original.learning_rate + assert reconstructed.name == original.name + + +class TestSerializeValueDataclass: + """Test suite for serialize_value with dataclasses.""" + + def test_serialize_simple_dataclass(self): + """Test serialization of simple dataclass.""" + config = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + result = serialize_value(config) + + assert isinstance(result, dict) + assert result["type"] == "dataclass" + assert "SimpleConfig" in result["class_path"] + assert result["data"]["batch_size"] == 32 + + def test_serialize_dataclass_with_tensor(self): + """Test serialization of dataclass containing tensor.""" + config = ConfigWithTensor(data=torch.tensor([1.0, 2.0, 3.0]), label="example") + result = serialize_value(config) + + assert result["type"] == "dataclass" + assert result["data"]["label"] == "example" + # Tensor should be serialized within dataclass + assert result["data"]["data"]["type"] == "tensor" + + def test_serialize_nested_dataclass(self): + """Test serialization of nested dataclass.""" + inner = SimpleConfig(batch_size=16, learning_rate=0.01, name="inner") + outer = NestedConfig(inner=inner, outer_value=42) + result = serialize_value(outer) + + assert result["type"] == "dataclass" + assert "NestedConfig" in result["class_path"] + # Inner dataclass should also be serialized + assert result["data"]["inner"]["type"] == "dataclass" + assert result["data"]["inner"]["data"]["batch_size"] == 16 + assert result["data"]["outer_value"] == 42 + + def test_serialize_list_of_dataclasses(self): + """Test serialization of list containing dataclasses.""" + configs = [ + SimpleConfig(batch_size=32, learning_rate=0.001, name="config1"), + SimpleConfig(batch_size=64, learning_rate=0.002, name="config2"), + ] + result = serialize_value(configs) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(item["type"] == "dataclass" for item in result) + assert result[0]["data"]["batch_size"] == 32 + assert result[1]["data"]["batch_size"] == 64 + + def test_serialize_dict_with_dataclass_values(self): + """Test serialization of dict with dataclass values.""" + data = { + "config": SimpleConfig(batch_size=32, learning_rate=0.001, name="test"), + "value": 42, + } + result = serialize_value(data) + + assert result["config"]["type"] == "dataclass" + assert result["config"]["data"]["batch_size"] == 32 + assert result["value"] == 42 + + +class TestDeserializeValueDataclass: + """Test suite for deserialize_value with dataclasses.""" + + def test_deserialize_simple_dataclass(self): + """Test deserialization of simple dataclass.""" + config = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + serialized = serialize_value(config) + result = deserialize_value(serialized) + + assert isinstance(result, SimpleConfig) + assert result.batch_size == 32 + assert result.learning_rate == 0.001 + assert result.name == "test" + + def test_deserialize_dataclass_with_tensor(self): + """Test deserialization of dataclass containing tensor.""" + original_tensor = torch.tensor([1.0, 2.0, 3.0]) + config = ConfigWithTensor(data=original_tensor, label="example") + serialized = serialize_value(config) + result = deserialize_value(serialized) + + assert isinstance(result, ConfigWithTensor) + assert result.label == "example" + assert isinstance(result.data, torch.Tensor) + assert torch.allclose(original_tensor, result.data) + + def test_deserialize_nested_dataclass(self): + """Test deserialization of nested dataclass.""" + inner = SimpleConfig(batch_size=16, learning_rate=0.01, name="inner") + outer = NestedConfig(inner=inner, outer_value=42) + serialized = serialize_value(outer) + result = deserialize_value(serialized) + + assert isinstance(result, NestedConfig) + assert isinstance(result.inner, SimpleConfig) + assert result.inner.batch_size == 16 + assert result.inner.learning_rate == 0.01 + assert result.outer_value == 42 + + def test_deserialize_list_of_dataclasses(self): + """Test deserialization of list containing dataclasses.""" + configs = [ + SimpleConfig(batch_size=32, learning_rate=0.001, name="config1"), + SimpleConfig(batch_size=64, learning_rate=0.002, name="config2"), + ] + serialized = serialize_value(configs) + result = deserialize_value(serialized) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(item, SimpleConfig) for item in result) + assert result[0].batch_size == 32 + assert result[1].batch_size == 64 + + +class TestRoundtripDataclass: + """Test suite for full serialize-deserialize roundtrips with dataclasses.""" + + def test_roundtrip_simple_dataclass(self): + """Test roundtrip for simple dataclass.""" + original = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert result.batch_size == original.batch_size + assert result.learning_rate == original.learning_rate + assert result.name == original.name + + def test_roundtrip_dataclass_with_tensor(self): + """Test roundtrip for dataclass with tensor field.""" + original = ConfigWithTensor(data=torch.tensor([1.0, 2.0, 3.0]), label="test") + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert isinstance(result, ConfigWithTensor) + assert result.label == original.label + assert torch.allclose(original.data, result.data) + + def test_roundtrip_nested_dataclass(self): + """Test roundtrip for nested dataclass.""" + inner = SimpleConfig(batch_size=16, learning_rate=0.01, name="inner") + original = NestedConfig(inner=inner, outer_value=42) + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert isinstance(result, NestedConfig) + assert isinstance(result.inner, SimpleConfig) + assert result.inner.batch_size == original.inner.batch_size + assert result.outer_value == original.outer_value + + def test_roundtrip_mixed_dataclass_and_tensor(self): + """Test roundtrip for structure with both dataclasses and tensors.""" + config = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + original = { + "config": config, + "tensor": torch.tensor([1.0, 2.0, 3.0]), + "metadata": {"count": 3, "type": "experiment"}, + } + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert isinstance(result["config"], SimpleConfig) + assert result["config"].batch_size == 32 + assert isinstance(result["tensor"], torch.Tensor) + assert torch.allclose(original["tensor"], result["tensor"]) + assert result["metadata"] == original["metadata"] + + def test_roundtrip_list_of_mixed_types(self): + """Test roundtrip for list containing dataclasses, tensors, and primitives.""" + config = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + original = [ + config, + torch.tensor([1.0, 2.0]), + 42, + "string", + ] + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert isinstance(result[0], SimpleConfig) + assert result[0].batch_size == 32 + assert isinstance(result[1], torch.Tensor) + assert torch.allclose(original[1], result[1]) + assert result[2] == 42 + assert result[3] == "string" + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/areal/tests/utils.py b/areal/tests/utils.py index cb9bad52c..3d61b45ed 100644 --- a/areal/tests/utils.py +++ b/areal/tests/utils.py @@ -45,8 +45,17 @@ async def arun_episode( gen_len = random.randint(2, 8) seqlen = prompt_len + gen_len return dict( - input_ids=torch.randint(0, 100, (seqlen,)), - attention_mask=torch.ones(seqlen, dtype=torch.bool), - loss_mask=torch.tensor([0] * prompt_len + [1] * gen_len, dtype=torch.bool), + input_ids=torch.randint( + 0, + 100, + ( + 1, + seqlen, + ), + ), + attention_mask=torch.ones(1, seqlen, dtype=torch.bool), + loss_mask=torch.tensor( + [0] * prompt_len + [1] * gen_len, dtype=torch.bool + ).unsqueeze(0), rewards=torch.randn(1), ) From d14b53c2be6497ff7ec1e9d37212b97120e1fda2 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 14:01:02 +0800 Subject: [PATCH 12/52] add scheduler and rollout controller test --- areal/experimental/openai/tool_call_parser.py | 12 +- areal/scheduler/rpc/rpc_server.py | 2 +- areal/tests/test_local_scheduler.py | 166 ++- areal/tests/test_rollout_controller.py | 994 ++++++++++++++++++ pyproject.toml | 3 + 5 files changed, 1152 insertions(+), 25 deletions(-) create mode 100644 areal/tests/test_rollout_controller.py diff --git a/areal/experimental/openai/tool_call_parser.py b/areal/experimental/openai/tool_call_parser.py index 75f011a2c..70fba9640 100644 --- a/areal/experimental/openai/tool_call_parser.py +++ b/areal/experimental/openai/tool_call_parser.py @@ -2,11 +2,9 @@ import uuid from typing import Any -from openai.types.chat.chat_completion_message_function_tool_call import ( - ChatCompletionMessageFunctionToolCall, - Function, -) -from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall +from openai.types.chat import ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_message_tool_call import Function +from openai.types.responses import ResponseFunctionToolCall from areal.utils import logging @@ -21,7 +19,7 @@ def process_tool_calls( finish_reason: str, use_responses: bool = False, ) -> tuple[ - list[ChatCompletionMessageFunctionToolCall | ResponseFunctionToolCall] | None, + list[ChatCompletionMessageToolCall | ResponseFunctionToolCall] | None, str, str, ]: @@ -69,7 +67,7 @@ def process_tool_calls( ] else: tool_calls = [ - ChatCompletionMessageFunctionToolCall( + ChatCompletionMessageToolCall( type="function", id=f"call_{uuid.uuid4().hex[:24]}", function=Function( diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index ef8884db6..d90e2217f 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -358,7 +358,7 @@ def main(): "--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)" ) - args, unknown = parser.parse_known_args() + args, _ = parser.parse_known_args() port = args.port logger.info(f"Starting RPC server on {args.host}:{port}") diff --git a/areal/tests/test_local_scheduler.py b/areal/tests/test_local_scheduler.py index 1e700a784..dd8d52e6d 100644 --- a/areal/tests/test_local_scheduler.py +++ b/areal/tests/test_local_scheduler.py @@ -1,20 +1,3 @@ -""" -Comprehensive unit tests for LocalScheduler. - -This test suite covers: -1. Initialization and GPU detection -2. Worker creation with various configurations -3. GPU allocation strategies (new, colocate, round-robin) -4. Port allocation and tracking -5. Worker health checks and readiness -6. Engine creation and method calls (sync and async) -7. Error handling for all exception types -8. Resource cleanup and process termination -9. Edge cases (duplicate workers, worker not found, GPU exhaustion, port conflicts) -10. Log file handling -11. HTTP client interactions -""" - import asyncio import os import time @@ -1560,3 +1543,152 @@ def test_log_directory_with_special_characters(self, tmp_path): assert log_dir.exists() assert scheduler.log_dir == log_dir + + +class TestRPCWorkflowIntegration: + """Integration tests for RPC workflow endpoint execution.""" + + @pytest.mark.integration + @pytest.mark.asyncio + async def test_run_workflow_endpoint_basic(self, tmp_path): + """Should execute workflow via RPC endpoint with real worker processes.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + try: + config = SchedulingConfig(replicas=1) + worker_ids = scheduler.create_workers(role="test", scheduler_config=config) + assert len(worker_ids) == 1 + + workers = scheduler.get_workers(role="test", timeout=30.0) + worker_id = workers[0].id + + result = await scheduler.async_call_engine( + worker_id=worker_id, + method="run_workflow", + workflow="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + data={"test_id": 123, "value": "test_data"}, + ) + + assert result is not None + assert "input_ids" in result + assert "attention_mask" in result + assert "loss_mask" in result + assert "rewards" in result + + from areal.tests.utils import TestWorkflow + + workflow = TestWorkflow() + ref_result = await workflow.arun_episode(None, None) + assert result.keys() == ref_result.keys() + for key in result: + assert type(result[key]) is type(ref_result[key]) + + finally: + scheduler.delete_workers(role="test") + + @pytest.mark.integration + @pytest.mark.asyncio + async def test_run_workflow_endpoint_multiple_calls(self, tmp_path): + """Should handle multiple sequential workflow calls via RPC endpoint.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + try: + config = SchedulingConfig(replicas=1) + scheduler.create_workers(role="test", scheduler_config=config) + + workers = scheduler.get_workers(role="test", timeout=30.0) + worker_id = workers[0].id + + from areal.tests.utils import TestWorkflow + + workflow = TestWorkflow() + ref_result = await workflow.arun_episode(None, None) + + for i in range(3): + result = await scheduler.async_call_engine( + worker_id=worker_id, + method="run_workflow", + workflow="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + data={"test_id": i, "value": f"test_{i}"}, + ) + + assert result is not None + assert result.keys() == ref_result.keys() + for key in result: + assert type(result[key]) is type(ref_result[key]) + + finally: + scheduler.delete_workers(role="test") + + @pytest.mark.integration + @pytest.mark.asyncio + async def test_run_workflow_serialization(self, tmp_path): + """Should correctly serialize and deserialize tensors through RPC.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + try: + config = SchedulingConfig(replicas=1) + scheduler.create_workers(role="test", scheduler_config=config) + + workers = scheduler.get_workers(role="test", timeout=30.0) + worker_id = workers[0].id + + result = await scheduler.async_call_engine( + worker_id=worker_id, + method="run_workflow", + workflow="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + data={"test_id": 456, "value": "serialization_test"}, + ) + + assert result is not None + assert "input_ids" in result + assert "attention_mask" in result + assert "loss_mask" in result + assert "rewards" in result + + import torch + + assert isinstance(result["input_ids"], torch.Tensor) + assert isinstance(result["attention_mask"], torch.Tensor) + assert isinstance(result["loss_mask"], torch.Tensor) + assert isinstance(result["rewards"], torch.Tensor) + + assert result["input_ids"].dtype == torch.long + assert result["attention_mask"].dtype == torch.bool + assert result["loss_mask"].dtype == torch.bool + + finally: + scheduler.delete_workers(role="test") + + @pytest.mark.integration + @pytest.mark.asyncio + async def test_run_workflow_with_kwargs(self, tmp_path): + """Should support workflow instantiation with custom kwargs.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + try: + config = SchedulingConfig(replicas=1) + scheduler.create_workers(role="test", scheduler_config=config) + + workers = scheduler.get_workers(role="test", timeout=30.0) + worker_id = workers[0].id + + result = await scheduler.async_call_engine( + worker_id=worker_id, + method="run_workflow", + workflow="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + data={"test_id": 789, "custom_param": "value"}, + ) + + assert result is not None + assert "input_ids" in result + assert "attention_mask" in result + assert "loss_mask" in result + assert "rewards" in result + + finally: + scheduler.delete_workers(role="test") diff --git a/areal/tests/test_rollout_controller.py b/areal/tests/test_rollout_controller.py new file mode 100644 index 000000000..c02dcab41 --- /dev/null +++ b/areal/tests/test_rollout_controller.py @@ -0,0 +1,994 @@ +import asyncio +import queue +from concurrent.futures import Future +from unittest.mock import Mock, patch + +import pytest +import torch + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import InferenceEngineConfig +from areal.api.io_struct import ModelRequest, ParamSpec, WeightUpdateMeta +from areal.api.scheduler_api import Worker +from areal.controller.batch import DistributedBatchMemory +from areal.controller.rollout_controller import RolloutController +from areal.core.async_task_runner import TaskQueueFullError + + +class MockScheduler: + def __init__(self): + self.workers = [] + self.call_count = 0 + self.engine_calls = [] + + def create_workers(self, role, scheduler_config, *args, **kwargs): + worker_ids = [f"{role}/{i}" for i in range(scheduler_config.replicas)] + self.workers = [ + Worker(id=wid, ip="127.0.0.1", ports=["8000", "8001"]) for wid in worker_ids + ] + return worker_ids + + def get_workers(self, role, timeout=None): + return self.workers + + async def create_engine(self, worker_id, engine, config): + pass + + async def async_call_engine(self, worker_id, method, *args, **kwargs): + self.engine_calls.append((worker_id, method, args, kwargs)) + self.call_count += 1 + + if method == "run_workflow": + await asyncio.sleep(0.01) + return { + "input_ids": torch.randint(0, 100, (1, 10)), + "attention_mask": torch.ones(1, 10, dtype=torch.bool), + "loss_mask": torch.tensor( + [0] * 5 + [1] * 5, dtype=torch.bool + ).unsqueeze(0), + "rewards": torch.randn(1), + } + elif method == "agenerate": + return Mock() + return None + + def call_engine(self, worker_id, method, *args, **kwargs): + self.engine_calls.append((worker_id, method, args, kwargs)) + + # For weight update methods that await call_engine, return a coroutine + if method in [ + "update_weights_from_distributed", + "update_weights_from_disk", + "init_weights_update_group", + ]: + return self._async_call_engine_internal(worker_id, method, *args, **kwargs) + + return None + + async def _async_call_engine_internal(self, worker_id, method, *args, **kwargs): + await asyncio.sleep(0.001) + return None + + def delete_workers(self, role): + self.workers.clear() + + +class MockInferenceEngine: + @classmethod + def __module__(cls): + return "areal.tests.test_rollout_controller" + + @classmethod + def __name__(cls): + return "MockInferenceEngine" + + +class TestRolloutControllerInitialization: + def test_constructor(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + assert controller.config == config + assert controller.scheduler == scheduler + assert controller.workers == [] + assert controller._current_worker_idx == 0 + assert controller._version == 0 + assert controller.runner is None + assert controller.executor is None + assert controller.staleness_manager is None + + def test_initialize_creates_workers(self): + config = InferenceEngineConfig( + consumer_batch_size=16, + max_head_offpolicyness=2, + enable_rollout_tracing=False, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + assert len(controller.workers) == 2 + assert controller.runner is not None + assert controller.executor is not None + assert controller.staleness_manager is not None + + controller.destroy() + + def test_initialize_creates_staleness_manager(self): + config = InferenceEngineConfig( + consumer_batch_size=32, + max_head_offpolicyness=5, + max_concurrent_rollouts=100, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + assert controller.staleness_manager.max_concurrent_rollouts == 100 + assert controller.staleness_manager.consumer_batch_size == 32 + assert controller.staleness_manager.max_staleness == 5 + + controller.destroy() + + def test_initialize_uses_consumer_batch_size_as_fallback(self): + config = InferenceEngineConfig( + consumer_batch_size=64, + max_head_offpolicyness=3, + max_concurrent_rollouts=None, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + assert controller.staleness_manager.max_concurrent_rollouts == 64 + + controller.destroy() + + def test_initialize_with_tracing_enabled(self): + config = InferenceEngineConfig( + consumer_batch_size=16, + max_head_offpolicyness=2, + enable_rollout_tracing=True, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + assert controller.runner.enable_tracing is True + + controller.destroy() + + +class TestRolloutControllerDestroy: + def test_destroy_cleans_up_resources(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + assert controller.runner is not None + assert controller.executor is not None + assert len(controller.workers) > 0 + + controller.destroy() + + assert controller.runner is None + assert controller.executor is None + assert len(controller.workers) == 0 + + def test_destroy_deletes_workers_via_scheduler(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + assert len(scheduler.workers) == 2 + + controller.destroy() + + assert len(scheduler.workers) == 0 + + def test_destroy_handles_scheduler_error(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + scheduler.delete_workers = Mock(side_effect=Exception("Test error")) + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + controller.destroy() + + +class TestRolloutControllerCapacity: + def test_get_capacity_initial_state(self): + config = InferenceEngineConfig( + consumer_batch_size=16, + max_concurrent_rollouts=32, + max_head_offpolicyness=2, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + capacity = controller.get_capacity() + assert capacity == 32 + + controller.destroy() + + def test_get_capacity_uses_version(self): + config = InferenceEngineConfig( + consumer_batch_size=8, + max_concurrent_rollouts=1000, + max_head_offpolicyness=2, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + capacity_v0 = controller.get_capacity() + + controller.set_version(5) + capacity_v5 = controller.get_capacity() + + assert capacity_v5 > capacity_v0 + + controller.destroy() + + +class TestRolloutControllerWorkerSelection: + def test_choose_worker_round_robin(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d3") + controller.initialize(alloc_mode=alloc_mode) + + worker_ids = [] + for _ in range(6): + worker = controller._choose_worker() + worker_ids.append(worker.id) + + assert worker_ids[0] == "rollout/0" + assert worker_ids[1] == "rollout/1" + assert worker_ids[2] == "rollout/2" + assert worker_ids[3] == "rollout/0" + assert worker_ids[4] == "rollout/1" + assert worker_ids[5] == "rollout/2" + + controller.destroy() + + +class TestRolloutControllerSubmitAndWait: + def test_submit_adds_to_pending_queue(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + data = {"test": "data"} + controller.submit( + data, workflow_path="areal.tests.utils.TestWorkflow", workflow_kwargs={} + ) + + assert len(controller._pending_inputs) == 1 + assert controller._pending_inputs[0].data == data + + controller.destroy() + + def test_submit_multiple_requests(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + for i in range(5): + controller.submit( + {"id": i}, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + assert len(controller._pending_inputs) == 5 + + controller.destroy() + + def test_wait_returns_distributed_batch(self): + config = InferenceEngineConfig( + consumer_batch_size=16, max_concurrent_rollouts=50 + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + for i in range(3): + controller.submit( + {"id": i}, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + batch = controller.wait(count=3, timeout=5.0) + + assert isinstance(batch, DistributedBatchMemory) + assert len(batch) == 3 + + controller.destroy() + + def test_wait_timeout_when_insufficient_results(self): + config = InferenceEngineConfig( + consumer_batch_size=16, max_concurrent_rollouts=10 + ) + scheduler = MockScheduler() + + async def slow_workflow(*args, **kwargs): + await asyncio.sleep(10.0) + return None + + scheduler.async_call_engine = slow_workflow + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + controller.submit( + {"id": 0}, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + with pytest.raises(TimeoutError, match="Timed out waiting for"): + controller.wait(count=1, timeout=0.2) + + controller.destroy() + + def test_wait_handles_rejected_rollouts(self): + config = InferenceEngineConfig( + consumer_batch_size=16, max_concurrent_rollouts=20 + ) + scheduler = MockScheduler() + + call_count = 0 + + async def mixed_results(*args, **kwargs): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) + if call_count % 2 == 0: + return None + return { + "input_ids": torch.randint(0, 100, (1, 10)), + "attention_mask": torch.ones(1, 10, dtype=torch.bool), + "loss_mask": torch.ones(1, 10, dtype=torch.bool), + "rewards": torch.randn(1), + } + + scheduler.async_call_engine = mixed_results + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + for i in range(6): + controller.submit( + {"id": i}, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + batch = controller.wait(count=3, timeout=2.0) + assert len(batch) == 3 + + controller.destroy() + + +class TestRolloutControllerBatchOperations: + def test_rollout_batch_submits_all_data(self): + config = InferenceEngineConfig( + consumer_batch_size=16, max_concurrent_rollouts=50 + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + batch_data = [{"id": i, "value": f"item_{i}"} for i in range(4)] + batch = controller.rollout_batch( + batch_data, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + assert isinstance(batch, DistributedBatchMemory) + assert len(batch) == 4 + + controller.destroy() + + def test_rollout_batch_waits_for_all_results(self): + config = InferenceEngineConfig( + consumer_batch_size=16, max_concurrent_rollouts=100 + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + batch_data = [{"id": i} for i in range(10)] + batch = controller.rollout_batch( + batch_data, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + assert len(batch) == 10 + + controller.destroy() + + +class TestRolloutControllerVersionManagement: + def test_get_version_initial(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + assert controller.get_version() == 0 + + def test_set_version_updates_controller_version(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + controller.set_version(42) + assert controller.get_version() == 42 + + controller.destroy() + + def test_set_version_calls_workers(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + controller.set_version(10) + + version_calls = [ + call for call in scheduler.engine_calls if call[1] == "set_version" + ] + assert len(version_calls) == 2 + + controller.destroy() + + def test_set_version_handles_worker_error(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + + def failing_call(*args, **kwargs): + raise Exception("Worker error") + + scheduler.call_engine = failing_call + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + controller.set_version(5) + + +class TestRolloutControllerWeightUpdates: + def test_init_weights_update_group_returns_future(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + meta = WeightUpdateMeta(type="disk", path="/tmp/test") + future = controller.init_weights_update_group(meta) + + assert isinstance(future, Future) + future.result(timeout=5.0) + + controller.destroy() + + def test_update_weights_from_distributed_returns_future(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + meta = WeightUpdateMeta(type="disk", path="/tmp/test") + param_specs = [ParamSpec(name="test", shape=(10, 10), dtype="float32")] + future = controller.update_weights_from_distributed(meta, param_specs) + + assert isinstance(future, Future) + future.result(timeout=5.0) + + controller.destroy() + + def test_update_weights_from_disk_returns_future(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + meta = WeightUpdateMeta(type="disk", path="/tmp/test") + future = controller.update_weights_from_disk(meta) + + assert isinstance(future, Future) + + controller.destroy() + + +class TestRolloutControllerLifecycle: + def test_pause_calls_all_workers(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d3") + controller.initialize(alloc_mode=alloc_mode) + + controller.pause() + + pause_calls = [call for call in scheduler.engine_calls if call[1] == "pause"] + assert len(pause_calls) == 3 + + controller.destroy() + + def test_resume_calls_all_workers(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d3") + controller.initialize(alloc_mode=alloc_mode) + + controller.resume() + + resume_calls = [call for call in scheduler.engine_calls if call[1] == "resume"] + assert len(resume_calls) == 3 + + controller.destroy() + + def test_pause_handles_worker_error(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + + def failing_call(*args, **kwargs): + raise Exception("Worker error") + + scheduler.call_engine = failing_call + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + controller.pause() + + def test_resume_handles_worker_error(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + + def failing_call(*args, **kwargs): + raise Exception("Worker error") + + scheduler.call_engine = failing_call + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + controller.resume() + + +class TestRolloutControllerAgenerate: + def test_agenerate_chooses_worker(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + req = ModelRequest(input_ids=[1, 2, 3, 4, 5]) + + async def test_agenerate(): + result = await controller.agenerate(req) + return result + + asyncio.run(test_agenerate()) + + agenerate_calls = [ + call for call in scheduler.engine_calls if call[1] == "agenerate" + ] + assert len(agenerate_calls) == 1 + assert agenerate_calls[0][3]["req"] == req + + controller.destroy() + + def test_agenerate_round_robin(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d3") + controller.initialize(alloc_mode=alloc_mode) + + async def test_multiple_agenerate(): + for _ in range(6): + req = ModelRequest(input_ids=[1, 2, 3]) + await controller.agenerate(req) + + asyncio.run(test_multiple_agenerate()) + + agenerate_calls = [ + call for call in scheduler.engine_calls if call[1] == "agenerate" + ] + worker_ids = [call[0] for call in agenerate_calls] + + assert worker_ids[0] == "rollout/0" + assert worker_ids[1] == "rollout/1" + assert worker_ids[2] == "rollout/2" + assert worker_ids[3] == "rollout/0" + + controller.destroy() + + +class TestRolloutControllerErrorHandling: + def test_commit_raises_on_queue_full(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + with patch.object( + controller.runner, "submit", side_effect=TaskQueueFullError("Queue full") + ): + controller.submit( + {"id": 0}, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + with pytest.raises(queue.Full): + controller._commit_one_to_runner() + + controller.destroy() + + def test_wait_returns_empty_batch_on_no_results(self): + config = InferenceEngineConfig( + consumer_batch_size=16, max_concurrent_rollouts=50 + ) + scheduler = MockScheduler() + + async def reject_all(*args, **kwargs): + await asyncio.sleep(0.01) + return None + + scheduler.async_call_engine = reject_all + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + with pytest.raises(TimeoutError): + controller.wait(count=1, timeout=0.5) + + controller.destroy() + + +class TestRolloutControllerIntegration: + def test_end_to_end_workflow(self): + config = InferenceEngineConfig( + consumer_batch_size=8, + max_concurrent_rollouts=20, + max_head_offpolicyness=2, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + capacity = controller.get_capacity() + assert capacity == 20 + + for i in range(5): + controller.submit( + {"id": i}, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + batch = controller.wait(count=5, timeout=5.0) + assert len(batch) == 5 + + controller.set_version(1) + assert controller.get_version() == 1 + + controller.destroy() + + def test_multiple_batch_cycles(self): + config = InferenceEngineConfig( + consumer_batch_size=4, + max_concurrent_rollouts=50, + max_head_offpolicyness=5, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + for cycle in range(3): + batch_data = [{"id": i, "cycle": cycle} for i in range(4)] + batch = controller.rollout_batch( + batch_data, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + assert len(batch) == 4 + + controller.destroy() + + +class TestRolloutControllerNotImplemented: + def test_register_callback_not_implemented(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + with pytest.raises(NotImplementedError): + controller.register_callback_to_all_worker("test", lambda: None) + + def test_abort_all_requests_not_implemented(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + with pytest.raises(NotImplementedError): + controller.abort_all_requests() + + +@pytest.mark.parametrize("num_workers", [1, 2, 4]) +def test_parametrized_worker_count(num_workers): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str(f"sglang.d{num_workers}") + controller.initialize(alloc_mode=alloc_mode) + + assert len(controller.workers) == num_workers + + controller.destroy() + + +@pytest.mark.parametrize( + "consumer_batch_size,max_concurrent_rollouts,expected_capacity", + [(16, 32, 32), (32, 64, 64), (8, 100, 24)], +) +def test_parametrized_capacity_settings( + consumer_batch_size, max_concurrent_rollouts, expected_capacity +): + config = InferenceEngineConfig( + consumer_batch_size=consumer_batch_size, + max_concurrent_rollouts=max_concurrent_rollouts, + max_head_offpolicyness=2, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + capacity = controller.get_capacity() + assert capacity == expected_capacity + + controller.destroy() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/pyproject.toml b/pyproject.toml index 76bad35eb..7f74f6697 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -208,6 +208,9 @@ filterwarnings = [ "ignore::UserWarning:torch.*", "ignore::UserWarning:transformers.*", ] +markers = [ + "integration: marks tests as integration tests (real processes, slower)", +] [tool.ruff] line-length = 88 From b3a3e5304d34fbe4cd9752607da9144371aa76ef Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 14:43:31 +0800 Subject: [PATCH 13/52] fix docstring and type annotations --- areal/api/controller_api.py | 304 +------------------------ areal/controller/rollout_controller.py | 217 +++++++++++++++--- 2 files changed, 185 insertions(+), 336 deletions(-) diff --git a/areal/api/controller_api.py b/areal/api/controller_api.py index 366940e3c..0bf02b10a 100644 --- a/areal/api/controller_api.py +++ b/areal/api/controller_api.py @@ -1,18 +1,13 @@ import abc from collections.abc import Callable -from concurrent.futures import Future -from typing import TYPE_CHECKING, Any, Optional +from typing import Any import torch -from torchdata.stateful_dataloader import StatefulDataLoader from areal.api.alloc_mode import ParallelStrategy -from areal.api.cli_args import InferenceEngineConfig, TrainEngineConfig +from areal.api.cli_args import TrainEngineConfig from areal.api.engine_api import InferenceEngine, TrainEngine from areal.api.io_struct import ( - ModelRequest, - ModelResponse, - ParamSpec, SaveLoadMeta, WeightUpdateMeta, ) @@ -199,10 +194,6 @@ def __setstate__(self, state): raise NotImplementedError() -if TYPE_CHECKING: - from areal.api.workflow_api import RolloutWorkflow - - class TrainController(abc.ABC): """A centralized controller that manages multiple distributed TrainEngine workers. @@ -449,294 +440,3 @@ def forward( The result produced by `post_hook` and `aggregate_fn`. """ raise NotImplementedError() - - -class RolloutController(abc.ABC): - """A centralized controller that manages multiple distributed InferenceEngine workers for rollout generation. - - RolloutController orchestrates distributed inference workloads by scheduling and - dispatching requests across multiple concurrent InferenceEngine instances. It provides - intelligent load balancing, staleness control, and capacity management to optimize - rollout generation efficiency. - - Key features: - - Distributed request scheduling and load balancing across workers - - Centralized staleness and capacity control for consistent performance - - Asynchronous rollout generation with configurable acceptance criteria - - Data aggregation from heterogeneously loaded workers - - The controller handles workload imbalances inherent in rollout generation, where - different workers may produce varying amounts of data depending on the complexity - of their assigned tasks. Generated data is stored locally on workers and aggregated - into `DistributedBatch` objects for seamless integration with TrainController. - """ - - def __init__( - self, - inf_engine: InferenceEngine, - config: InferenceEngineConfig, - scheduler: Scheduler, - ): - self.inf_engine = inf_engine - self.config = config - self.scheduler = scheduler - - def initialize(self, *args, **kwargs): - """Initialize environments and launch the background thread for asynchronous distributed inference. - - For remote inference engines, this serves as a client and connects to the inference servers. - For local inference engines, this creates an LLM engine on the local GPU. - - Parameters - ---------- - *args - Variable length argument list - **kwargs - Arbitrary keyword arguments - """ - raise NotImplementedError() - - def destroy(self): - """Destroy the engine and release GPU memory for the local inference engine.""" - raise NotImplementedError() - - async def agenerate(self, req: ModelRequest) -> ModelResponse: - """Asynchronously generate a response for the given request. - - Parameters - ---------- - req : ModelRequest - The model request containing input data and generation parameters - - Returns - ------- - ModelResponse - The generated response from the model - """ - raise NotImplementedError() - - def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: - """Initialize the weight update process group for distributed weight updates. - - This method should be called before performing any weight updates to ensure - that the necessary communication groups are set up correctly. - - Parameters - ---------- - meta : WeightUpdateMeta - Metadata containing information about the weight update, such as the - type of communication backend and allocation mode. - - Raises - ------ - NotImplementedError - If the method is not implemented by a subclass. - - Returns - ------- - Future[None] - A future object representing the asynchronous initialization operation. - """ - raise NotImplementedError() - - def update_weights_from_distributed( - self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] - ) -> Future[None]: - """Update weights in the inference engine in a non-blocking manner. - - Parameters - ---------- - meta : WeightUpdateMeta - Metadata containing information about the weight update - param_specs : List[ParamSpec] - A list of parameter specifications for the weights to be updated - - Returns - ------- - Future[None] - A future object representing the asynchronous weight update operation - """ - raise NotImplementedError() - - def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: - """Update weights in the inference engine from disk in a non-blocking manner. - - Parameters - ---------- - meta : WeightUpdateMeta - Metadata containing information about the weight update - - Returns - ------- - Future[None] - A future object representing the asynchronous weight update operation - """ - raise NotImplementedError() - - def set_version(self, version: int) -> None: - """Set the current weight version in the inference engine. - - Parameters - ---------- - version : int - The weight version number to set - """ - raise NotImplementedError() - - def get_version(self) -> int: - """Get the current weight version in the inference engine. - - Returns - ------- - int - The current weight version number - """ - raise NotImplementedError() - - def submit( - self, - data: dict[str, Any], - workflow: Optional["RolloutWorkflow"] = None, - workflow_builder: Callable | None = None, - should_accept: Callable | None = None, - ) -> None: - """Submit a request to the inference engine and return immediately. - - Should be used together with subsequent `wait`. - - Parameters - ---------- - data : Dict[str, Any] - The input data for rollout. Used by the user's customized workflow implementation. - workflow : RolloutWorkflow, optional - The workflow instance to run. Note that a single workflow instance can run multiple data. - Use `workflow` when you want to share some resources between different rollouts. - Either `workflow` or `workflow_builder` should be specified, by default None. - workflow_builder : Callable, optional - A builder to create a workflow instance to run, guaranteed for source separation. - Either `workflow` or `workflow_builder` should be specified, by default None. - should_accept : Callable, optional - A function used to decide whether to accept a specific trajectory, i.e., dynamic filtering. - It takes a complete trajectory output by the workflow, and returns a bool, by default None. - """ - raise NotImplementedError() - - def wait(self, count: int, timeout: float | None = None) -> DistributedBatch: - """Wait for a specified number of requests to complete, with a timeout. - - Should be used together with preceding `submit`. - - Parameters - ---------- - count : int - The number of accepted trajectories to wait for - timeout : float, optional - Timeout in seconds. Exceeding the timeout will raise a `TimeoutError`, by default None - - Returns - ------- - DistributedBatch - A concatenated batch of trajectories - - Raises - ------ - TimeoutError - If the timeout is exceeded before enough trajectories are collected - """ - raise NotImplementedError() - - def rollout_batch( - self, - data: list[dict[str, Any]], - workflow: Optional["RolloutWorkflow"] = None, - workflow_builder: Callable | None = None, - should_accept: Callable | None = None, - ) -> DistributedBatch: - """Submit a batch of requests to the inference engine and wait for the results. - - See `workflow_api.py` for concrete implementation. - - Parameters - ---------- - data : List[Dict[str, Any]] - A list of input data dictionaries for rollout - workflow : RolloutWorkflow, optional - The workflow instance to run, by default None - workflow_builder : Callable, optional - A builder to create a workflow instance, by default None - should_accept : Callable, optional - A function to decide whether to accept a trajectory, by default None - - Returns - ------- - DistributedBatch - A concatenated batch of trajectory results - """ - raise NotImplementedError() - - def prepare_batch( - self, - dataloader: StatefulDataLoader, - workflow: Optional["RolloutWorkflow"] = None, - workflow_builder: Callable | None = None, - should_accept: Callable | None = None, - ) -> DistributedBatch: - """Asynchronously submit and wait until a full batch is ready with controlled staleness. - - See `workflow_api.py` for concrete implementation. - - Parameters - ---------- - dataloader : StatefulDataLoader - The data loader to pull data from for batch preparation - workflow : RolloutWorkflow, optional - The workflow instance to run, by default None - workflow_builder : Callable, optional - A builder to create a workflow instance, by default None - should_accept : Callable, optional - A function to decide whether to accept a trajectory, by default None - - Returns - ------- - DistributedBatch - A full batch of trajectory results with controlled staleness - """ - raise NotImplementedError() - - def pause(self): - """Pause request submission for async rollout. - - Used during evaluation to prevent data over-generation. - """ - raise NotImplementedError() - - def resume(self): - """Resume request submission for async rollout.""" - raise NotImplementedError() - - def register_callback_to_all_worker( - self, method: str, callback: Callable, **kwargs - ): - """Register a callback function for the specified method across all workers. - - Partial rollout API. After successful registration, the controller will poll - and call the specified method in a background thread. When the return value - is obtained, it will be used as a parameter to call the `callback` function. - - Parameters - ---------- - method : str - The name of the method to register the callback for - callback : Callable - The callback function to be called with the method's return value - **kwargs - Additional keyword arguments for the callback registration - """ - raise NotImplementedError() - - def abort_all_requests(self) -> None: - """Abort all ongoing requests in the inference engine. - - Partial rollout API for canceling all queued and in-progress requests. - """ - raise NotImplementedError() diff --git a/areal/controller/rollout_controller.py b/areal/controller/rollout_controller.py index 9157d4b01..fcc9db72d 100644 --- a/areal/controller/rollout_controller.py +++ b/areal/controller/rollout_controller.py @@ -16,7 +16,6 @@ from areal.api.alloc_mode import AllocationMode from areal.api.cli_args import InferenceEngineConfig from areal.api.controller_api import DistributedBatch -from areal.api.controller_api import RolloutController as RolloutControllerAPI from areal.api.engine_api import InferenceEngine from areal.api.io_struct import ModelRequest, ModelResponse, ParamSpec, WeightUpdateMeta from areal.api.scheduler_api import Scheduler, SchedulingConfig, Worker @@ -38,19 +37,35 @@ class _RemoteRolloutTaskInput: should_accept_path: str | None = None -class RolloutController(RolloutControllerAPI): - """A centralized controller managing multiple InferenceEngine workers for rollout generation. +class RolloutController: + """A centralized controller that manages multiple distributed InferenceEngine workers for rollout generation. - This controller orchestrates distributed inference by: - 1. Launching local inference engines on workers via scheduler - 2. Scheduling requests to specific engines via round-robin - 3. Delegating actual execution to AsyncTaskRunner - 4. Aggregating results from workers into DistributedBatch + RolloutController orchestrates distributed inference workloads by scheduling and + dispatching requests across multiple concurrent InferenceEngine instances. It provides + intelligent load balancing, staleness control, and capacity management to optimize + rollout generation efficiency. + + Key features: + - Distributed request scheduling and load balancing across workers + - Centralized staleness and capacity control for consistent performance + - Asynchronous rollout generation with configurable acceptance criteria + - Data aggregation from heterogeneously loaded workers + + The controller handles workload imbalances inherent in rollout generation, where + different workers may produce varying amounts of data depending on the complexity + of their assigned tasks. Generated data is stored locally on workers and aggregated + into `DistributedBatch` objects for seamless integration with TrainController. + + Implementation details: + - Launches local inference engines on workers via scheduler + - Schedules requests to specific engines via round-robin + - Delegates actual execution to AsyncTaskRunner + - Aggregates results from workers into DistributedBatch Parameters ---------- - inf_engine : InferenceEngine - The inference engine class to instantiate on each worker + inf_engine : type[InferenceEngine] + The inference engine class (not instance) to instantiate on each worker config : InferenceEngineConfig Configuration for inference engines scheduler : Scheduler @@ -59,7 +74,7 @@ class RolloutController(RolloutControllerAPI): def __init__( self, - inf_engine: InferenceEngine, + inf_engine: type[InferenceEngine], config: InferenceEngineConfig, scheduler: Scheduler, ): @@ -67,14 +82,16 @@ def __init__( Parameters ---------- - inf_engine : InferenceEngine + inf_engine : type[InferenceEngine] The inference engine class (not instance) to create on workers config : InferenceEngineConfig Configuration for the inference engines scheduler : Scheduler Scheduler for managing workers """ - super().__init__(inf_engine, config, scheduler) + self.inf_engine = inf_engine + self.config = config + self.scheduler = scheduler # Worker management self.workers: list[Worker] = [] # List of Worker objects from scheduler @@ -107,6 +124,20 @@ def initialize( *args, **kwargs, ): + """Initialize environments and launch the background thread for asynchronous distributed inference. + + For remote inference engines, this serves as a client and connects to the inference servers. + For local inference engines, this creates an LLM engine on the local GPU. + + Parameters + ---------- + alloc_mode : AllocationMode + The allocation mode configuration for distributed setup + *args + Variable length argument list passed to engine initialization + **kwargs + Arbitrary keyword arguments passed to engine initialization + """ self.logger = logging.getLogger("[RolloutController]") # Get scheduling config from kwargs or use defaults @@ -191,7 +222,10 @@ async def _async_initialize( self.logger.info("All engines are initialized...") def destroy(self): - """Destroy the controller and clean up resources.""" + """Destroy the engine and release GPU memory for the local inference engine. + + This method cleans up all resources including workers, task runner, and thread pool. + """ self.logger.info("Destroying RolloutController...") # Destroy task runner @@ -291,6 +325,24 @@ def submit( workflow_kwargs: dict[str, Any], should_accept_path: str | None = None, ) -> None: + """Submit a request to the inference engine and return immediately. + + Should be used together with subsequent `wait`. + + Parameters + ---------- + data : dict[str, Any] + The input data for rollout. Used by the user's customized workflow implementation. + workflow_path : str + The fully qualified path to the workflow class (e.g., "module.submodule.WorkflowClass"). + The workflow will be dynamically imported on the worker. + workflow_kwargs : dict[str, Any] + Keyword arguments to pass to the workflow constructor. + should_accept_path : str | None, optional + The fully qualified path to a function used to decide whether to accept a specific + trajectory (dynamic filtering). The function should take a complete trajectory + output by the workflow and return a bool, by default None. + """ # Add to pending queue (will be submitted when capacity allows) self._pending_inputs.append( _RemoteRolloutTaskInput( @@ -333,6 +385,27 @@ def _commit_one_to_runner(self): ) def wait(self, count: int, timeout: float | None = None) -> DistributedBatch: + """Wait for a specified number of requests to complete, with a timeout. + + Should be used together with preceding `submit`. + + Parameters + ---------- + count : int + The number of accepted trajectories to wait for + timeout : float | None, optional + Timeout in seconds. Exceeding the timeout will raise a `TimeoutError`, by default None + + Returns + ------- + DistributedBatch + A concatenated batch of trajectories + + Raises + ------ + TimeoutError + If the timeout is exceeded before enough trajectories are collected + """ ####################################################### # The following logic is copied from WorkflowExecutor # ####################################################### @@ -409,6 +482,24 @@ def rollout_batch( workflow_kwargs: dict[str, Any], should_accept_path: str | None = None, ) -> DistributedBatch: + """Submit a batch of requests to the inference engine and wait for the results. + + Parameters + ---------- + data : list[dict[str, Any]] + A list of input data dictionaries for rollout + workflow_path : str + The fully qualified path to the workflow class (e.g., "module.submodule.WorkflowClass") + workflow_kwargs : dict[str, Any] + Keyword arguments to pass to the workflow constructor + should_accept_path : str | None, optional + The fully qualified path to a function to decide whether to accept a trajectory, by default None + + Returns + ------- + DistributedBatch + A concatenated batch of trajectory results + """ # Submit all requests for item in data: self.submit( @@ -428,6 +519,24 @@ def prepare_batch( workflow_kwargs: dict[str, Any], should_accept_path: str | None = None, ) -> DistributedBatch: + """Asynchronously submit and wait until a full batch is ready with controlled staleness. + + Parameters + ---------- + dataloader : StatefulDataLoader + The data loader to pull data from for batch preparation + workflow_path : str + The fully qualified path to the workflow class (e.g., "module.submodule.WorkflowClass") + workflow_kwargs : dict[str, Any] + Keyword arguments to pass to the workflow constructor + should_accept_path : str | None, optional + The fully qualified path to a function to decide whether to accept a trajectory, by default None + + Returns + ------- + DistributedBatch + A full batch of trajectory results with controlled staleness + """ ####################################################### # The following logic is copied from WorkflowExecutor # ####################################################### @@ -461,17 +570,19 @@ def prepare_batch( async def agenerate(self, req: ModelRequest) -> ModelResponse: """Asynchronously generate a response for the given request. + This method provides direct access to the inference engine's generation capabilities + for single requests, bypassing the workflow system. + Parameters ---------- req : ModelRequest - Model request containing input data and generation parameters + The model request containing input data and generation parameters Returns ------- ModelResponse - Generated response from the model + The generated response from the model """ - # Choose worker and delegate worker = self._choose_worker() @@ -483,17 +594,21 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse: ) def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: - """Initialize weight update process group for all workers. + """Initialize the weight update process group for distributed weight updates. + + This method should be called before performing any weight updates to ensure + that the necessary communication groups are set up correctly across all workers. Parameters ---------- meta : WeightUpdateMeta - Metadata containing weight update information + Metadata containing information about the weight update, such as the + type of communication backend and allocation mode. Returns ------- Future[None] - Future representing the async initialization operation + A future object representing the asynchronous initialization operation. """ async def _init_all_workers(): @@ -515,19 +630,19 @@ def init_all_workers(): def update_weights_from_distributed( self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] ) -> Future[None]: - """Update weights from distributed memory for all workers. + """Update weights in the inference engine in a non-blocking manner from distributed memory. Parameters ---------- meta : WeightUpdateMeta - Metadata containing weight update information + Metadata containing information about the weight update param_specs : list[ParamSpec] - Parameter specifications for weights to update + A list of parameter specifications for the weights to be updated Returns ------- Future[None] - Future representing the async update operation + A future object representing the asynchronous weight update operation """ async def _update_all_workers(): @@ -548,17 +663,17 @@ def update_all_workers(): return self.executor.submit(update_all_workers) def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: - """Update weights from disk for all workers. + """Update weights in the inference engine from disk in a non-blocking manner. Parameters ---------- meta : WeightUpdateMeta - Metadata containing weight update information + Metadata containing information about the weight update Returns ------- Future[None] - Future representing the async update operation + A future object representing the asynchronous weight update operation """ async def _update_all_workers(): @@ -578,14 +693,16 @@ def update_all_workers(): return self.executor.submit(update_all_workers) def set_version(self, version: int) -> None: - """Set the current weight version for all workers. + """Set the current weight version in the inference engine. + + This updates the version number across all workers, which is used for + staleness tracking in online training scenarios. Parameters ---------- version : int - Weight version number to set + The weight version number to set """ - self._version = version for worker in self.workers: try: @@ -598,18 +715,20 @@ def set_version(self, version: int) -> None: self.logger.error(f"Error setting version for worker {worker.id}: {e}") def get_version(self) -> int: - """Get the current weight version. + """Get the current weight version in the inference engine. Returns ------- int - Current weight version number + The current weight version number """ return self._version def pause(self): - """Pause request submission for async rollout on all workers.""" + """Pause request submission for async rollout. + Used during evaluation to prevent data over-generation across all workers. + """ for worker in self.workers: try: self.scheduler.call_engine( @@ -620,7 +739,7 @@ def pause(self): self.logger.error(f"Error pausing worker {worker.id}: {e}") def resume(self): - """Resume request submission for async rollout on all workers.""" + """Resume request submission for async rollout across all workers.""" for worker in self.workers: try: self.scheduler.call_engine( @@ -633,7 +752,37 @@ def resume(self): def register_callback_to_all_worker( self, method: str, callback: Callable, **kwargs ): + """Register a callback function for the specified method across all workers. + + Partial rollout API. After successful registration, the controller will poll + and call the specified method in a background thread. When the return value + is obtained, it will be used as a parameter to call the `callback` function. + + Parameters + ---------- + method : str + The name of the method to register the callback for + callback : Callable + The callback function to be called with the method's return value + **kwargs + Additional keyword arguments for the callback registration + + Raises + ------ + NotImplementedError + This method is not yet implemented + """ raise NotImplementedError() def abort_all_requests(self) -> None: + """Abort all ongoing requests in the inference engine. + + Partial rollout API for canceling all queued and in-progress requests across + all workers. + + Raises + ------ + NotImplementedError + This method is not yet implemented + """ raise NotImplementedError() From f223db1cac7e250841a94dac1446695ad7527e9c Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 14:49:46 +0800 Subject: [PATCH 14/52] merge train controller commit --- areal/api/cli_args.py | 33 +++++ areal/api/controller_api.py | 32 +---- areal/api/engine_api.py | 21 ++- areal/api/scheduler_api.py | 33 ++--- areal/controller/batch.py | 36 +++++ areal/engine/base_hf_engine.py | 8 +- areal/engine/fsdp_engine.py | 2 + areal/engine/ppo/actor.py | 8 +- areal/scheduler/rpc/rpc_client.py | 27 +++- areal/scheduler/rpc/rpc_server.py | 215 +++++++++++++++++++++++------- areal/tests/test_rpc.py | 125 ++--------------- areal/utils/http.py | 33 +++++ areal/utils/recover.py | 7 +- areal/utils/saver.py | 3 +- areal/utils/stats_logger.py | 6 +- areal/workflow/rlvr.py | 40 +++++- docs/cli_reference.md | 49 ++++--- 17 files changed, 434 insertions(+), 244 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 1fd3f6bbe..f65bb0f98 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -311,6 +311,31 @@ class MegatronEngineConfig: recompute_modules: list[str] | None = None +@dataclass +class SchedulingSpec: + cpu: int = field(default=0, metadata={"help": "Number of CPU cores required"}) + gpu: int = field(default=0, metadata={"help": "Number of GPU units required"}) + mem: int = field(default=0, metadata={"help": "Amount of memory (GB) required"}) + port_count: int = field(default=2, metadata={"help": "Number of ports to expose"}) + image: str = field( + default="", metadata={"help": "Docker/Singularity container image to use"} + ) + type: str = field( + default="worker", + metadata={ + "help": "Task type (e.g., worker, engine)", + "choices": ["worker", "engine"], + }, + ) + env_vars: Dict[str, str] = field( + default_factory=dict, + metadata={"help": "Environment variables for the container"}, + ) + cmd: str = field( + default="", metadata={"help": "Command to execute inside the container"} + ) + + @dataclass class TrainEngineConfig: """Core configuration for model training, including optimization and backend settings.""" @@ -384,6 +409,10 @@ class TrainEngineConfig: default="lora", metadata={"help": "peft method type. Only LoRA is supported for now."}, ) + scheduling_specs: List[SchedulingSpec] = field( + default_factory=list, + metadata={"help": "train engine schedule specs"}, + ) @dataclass @@ -844,6 +873,10 @@ class InferenceEngineConfig: "help": "The grace period after calling /pause_generation. Wait until all requests have been dropped." }, ) + scheduling_specs: List[SchedulingSpec] = field( + default_factory=list, + metadata={"help": "inference engine schedule specs"}, + ) @dataclass diff --git a/areal/api/controller_api.py b/areal/api/controller_api.py index 366940e3c..05303b671 100644 --- a/areal/api/controller_api.py +++ b/areal/api/controller_api.py @@ -4,14 +4,11 @@ from typing import TYPE_CHECKING, Any, Optional import torch -from torchdata.stateful_dataloader import StatefulDataLoader from areal.api.alloc_mode import ParallelStrategy from areal.api.cli_args import InferenceEngineConfig, TrainEngineConfig from areal.api.engine_api import InferenceEngine, TrainEngine from areal.api.io_struct import ( - ModelRequest, - ModelResponse, ParamSpec, SaveLoadMeta, WeightUpdateMeta, @@ -676,7 +673,7 @@ def rollout_batch( def prepare_batch( self, - dataloader: StatefulDataLoader, + dataloader: DistributedBatch, workflow: Optional["RolloutWorkflow"] = None, workflow_builder: Callable | None = None, should_accept: Callable | None = None, @@ -713,30 +710,3 @@ def pause(self): def resume(self): """Resume request submission for async rollout.""" raise NotImplementedError() - - def register_callback_to_all_worker( - self, method: str, callback: Callable, **kwargs - ): - """Register a callback function for the specified method across all workers. - - Partial rollout API. After successful registration, the controller will poll - and call the specified method in a background thread. When the return value - is obtained, it will be used as a parameter to call the `callback` function. - - Parameters - ---------- - method : str - The name of the method to register the callback for - callback : Callable - The callback function to be called with the method's return value - **kwargs - Additional keyword arguments for the callback registration - """ - raise NotImplementedError() - - def abort_all_requests(self) -> None: - """Abort all ongoing requests in the inference engine. - - Partial rollout API for canceling all queued and in-progress requests. - """ - raise NotImplementedError() diff --git a/areal/api/engine_api.py b/areal/api/engine_api.py index 216d5c02c..5a4ed6336 100644 --- a/areal/api/engine_api.py +++ b/areal/api/engine_api.py @@ -3,6 +3,7 @@ from concurrent.futures import Future from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional import torch import torch.distributed as dist @@ -26,12 +27,14 @@ class Scheduling: cpu: int gpu: int mem: int + port_count: int + cmd: str | None = None nodelist: str | None = None exclude: str | None = None partition: str | None = None container_image: str | None = None - type: str | None = None - env_vars: dict[str, str] = field(default_factory=dict) + type: Literal["worker", "engine"] = None + env_vars: Dict[str, str] = field(default_factory=dict) # time utils from "https://slurm.schedmd.com/sbatch.html" time_limit: str | None = None # see "--time" option for format begin: str | None = None # see "--begin" option for format @@ -138,7 +141,7 @@ def parallelism_group(self) -> dist.ProcessGroup: """ raise NotImplementedError() - def get_scheduling_config(self) -> Scheduling: + def get_scheduling_config(self) -> List[Scheduling]: """Get the scheduling configuration for the engine. This includes configuration such as container image, CPU/GPU/memory size. @@ -588,3 +591,15 @@ def pause(self): def resume(self): """Resume request submission for async rollout.""" raise NotImplementedError() + + def get_scheduling_config(self) -> List[Scheduling]: + """Get the scheduling configuration for the engine. + + This includes configuration such as container image, CPU/GPU/memory size. + + Returns + ------- + List[Scheduling] + A list of scheduling configurations for the engine + """ + raise NotImplementedError() diff --git a/areal/api/scheduler_api.py b/areal/api/scheduler_api.py index 5d2076212..de2c12621 100644 --- a/areal/api/scheduler_api.py +++ b/areal/api/scheduler_api.py @@ -1,46 +1,41 @@ import abc from dataclasses import dataclass, field +from typing import List, Literal + +from areal.api.engine_api import Scheduling @dataclass class Worker: id: str + # worker and engine deploy on the same machine, so ip are the same ip: str - ports: list[str] = field(default_factory=list) - - -@dataclass -class ContainerSpec: - cpu: int = 0 - gpu: int = 0 - mem: int = 0 - container_image: str = "" - cmd: str = "" - env_vars: dict[str, str] = field(default_factory=dict) - port_count: int = 2 + worker_ports: List[str] = field(default_factory=list) + engine_ports: List[str] = field(default_factory=list) @dataclass class ScheduleStrategy: - type: str = "" - uid: str = "" + type: Literal["colocation", "separation"] = "separation" + target: str = "" @dataclass -class SchedulingConfig: +class Job: replicas: int = 0 - specs: list[ContainerSpec] = field(default_factory=list) + tasks: List[Scheduling] = field(default_factory=list) schedule_strategy: ScheduleStrategy | None = None role: str = "" class Scheduler(abc.ABC): - def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> str: + def create_workers(self, job: Job, *args, **kwargs) -> None: """ - Start workers, return job id + Start workers """ + raise NotImplementedError() - def get_workers(self, worker_key, timeout=None) -> list[Worker]: + def get_workers(self, role: str, timeout=None) -> List[Worker]: """ Wait and return worker list, including scheduling results such as ip and engine ports (worker id, ip, ports) diff --git a/areal/controller/batch.py b/areal/controller/batch.py index 66df87573..0a65a3bf9 100644 --- a/areal/controller/batch.py +++ b/areal/controller/batch.py @@ -234,6 +234,42 @@ def get_data(self) -> Dict[str, Union[torch.Tensor, Any]]: return batch_data + def to_list(self) -> List[Dict[str, Any]]: + """Convert the dataset to a list format. + + Returns a list where each element is a dictionary representing + a single sample from the dataset. + + Returns + ------- + List[Dict[str, Any]] + List of dictionaries, where each dictionary contains the data + for one sample with keys as field names and values as the + corresponding field values for that sample. + """ + if not self.dataset: + return [] + + total_size = self._get_total_size() + if total_size == 0: + return [] + + # Build list of individual samples + result = [] + for i in range(total_size): + sample = {} + for key, values in self.dataset.items(): + if isinstance(values, torch.Tensor): + sample[key] = values[i] + elif isinstance(values, list): + sample[key] = values[i] if i < len(values) else None + else: + # For scalar values, use the same value for all samples + sample[key] = values + result.append(sample) + + return result + @staticmethod def concat(data: list["DistributedBatchMemory"]) -> "DistributedBatchMemory": """Concatenate multiple DistributedBatchMemory objects diff --git a/areal/engine/base_hf_engine.py b/areal/engine/base_hf_engine.py index 4adfb8950..f0af9651d 100644 --- a/areal/engine/base_hf_engine.py +++ b/areal/engine/base_hf_engine.py @@ -19,6 +19,8 @@ from areal.api.alloc_mode import ParallelStrategy from areal.api.cli_args import TrainEngineConfig from areal.api.engine_api import TrainEngine +from areal.api.engine_api import Scheduling, TrainEngine +from areal.api.io_struct import FinetuneSpec from areal.platforms import current_platform from areal.utils import logging from areal.utils.data import ( @@ -41,6 +43,7 @@ is_valid_vision_model, ) from areal.utils.nccl import NCCL_DEFAULT_TIMEOUT +from areal.utils.scheduler import scheduling_specs_to_schedulings class BaseHFEngine(TrainEngine): @@ -69,7 +72,7 @@ def __init__(self, config: TrainEngineConfig): ) self.is_vision_model = is_valid_vision_model(self.model_config.model_type) - self.world_size = int(os.environ["WORLD_SIZE"]) + self.world_size: int def set_version(self, version: int): self._version = version @@ -545,3 +548,6 @@ def forward( unpacked = unpack_sequence(res, lens=output_seqlens, dim=0) reordered = reorder_list(unpacked, mb_list.backward_indices) return pad_and_stack_tensors_along_first_dim(reordered) + + def get_scheduling_config(self) -> List[Scheduling]: + return scheduling_specs_to_schedulings(self.config.scheduling_specs) diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index 44c562a4f..e883a6077 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -139,6 +139,8 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None self.dp_head = int(self.world_mesh["sp_tp"].mesh[0].item()) self.dp_rank = dist.get_rank(self.dp_group) + self.world_size = int(os.environ["WORLD_SIZE"]) + self.logger.info(f"Data parallel head {self.dp_head} and rank {self.dp_rank}") def initialize( diff --git a/areal/engine/ppo/actor.py b/areal/engine/ppo/actor.py index ec84b6ef2..c35527a04 100644 --- a/areal/engine/ppo/actor.py +++ b/areal/engine/ppo/actor.py @@ -69,7 +69,7 @@ def calc_logprobs(logits, input_data): aggregate_fn=lambda xs: torch.cat(xs, dim=-1), ) - def compute_advantages(self, data: Dict[str, Any]) -> None: + def compute_advantages(self, data: Dict[str, Any]) -> Dict[str, Any]: bs = data["input_ids"].shape[0] max_seqlen = data["input_ids"].shape[1] batch_indices = torch.arange( @@ -163,6 +163,8 @@ def compute_advantages(self, data: Dict[str, Any]) -> None: # because we have rolled old_logp by -1 data["logprobs"] = old_logp + return data + def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]: if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: data, sampling_stat = dynamic_sampling(data, self.group_size) @@ -285,8 +287,8 @@ def compute_logp(self, *args, **kwargs) -> torch.Tensor | None: return self.actor.compute_logp(*args, **kwargs) @torch.no_grad() - def compute_advantages(self, *args, **kwargs) -> None: - self.actor.compute_advantages(*args, **kwargs) + def compute_advantages(self, *args, **kwargs): + return self.actor.compute_advantages(*args, **kwargs) def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]: return self.actor.ppo_update(*args, **kwargs) diff --git a/areal/scheduler/rpc/rpc_client.py b/areal/scheduler/rpc/rpc_client.py index 28f4b8082..c7004df7f 100644 --- a/areal/scheduler/rpc/rpc_client.py +++ b/areal/scheduler/rpc/rpc_client.py @@ -6,7 +6,6 @@ import cloudpickle import requests -from areal.api.cli_args import InferenceEngineConfig, TrainEngineConfig from areal.api.engine_api import InferenceEngine, TrainEngine from areal.utils import logging from areal.utils.http import response_ok, response_retryable @@ -22,16 +21,20 @@ def register(self, worker_id: str, ip: str, port: int) -> None: self._addrs[worker_id] = (ip, port) logger.info(f"Registered worker {worker_id} at {ip}:{port}") + def get_info(self, worker_id: str) -> tuple[str, int]: + return self._addrs[worker_id] + def create_engine( self, worker_id: str, engine_obj: Union[InferenceEngine, TrainEngine], - init_config: Union[InferenceEngineConfig, TrainEngineConfig], + *args, + **kwargs, ) -> None: ip, port = self._addrs[worker_id] url = f"http://{ip}:{port}/create_engine" logger.info(f"send create_engine to {worker_id} ({ip}:{port})") - payload = (engine_obj, init_config) + payload = (engine_obj, args, kwargs) serialized_data = cloudpickle.dumps(payload) serialized_obj = gzip.compress(serialized_data) resp = requests.post(url, data=serialized_obj) @@ -47,6 +50,24 @@ def create_engine( f"Failed to create engine, {resp.status_code}, {resp.content}" ) + def check_health(self, worker_id: str, timeout: int = 20) -> bool: + ip, port = self._addrs[worker_id] + url = f"http://{ip}:{port}/health" + + start_time = time.time() + while time.time() - start_time < timeout: + remain_timeout = timeout - (time.time() - start_time) + try: + resp = requests.post(url, timeout=remain_timeout) + resp.raise_for_status() + return True + except Exception as e: + logger.warning(f"Health check exception for {worker_id}: {e}") + time.sleep(2) + + logger.error(f"Health check failed for {worker_id} after {timeout} seconds") + return False + def call_engine( self, worker_id: str, method: str, max_retries: int = 3, *args, **kwargs ) -> Any: diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index b2bc3d612..56713ab80 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -1,44 +1,143 @@ import argparse import gzip +import inspect import os import traceback +from asyncio import Future +from concurrent import futures from http import HTTPStatus -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from typing import AnyStr +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any, AnyStr, Dict, List import cloudpickle +import torch from tensordict import TensorDict from areal.api.controller_api import DistributedBatch +from areal.api.engine_api import InferenceEngine from areal.controller.batch import DistributedBatchMemory from areal.utils import logging logger = logging.getLogger("RPCServer") -def process_input_to_distributed_batch(*args, **kwargs): - for i in range(len(args)): - if isinstance(args[i], DistributedBatch): - args = list(args) - args[i] = args[i].get_data() - args = tuple(args) +def tensor_container_to_safe( + d: Dict[str, Any] | torch.Tensor | List[torch.Tensor], *args, **kwargs +): + """Apply `t.to(*args, **kwargs)` to all tensors in the dictionary. + Support nested dictionaries. + """ + new_dict = {} + if torch.is_tensor(d): + return d.to(*args, **kwargs) + elif isinstance(d, list): + return [tensor_container_to_safe(v, *args, **kwargs) for v in d] + elif isinstance(d, dict): + for key, value in d.items(): + if isinstance(value, dict) or isinstance(value, list): + new_dict[key] = tensor_container_to_safe(value, *args, **kwargs) + elif torch.is_tensor(value): + new_dict[key] = value.to(*args, **kwargs) + else: + new_dict[key] = value + return new_dict + else: + return d + + +def process_input_to_distributed_batch(to_device, method, *args, **kwargs): + """Process input arguments, converting DistributedBatch based on method signature. + + This function inspects the method signature to determine whether each parameter + expects a dict or list format, then converts DistributedBatch instances accordingly. + """ + # Get method signature + try: + sig = inspect.signature(method) + parameters = sig.parameters + except (ValueError, TypeError): + # Fallback to list if signature inspection fails + parameters = {} + + def convert_distributed_batch(obj, param_name=None): + """Convert DistributedBatch based on expected parameter type.""" + if not isinstance(obj, DistributedBatch): + return obj + + # Determine expected type from parameter annotation + expected_type = None + if param_name and param_name in parameters: + param = parameters[param_name] + if param.annotation != inspect.Parameter.empty: + annotation = param.annotation + if annotation == dict or ( + hasattr(annotation, "__origin__") and annotation.__origin__ is dict + ): + expected_type = "dict" + elif annotation == list or ( + hasattr(annotation, "__origin__") and annotation.__origin__ is list + ): + expected_type = "list" - for k in list(kwargs.keys()): - if isinstance(kwargs[k], DistributedBatch): - kwargs[k] = kwargs[k].get_data() + # Convert based on expected type or fallback to list + if expected_type == "list": + return obj.to_list() + else: + return obj.get_data() - return args, kwargs + # Process args + new_args = list(args) + for i, arg in enumerate(new_args): + if isinstance(arg, DistributedBatch): + # Try to get parameter name for positional arguments + param_names = list(parameters.keys()) + param_name = param_names[i] if i < len(param_names) else None + new_args[i] = convert_distributed_batch(arg, param_name) + + # Process kwargs + new_kwargs = {} + for key, value in kwargs.items(): + if isinstance(value, DistributedBatch): + new_kwargs[key] = convert_distributed_batch(value, key) + else: + new_kwargs[key] = value + + # Apply device transfer + new_args = tuple(tensor_container_to_safe(new_args, to_device)) + new_kwargs = tensor_container_to_safe(new_kwargs, to_device) + + return new_args, new_kwargs def process_output_to_distributed_batch(result): - if isinstance(result, dict): - return DistributedBatchMemory.from_dict(result) - elif isinstance(result, TensorDict): + result = tensor_container_to_safe(result, "cpu") + + if isinstance(result, TensorDict): return DistributedBatchMemory.from_dict(result.to_dict()) - elif isinstance(result, (list, tuple)): - return DistributedBatchMemory.from_list(list(result)) - else: - return result + + if isinstance(result, (Future, futures.Future)): + return result.result() + + if isinstance(result, list) and result: + if all(isinstance(item, dict) for item in result): + is_list_of_dict_str_tensor = True + for item in result: + for key, value in item.items(): + if not isinstance(key, str) or not isinstance(value, torch.Tensor): + is_list_of_dict_str_tensor = False + break + if is_list_of_dict_str_tensor: + DistributedBatchMemory.from_list(result) + + if isinstance(result, dict) and result: + is_dict_of_tensor = all( + isinstance(key, str) and isinstance(value, torch.Tensor) + for key, value in result.items() + ) + if is_dict_of_tensor: + return DistributedBatchMemory.from_dict(result) + + return result class EngineRPCServer(BaseHTTPRequestHandler): @@ -49,7 +148,6 @@ def _read_body(self, timeout=120.0) -> AnyStr: try: length = int(self.headers["Content-Length"]) old_timeout = self.request.gettimeout() - logger.info(f"Receive rpc call, path: {self.path}, timeout: {old_timeout}") # set max read timeout = 120s here, if read hang raise exception self.request.settimeout(timeout) return self.rfile.read(length) @@ -76,9 +174,9 @@ def do_POST(self): try: if self.path == "/create_engine": decompressed_data = gzip.decompress(data) - engine_obj, init_args = cloudpickle.loads(decompressed_data) + engine_obj, args, kwargs = cloudpickle.loads(decompressed_data) EngineRPCServer.engine = engine_obj - result = EngineRPCServer.engine.initialize(init_args) + result = EngineRPCServer.engine.initialize(*args, **kwargs) logger.info(f"Engine created and initialized on RPC server: {result}") self.send_response(HTTPStatus.OK) self.end_headers() @@ -91,15 +189,34 @@ def do_POST(self): logger.error("Call received but engine is none.") return action, args, kwargs = cloudpickle.loads(data) + logger.info(f"Received call for action: {action}") + method = getattr(EngineRPCServer.engine, action) + # NOTE: DO NOT print args here, args may be a very huge tensor - logger.info(f"RPC server calling engine method: {action}") - args, kwargs = process_input_to_distributed_batch(*args, **kwargs) - result = method(*args, **kwargs) - result = process_output_to_distributed_batch(result) + if isinstance(EngineRPCServer.engine, InferenceEngine): + device = "cpu" + else: # actor + device = EngineRPCServer.engine.device + + args, kwargs = process_input_to_distributed_batch( + device, method, *args, **kwargs + ) + if ( + check_attribute_type(type(EngineRPCServer.engine), action) + == "method" + ): + result = method(*args, **kwargs) + result = process_output_to_distributed_batch(result) + else: + result = method self.send_response(HTTPStatus.OK) self.end_headers() self.wfile.write(cloudpickle.dumps(result)) + elif self.path == "/health": + self.send_response(HTTPStatus.OK) + self.end_headers() + self.wfile.write(b"OK") else: self.send_response(HTTPStatus.NOT_FOUND) self.end_headers() @@ -113,37 +230,41 @@ def do_POST(self): def start_rpc_server(port): - server = ThreadingHTTPServer(("0.0.0.0", port), EngineRPCServer) + # NOTE: We must use HTTPServer rather than ThreadingHTTPServer here, since the rank and device info + # of pytorch is thread level, if use ThreadingHTTPServer, the device set by create_engine thread + # will not be seen by call_engine thread. + # server = ThreadingHTTPServer(("0.0.0.0", port), EngineRPCServer) + server = HTTPServer(("0.0.0.0", port), EngineRPCServer) server.serve_forever() -def get_serve_port(args): - port = args.port - port_str = os.environ.get("PORT_LIST", "").strip() - - # Check if PORT_LIST is set - if port_str: - # Split by comma and strip whitespace - ports = [p.strip() for p in port_str.split(",")] - # Use the first valid port from the list - if ports and ports[0]: - try: - return int(ports[0]) - except ValueError: - logger.warning( - f"Invalid port '{ports[0]}' in PORT_LIST. Falling back to --port argument." - ) - return port +def check_attribute_type(cls, attr_name): + if hasattr(cls, attr_name): + attr = getattr(cls, attr_name) # 从类获取 + if isinstance(attr, property): + return "property" + elif callable(attr): + return "method" + else: + raise f"unsupported attr, type: {type(attr)}, name: {attr_name}" + raise f"attr not found, name: {attr_name}" + + +def get_server_port(port: int) -> int: + if port: + return port + port_list = os.environ.get("PORT_LIST", "").strip() + ports = [p.strip() for p in port_list.split(",")] + # use the first port as serve port + return int(ports[0]) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--port", type=int, required=False) args, unknown = parser.parse_known_args() - port = get_serve_port(args) + port = get_server_port(args.port) logger.info(f"About to start RPC server on {port}") - start_rpc_server(port) diff --git a/areal/tests/test_rpc.py b/areal/tests/test_rpc.py index 2f5ab493a..d78a0c9e3 100644 --- a/areal/tests/test_rpc.py +++ b/areal/tests/test_rpc.py @@ -16,8 +16,7 @@ from areal.scheduler.rpc.rpc_client import RPCClient from areal.scheduler.rpc.rpc_server import ( EngineRPCServer, - get_serve_port, - process_input_to_distributed_batch, + get_server_port, process_output_to_distributed_batch, start_rpc_server, ) @@ -29,6 +28,7 @@ class MockEngine: def __init__(self): self.initialized = False self.call_count = 0 + self.device = "cpu" def initialize(self, config): self.initialized = True @@ -53,55 +53,7 @@ def return_origin_batch(self, batch): return batch -# Test RPC data processing functions - - -def test_process_input_to_distributed_batch_with_memory_batch(): - """Test processing input parameters containing DistributedBatchMemory""" - # Create DistributedBatchMemory instance - data = { - "input_ids": torch.tensor([1, 2, 3, 4]), - "labels": torch.tensor([5, 6, 7, 8]), - "metadata": ["text1", "text2", "text3", "text4"], - } - batch = DistributedBatchMemory.from_dict(data) - - # Test args and kwargs containing DistributedBatchMemory - args = (batch, "other_arg") - kwargs = {"batch_param": batch, "other_param": "value"} - - processed_args, processed_kwargs = process_input_to_distributed_batch( - *args, **kwargs - ) - - # Verify DistributedBatchMemory is converted to dictionary - assert isinstance(processed_args[0], dict) - assert processed_args[1] == "other_arg" - assert isinstance(processed_kwargs["batch_param"], dict) - assert processed_kwargs["other_param"] == "value" - - # Verify converted dictionary contains original data - converted_data = processed_args[0] - torch.testing.assert_close(converted_data["input_ids"], data["input_ids"]) - torch.testing.assert_close(converted_data["labels"], data["labels"]) - assert converted_data["metadata"] == data["metadata"] - - -def test_process_input_no_distributed_batch(): - """Test processing input that does not contain DistributedBatch""" - args = ("arg1", "arg2", torch.tensor([1, 2, 3])) - kwargs = {"param1": "value1", "param2": torch.tensor([4, 5, 6])} - - processed_args, processed_kwargs = process_input_to_distributed_batch( - *args, **kwargs - ) - - # Should remain unchanged - assert processed_args == args - assert processed_kwargs == kwargs - - -def test_process_output_to_distributed_batch_dict(): +def test_process_output_to_dict(): """Test converting dictionary output to DistributedBatch""" result = { "output_ids": torch.tensor([1, 2, 3]), @@ -112,13 +64,7 @@ def test_process_output_to_distributed_batch_dict(): processed = process_output_to_distributed_batch(result) # Should be converted to DistributedBatchMemory - assert isinstance(processed, DistributedBatchMemory) - - # Verify data integrity - processed_data = processed.get_data() - torch.testing.assert_close(processed_data["output_ids"], result["output_ids"]) - torch.testing.assert_close(processed_data["scores"], result["scores"]) - assert processed_data["texts"] == result["texts"] + assert isinstance(processed, dict) def test_process_output_to_distributed_batch_tensordict(): @@ -139,8 +85,8 @@ def test_process_output_to_distributed_batch_tensordict(): torch.testing.assert_close(processed_data["tensor2"], tensor_dict["tensor2"]) -def test_process_output_to_distributed_batch_list(): - """Test converting list/tuple output to DistributedBatch""" +def test_process_output_to_list(): + """Test converting list/tuple output to list""" result_list = [ {"id": 1, "value": torch.tensor([0.1])}, {"id": 2, "value": torch.tensor([0.2])}, @@ -148,7 +94,7 @@ def test_process_output_to_distributed_batch_list(): ] processed_list = process_output_to_distributed_batch(result_list) - assert isinstance(processed_list, DistributedBatchMemory) + assert isinstance(processed_list, list) def test_process_output_to_distributed_batch_other_types(): @@ -175,60 +121,17 @@ def test_process_output_to_distributed_batch_other_types(): def test_get_serve_port_from_args(): """Test getting port from command line arguments""" mock_args = Mock() - mock_args.port = 8080 + mock_args.rpc_port = 8080 with patch.dict("os.environ", {}, clear=True): - port = get_serve_port(mock_args) + port = get_server_port(mock_args.rpc_port) assert port == 8080 -def test_get_serve_port_from_env_single_port(): +def test_get_server_ports_default_from_multi_ports(): """Test getting single port from PORT_LIST environment variable""" - mock_args = Mock() - mock_args.port = 8080 - - with patch.dict("os.environ", {"PORT_LIST": "9000"}): - port = get_serve_port(mock_args) - assert port == 9000 - - -def test_get_serve_port_from_env_multiple_ports(): - """Test getting first port from multiple ports in PORT_LIST environment variable""" - mock_args = Mock() - mock_args.port = 8080 - - with patch.dict("os.environ", {"PORT_LIST": "9000, 9001, 9002"}): - port = get_serve_port(mock_args) - assert port == 9000 - - -def test_get_serve_port_invalid_env_port(): - """Test fallback when PORT_LIST contains invalid ports""" - mock_args = Mock() - mock_args.port = 8080 - - with patch.dict("os.environ", {"PORT_LIST": "invalid_port, 9001"}): - port = get_serve_port(mock_args) - assert port == 8080 - - -def test_get_serve_port_empty_env(): - """Test fallback when PORT_LIST is empty""" - mock_args = Mock() - mock_args.port = 8080 - - with patch.dict("os.environ", {"PORT_LIST": ""}): - port = get_serve_port(mock_args) - assert port == 8080 - - -def test_get_serve_port_whitespace_env(): - """Test fallback when PORT_LIST contains only whitespace""" - mock_args = Mock() - mock_args.port = 8080 - - with patch.dict("os.environ", {"PORT_LIST": " "}): - port = get_serve_port(mock_args) + with patch.dict("os.environ", {"PORT_LIST": "8080,8081,8082,8083"}, clear=True): + port = get_server_port(0) assert port == 8080 @@ -314,7 +217,6 @@ def test_end_to_end_with_distributed_batch_memory(setup_rpc_server): batch_data = { "input_ids": torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), "attention_mask": torch.tensor([[1, 1, 1], [1, 1, 0], [1, 0, 0]]), - "metadata": ["sample1", "sample2", "sample3"], } batch = DistributedBatchMemory.from_dict(batch_data) @@ -323,7 +225,7 @@ def test_end_to_end_with_distributed_batch_memory(setup_rpc_server): # Verify batch processing success assert process_result["processed"] - assert process_result["batch_size"] == 3 + assert process_result["batch_size"] == 2 # Test tensor processing distrubuted_batch_result = client.call_engine( @@ -333,4 +235,3 @@ def test_end_to_end_with_distributed_batch_memory(setup_rpc_server): tensor_result = distrubuted_batch_result.get_data() assert torch.equal(tensor_result["input_ids"], batch_data["input_ids"]) assert torch.equal(tensor_result["attention_mask"], batch_data["attention_mask"]) - assert tensor_result["metadata"] == batch_data["metadata"] diff --git a/areal/utils/http.py b/areal/utils/http.py index d2bb9b96f..559035db2 100644 --- a/areal/utils/http.py +++ b/areal/utils/http.py @@ -1,6 +1,11 @@ import asyncio +import os +import signal +import traceback +from concurrent.futures import Future, as_completed from http import HTTPStatus from typing import Any +from typing import Any, Dict, List, Optional import aiohttp @@ -115,3 +120,31 @@ def response_ok(http_code: int) -> bool: def response_retryable(http_code: int) -> bool: return http_code == HTTPStatus.REQUEST_TIMEOUT + + +def wait_future_ordered( + futures: List[Future], exit_on_exception: bool = False +) -> List[Any]: + """ + Waits for a list of futures to complete and returns the results in the order the futures were submitted. + :param futures: List of Future objects to wait for. + :param exit_on_exception: If True, terminate the process upon an exception in any future. + If False, raise the exception. + :return: List of results in the same order as the input futures. + :raises Exception: If exit_on_exception is False and any future raises an exception. + """ + results = [None] * len(futures) + future_index_map = {future: i for i, future in enumerate(futures)} + for future in as_completed(futures): + index = future_index_map[future] + try: + results[index] = future.result() + except Exception as e: + logger.warning(f"Exception caught when waiting for future: {e}") + logger.warning(traceback.format_exc()) + if exit_on_exception: + logger.info("Exiting due to exception in future.") + os.kill(os.getpid(), signal.SIGTERM) + else: + raise e + return results diff --git a/areal/utils/recover.py b/areal/utils/recover.py index 7c6ee78a9..c936a26fd 100644 --- a/areal/utils/recover.py +++ b/areal/utils/recover.py @@ -9,6 +9,7 @@ from transformers import AutoProcessor, PreTrainedTokenizerFast from areal.api.cli_args import RecoverConfig +from areal.api.controller_api import TrainController from areal.api.engine_api import InferenceEngine, TrainEngine from areal.api.io_struct import FinetuneSpec, SaveLoadMeta, StepInfo, WeightUpdateMeta from areal.utils import logging, timeutil @@ -213,7 +214,7 @@ def dump( def load( self, - engine: TrainEngine | Dict[str, TrainEngine], + engine: TrainEngine | Dict[str, TrainEngine] | TrainController, saver: Saver, evaluator: Evaluator, stats_logger: "StatsLogger", @@ -231,7 +232,7 @@ def load( weight_update_meta is not None ), "Inference engine requires weight update meta for recovery." - if isinstance(engine, TrainEngine): + if isinstance(engine, (TrainEngine, TrainController)): engine = {"default": engine} recover_info_path = self.recover_info_path( @@ -304,7 +305,7 @@ def _save_checkpoint( def _load_checkpoint( self, - engine: TrainEngine, + engine: TrainEngine | TrainController, name: str = "default", tokenizer: PreTrainedTokenizerFast | None = None, base_model_path: str | None = None, diff --git a/areal/utils/saver.py b/areal/utils/saver.py index a64e55387..ec0d9c71a 100644 --- a/areal/utils/saver.py +++ b/areal/utils/saver.py @@ -4,6 +4,7 @@ from transformers import AutoProcessor, PreTrainedTokenizerFast from areal.api.cli_args import SaverConfig +from areal.api.controller_api import TrainController from areal.api.engine_api import TrainEngine from areal.api.io_struct import FinetuneSpec, SaveLoadMeta from areal.utils import timeutil @@ -85,7 +86,7 @@ def load_state_dict(self, state_dict): def save( self, - engine: TrainEngine, + engine: TrainEngine | TrainController, epoch: int, step: int, global_step: int, diff --git a/areal/utils/stats_logger.py b/areal/utils/stats_logger.py index 1ac0009b8..530f5761e 100644 --- a/areal/utils/stats_logger.py +++ b/areal/utils/stats_logger.py @@ -38,14 +38,14 @@ def init(self): self.start_time = time.perf_counter() # wandb init, connect to remote wandb host - if self.config.wandb.mode != "disabled": - wandb.login() - if self.config.wandb.wandb_base_url: os.environ["WANDB_API_KEY"] = self.config.wandb.wandb_api_key if self.config.wandb.wandb_api_key: os.environ["WANDB_BASE_URL"] = self.config.wandb.wandb_base_url + if self.config.wandb.mode != "disabled": + wandb.login() + suffix = self.config.wandb.id_suffix if suffix == "timestamp": suffix = time.strftime("%Y_%m_%d_%H_%M_%S") diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index e4ecf6860..2e8f8cccb 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -1,5 +1,6 @@ import asyncio import os +import threading import uuid from typing import Callable @@ -16,6 +17,7 @@ from areal.api.workflow_api import RolloutWorkflow from areal.utils import logging, stats_tracker from areal.utils.data import concat_padded_tensors +from areal.utils.hf_utils import load_hf_tokenizer logger = logging.getLogger("RLVR workflow") @@ -39,7 +41,7 @@ def __init__( self, reward_fn, gconfig: GenerationHyperparameters, - tokenizer: PreTrainedTokenizerFast, + tokenizer: str | PreTrainedTokenizerFast | None = None, enable_thinking: bool = False, rollout_stat_scope: str = "rollout", dump_dir: str | None = None, @@ -48,17 +50,39 @@ def __init__( ): self.reward_fn = reward_fn self.gconfig = gconfig - self.tokenizer = tokenizer + + self.tokenizer = None + self._initialized = False + self.tokenizer_path = "" + self._init_lock = threading.Lock() + # Handle tokenizer parameter + if isinstance(tokenizer, str): + self.tokenizer_path = tokenizer + elif isinstance(tokenizer, PreTrainedTokenizerFast): + self.tokenizer = tokenizer + self.enable_thinking = enable_thinking self.dump_dir = dump_dir self.rollout_stat_scope = rollout_stat_scope - self.async_reward_fn = AsyncRewardWrapper(reward_fn) + self.async_reward_fn = None self.get_input_ids_fn = get_input_ids_fn self.data_extract_prompt_fn = data_extract_prompt_fn if self.dump_dir is not None and not os.path.exists(self.dump_dir): os.makedirs(self.dump_dir, exist_ok=True) + def initialize(self): + if self.async_reward_fn is None: + self.async_reward_fn = AsyncRewardWrapper(self.reward_fn) + if self.tokenizer is None: + self.tokenizer = load_hf_tokenizer(self.tokenizer_path) + async def arun_episode(self, engine: InferenceEngine, data): + if not self._initialized: + with self._init_lock: + if not self._initialized: + self.initialize() + self._initialized = True + input_ids = self.get_input_ids_fn( self.data_extract_prompt_fn(data), self.tokenizer, self.enable_thinking ) @@ -142,3 +166,13 @@ async def arun_episode(self, engine: InferenceEngine, data): await f.write(info + "\n") return concat_padded_tensors(results) + + def __getstate__(self): + # pickle时不保存锁对象 + state = self.__dict__.copy() + del state["_init_lock"] + return state + + def __setstate__(self, state): + self.__dict__ = state + self._init_lock = threading.Lock() diff --git a/docs/cli_reference.md b/docs/cli_reference.md index ac15e5ab5..318cb1ca3 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -74,6 +74,7 @@ For detailed examples, see the experiment configurations in the `examples/` dire - [DistributedDataParallel Configuration](section-distributed-data-parallel) - [MegatronEngine Configuration](section-megatron-engine) - [Scheduler Configuration](section-scheduler) +- [Scheduling Specification](section-scheduling) ______________________________________________________________________ @@ -452,21 +453,22 @@ Controls text generation behavior for rollout. Configuration for inference servers, including offpolicyness control. -| Parameter | Type | Default | Description | -| ------------------------- | --------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string \| None | `None` | - | -| `trial_name` | string \| None | `None` | - | -| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | -| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | -| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | -| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | -| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | -| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | -| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | -| `setup_timeout` | float | `120.0` | Timeout in seconds of connecting to remote servers or launching local servers. | -| `request_timeout` | float | `3600` | Timeout for HTTP requests. | -| `request_retries` | integer | `3` | Number of retries for failed requests. | -| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | +| Parameter | Type | Default | Description | +| ------------------------- | ---------------------------------------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string \| None | `None` | - | +| `trial_name` | string \| None | `None` | - | +| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | +| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | +| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | +| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | +| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | +| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | +| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | +| `setup_timeout` | float | `120.0` | Timeout in seconds of connecting to remote servers or launching local servers. | +| `request_timeout` | float | `3600` | Timeout for HTTP requests. | +| `request_retries` | integer | `3` | Number of retries for failed requests. | +| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | +| `scheduling_specs` | list of [`SchedulingSpec`](section-scheduling) | **Required** | inference engine schedule specs | (section-sg-lang)= @@ -811,3 +813,20 @@ Configuration for worker scheduling. Used in the single-controller mode. Experim | `reward_functioncall_config` | `dict` | **Required** | - | | `reward_model_path` | string | `""` | - | | `reward_model_service_url` | string | `"http://localhost:30000/classify"` | - | + +(section-scheduling)= + +## Scheduling Specification + +Configuration class: SchedulingSpec + +| Parameter | Type | Default | Description | +| ------------ | ------- | ------------ | ---------------------------------------------------------------- | +| `cpu` | integer | `0` | Number of CPU cores required | +| `gpu` | integer | `0` | Number of GPU units required | +| `mem` | integer | `0` | Amount of memory (GB) required | +| `port_count` | integer | `2` | Number of ports to expose | +| `image` | string | `""` | Docker/Singularity container image to use | +| `type` | string | `"worker"` | Task type (e.g., worker, engine) **Choices:** `worker`, `engine` | +| `env_vars` | `Dict` | **Required** | Environment variables for the container | +| `cmd` | string | `""` | Command to execute inside the container | From a58d0cc3fa0bd3ae7823c7ec8eaa0203b4a9baef Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 15:05:18 +0800 Subject: [PATCH 15/52] add train controller --- areal/controller/train_controller.py | 256 ++++++++++++++++++++++ areal/controller/utils.py | 95 ++++++++ areal/scheduler/__init__.py | 0 areal/scheduler/local.py | 184 ++++++++++++++++ areal/utils/scheduler.py | 19 ++ examples/single-controller/gsm8k_sft.py | 185 ++++++++++++++++ examples/single-controller/gsm8k_sft.yaml | 96 ++++++++ 7 files changed, 835 insertions(+) create mode 100644 areal/controller/train_controller.py create mode 100644 areal/controller/utils.py create mode 100644 areal/scheduler/__init__.py create mode 100644 areal/scheduler/local.py create mode 100644 areal/utils/scheduler.py create mode 100644 examples/single-controller/gsm8k_sft.py create mode 100644 examples/single-controller/gsm8k_sft.yaml diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py new file mode 100644 index 000000000..c482e7c39 --- /dev/null +++ b/areal/controller/train_controller.py @@ -0,0 +1,256 @@ +from collections.abc import Callable +from concurrent.futures import Future, ThreadPoolExecutor +from functools import partial +from typing import Any + +import torch + +from areal.api.alloc_mode import ParallelStrategy +from areal.api.cli_args import TrainEngineConfig +from areal.api.controller_api import DistributedBatch, TrainController +from areal.api.engine_api import TrainEngine +from areal.api.io_struct import ( + AllocationMode, + FinetuneSpec, + ParamSpec, + SaveLoadMeta, + WeightUpdateMeta, +) +from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker +from areal.controller.batch import DistributedBatchMemory +from areal.controller.utils import create_engine_with_retry, rpc_call +from areal.utils import logging +from areal.utils.http import wait_future_ordered + +logger = logging.getLogger("DistributedTrainController") + + +class DistributedTrainController(TrainController): + def __init__( + self, train_engine: TrainEngine, config: TrainEngineConfig, scheduler: Scheduler + ): + super().__init__(train_engine, config, scheduler) + + self.role: str = "train" + self.group_size: int + self.alloc_mode: AllocationMode + self.workers: list[Worker] + self.engine_dp_ranks: list[int] + + def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): + assert self.workers is not None, "Workers are not created" + self.custom_function_call("create_process_group", parallel_strategy) + + def initialize( + self, + alloc_mode_str: str, + ft_spec: FinetuneSpec, + schedule_strategy: ScheduleStrategy, + **kwargs, + ): + """Initialize environments for distributed training and load models.""" + self.alloc_mode = AllocationMode.from_str(alloc_mode_str) + self.ft_spec = ft_spec + + # todo: group size is a sampling parameter and an attribute of the data, should be moved to DistributedBatch + self.group_size = kwargs.get("group_size", 1) + + job = Job( + replicas=self.alloc_mode.train.world_size, + tasks=self.train_engine.get_scheduling_config(), + schedule_strategy=schedule_strategy, + role=self.role, + ) + logger.info(f"Start to create job: {job}") + self.scheduler.create_workers(job) + # after get workers, all rpc server is ready + self.workers = self.scheduler.get_workers(self.role, timeout=1800) + + logger.info("Start to initialize engine") + with ThreadPoolExecutor(max_workers=len(self.workers)) as executor: + create_engine: Callable[..., Any] = partial( + create_engine_with_retry, + self.scheduler.create_engine, + 60, # max_retries + 10, # retry_delay + ) + futures: list[Future] = [ + executor.submit( + create_engine, + worker.id, + self.train_engine, + None, + self.ft_spec, + self.alloc_mode.train, + ) + for worker in self.workers + ] + try: + wait_future_ordered(futures, exit_on_exception=True) + except Exception as e: + logger.error(f"Failed to initialize engine: {e}") + raise + + logger.info("Start to get rank info from engine") + self.engine_dp_ranks = rpc_call( + self.scheduler, self.workers, "data_parallel_rank" + ) + logger.info("Initialize train engines succeeded!") + + def destroy(self): + self.scheduler.delete_workers() + + def train(self, mode: bool = True): + self.custom_function_call("train", mode) + + def upload_weights(self, meta: WeightUpdateMeta): + self.custom_function_call("upload_weights", meta) + + def get_param_specs( + self, weight_chunked_mem_mb: int = 1024 + ) -> list[list[ParamSpec]]: + ret: list[list[list[ParamSpec]]] = self.custom_function_call( + "get_param_specs", weight_chunked_mem_mb + ) + return ret[0] + + def set_version(self, version: int): + return self.custom_function_call("set_version", version) + + def get_version(self) -> int: + results = self.custom_function_call("get_version") + return results[0] + + def save(self, meta: SaveLoadMeta): + self.custom_function_call("save", meta) + + def load(self, meta: SaveLoadMeta): + self.custom_function_call("load", meta) + + def step_lr_scheduler(self): + self.custom_function_call("step_lr_scheduler") + + def custom_function_call(self, method: str, *args, **kwargs): + return rpc_call(self.scheduler, self.workers, method, None, *args, **kwargs) + + def custom_function_call_with_data( + self, method: str, input_: DistributedBatch, strict_order=False, *args, **kwargs + ): + if strict_order: + batches = self._align_batches_with_dp(input_, False) + else: + batches = self._align_batches_with_dp(input_, True) + + stats = rpc_call( + self.scheduler, + self.workers, + method, + batches, + *args, + **kwargs, + ) + return stats + + def _align_batches_with_dp( + self, input_: DistributedBatch, rebalance=True + ) -> list[DistributedBatch]: + if rebalance: + inputs = input_.chunk_by_ffd(self.group_size, self.alloc_mode.train.dp_size) + else: + inputs = input_.chunk(self.alloc_mode.train.dp_size) + + batches = [] + for dp_rank in self.engine_dp_ranks: + batches.append(inputs[dp_rank]) + + return batches + + def train_batch( + self, + input_: DistributedBatch, + loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], + loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], + ) -> dict[str, float]: + stats = self.custom_function_call_with_data( + "train_batch", input_, False, loss_fn, loss_weight_fn + ) + + return stats[0] + + def ppo_update( + self, + input_: DistributedBatch, + ) -> dict[str, float]: + stats = self.custom_function_call_with_data("ppo_update", input_, False) + + return stats[0] + + def eval_batch( + self, + input_: DistributedBatch, + loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], + loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], + ) -> torch.Tensor | None: + stats = self.custom_function_call_with_data( + "eval_batch", input_, False, loss_fn, loss_weight_fn + ) + + return stats[0] + + def compute_logp( + self, + input_: DistributedBatch, + *args, + **kwargs, + ): + logps = self.custom_function_call_with_data( + "compute_logp", input_, True, *args, **kwargs + ) + return logps + + def compute_advantages( + self, + input_: DistributedBatch, + *args, + **kwargs, + ): + advantages = self.custom_function_call_with_data( + "compute_advantages", input_, True, *args, **kwargs + ) + + return DistributedBatchMemory.concat(advantages) + + def train_lm( + self, + input_: DistributedBatch, + *args, + **kwargs, + ) -> dict[str, float]: + stats = self.custom_function_call_with_data( + "train_lm", input_, False, *args, **kwargs + ) + + return stats[0] + + def evaluate_lm( + self, + input_: DistributedBatch, + *args, + **kwargs, + ) -> torch.Tensor | None: + stats = self.custom_function_call_with_data( + "evaluate_lm", input_, False, *args, **kwargs + ) + return stats[0] + + def forward( + self, + input_: DistributedBatch, + output_seqlens: list[int] | None = None, + post_hook: Callable[[torch.Tensor, dict[str, Any]], Any] | None = None, + aggregate_fn: Callable[[list[Any]], Any] = torch.cat, + ) -> list[Any]: + stats = self.custom_function_call_with_data( + "forward", input_, True, output_seqlens, post_hook, aggregate_fn + ) + return stats[0] diff --git a/areal/controller/utils.py b/areal/controller/utils.py new file mode 100644 index 000000000..77064ed38 --- /dev/null +++ b/areal/controller/utils.py @@ -0,0 +1,95 @@ +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from requests.exceptions import ConnectionError + +from areal.api.scheduler_api import Scheduler, Worker +from areal.utils import logging +from areal.utils.http import wait_future_ordered + +logger = logging.getLogger("ControllerUtil") + + +def create_engine_with_retry( + create_engine_func, max_retries=60, retry_delay=10, *args, **kwargs +): + """ + Attempts to create an engine with retry logic. + :param create_engine_func: Callable function for creating the engine. + :param max_retries: Maximum number of retries before giving up. + :param retry_delay: Seconds to wait between retries. + :param args: Positional arguments to pass to create_engine_func. + :param kwargs: Keyword arguments to pass to create_engine_func. + :return: Engine instance created by create_engine_func. + :raises RuntimeError: If maximum retries are reached and connection still fails. + """ + logger.info( + f"Create engine with retry: {max_retries}, {retry_delay}, {args}, {kwargs}" + ) + retries = 0 + while retries < max_retries: + try: + return create_engine_func(*args, **kwargs) + except (ConnectionError, OSError) as e: + logger.info( + f"Worker is not ready, exception: {e}, retrying in {retry_delay} seconds..." + ) + time.sleep(retry_delay) + retries += 1 + except Exception as e: + logger.error(f"Connection failed: {e}. unknown exception") + raise e + + raise RuntimeError("Failed to connect to remote service after maximum retries.") + + +def rpc_call( + scheduler: Scheduler, + workers: list[Worker], + method: str, + batches: list[Any] | None = None, + *args, + **kwargs, +) -> list[Any]: + """ + Utility method: Perform concurrent RPC calls to multiple workers. + :param scheduler: Scheduler object with a call_engine(worker_id, method, *args, **kwargs) method. + :param workers: List of worker instances. Each worker must have an 'id' attribute. + :param method: Name of the method to invoke on each worker. + :param batches: Optional list of batches, each batch is passed to the corresponding worker. + If provided, its length must match the number of workers. + :param args: Positional arguments to pass to call_engine. + :param kwargs: Keyword arguments to pass to call_engine. + :return: List of results returned in the order of workers. + :raises ValueError: If the batches parameter is provided but its length does not match the number of workers. + :raises RuntimeError: If any exception occurs during RPC execution. + """ + + if batches is not None and len(batches) != len(workers): + raise ValueError( + f"Batches length ({len(batches)}) must match workers count ({len(workers)})" + ) + logger.info( + f"Start to rpc call, method: {method}, batches: {batches}, args: {args}, kwargs: {kwargs}" + ) + + with ThreadPoolExecutor(max_workers=len(workers)) as executor: + futures = [] + for i, worker in enumerate(workers): + if batches is not None: + worker_args = (batches[i], *args) + future = executor.submit( + scheduler.call_engine, worker.id, method, *worker_args, **kwargs + ) + else: + future = executor.submit( + scheduler.call_engine, worker.id, method, *args, **kwargs + ) + futures.append(future) + try: + results = wait_future_ordered(futures, exit_on_exception=True) + except Exception as e: + raise RuntimeError(f"{method} failed, error: {e}") + + return results diff --git a/areal/scheduler/__init__.py b/areal/scheduler/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/areal/scheduler/local.py b/areal/scheduler/local.py new file mode 100644 index 000000000..38ee0e64a --- /dev/null +++ b/areal/scheduler/local.py @@ -0,0 +1,184 @@ +import uuid +from collections import defaultdict + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import ( + BaseExperimentConfig, + ClusterSpecConfig, + LauncherConfig, + RecoverConfig, + SGLangConfig, + to_structured_cfg, +) +from areal.api.scheduler_api import Job, Scheduler, Worker +from areal.launcher.local import LocalLauncher +from areal.scheduler.rpc.rpc_client import RPCClient +from areal.utils import logging, name_resolve, names +from areal.utils.launcher import get_env_vars +from areal.utils.network import find_free_ports, gethostip +from areal.utils.recover import check_if_recover + +logger = logging.getLogger("LocalScheduler") + + +class LocalScheduler(Scheduler): + def __init__(self, config: BaseExperimentConfig): + self.procs = [] # Store subprocess objects + self.engine_workers: dict[str, list[str]] = defaultdict( + list + ) # role -> [worker_id] + self.rpc_client = RPCClient() + self.launcher = LocalLauncher( + config.experiment_name, config.trial_name, config.cluster.fileroot + ) + self.config = config + + def create_workers(self, job: Job, *args, **kwargs) -> None: + config = kwargs.get("config") + if job.role == "rollout": + self._create_rollout_workers(job, config) + return None + + replicas = job.replicas + master_port = find_free_ports(1, port_range=(10000, 50000))[0] + + for index in range(replicas): + for task in job.tasks: + ports = find_free_ports(task.port_count, port_range=(10000, 50000)) + envs = get_env_vars( + self.config.cluster.cluster_name, + ) + extra_envs = task.env_vars if task.env_vars else {} + extra_envs["PORT_LIST"] = ",".join(map(str, ports)) + envs.update(extra_envs) + + if job.role != "rollout": + # For non-rollout workers, set RANK and WORLD_SIZE + envs.update( + { + "RANK": index, + "LOCAL_RANK": 0, + "WORLD_SIZE": replicas, + "MASTER_ADDR": "localhost", + "MASTER_PORT": master_port, + "NCCL_CUMEM_ENABLE": "0", + "NCCL_NVLS_ENABLE": "0", + } + ) + self.launcher.submit( + job_name=f"{job.role}_worker", + cmd=f"{task.cmd} --role {job.role} --index {index}", + gpu=task.gpu, + env_vars=envs, + ) + + if task.type == "worker": + worker_id = f"worker_{uuid.uuid4().hex[:8]}" + self.rpc_client.register(worker_id, "localhost", ports[0]) + self.engine_workers.setdefault(job.role, []).append(worker_id) + + logger.info(f"Submitted {job.replicas} tasks for command: {task.cmd}") + return None + + def _create_rollout_workers(self, job: Job, config): + config.launcher = to_structured_cfg(config.launcher, LauncherConfig) + config.recover = to_structured_cfg(config.recover, RecoverConfig) + config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig) + is_recover_run = check_if_recover(config.recover, run_id=0) + name_resolve.reconfigure(config.cluster.name_resolve) + name_resolve.clear_subtree( + names.trial_root( + experiment_name=config.experiment_name, trial_name=config.trial_name + ) + ) + alloc_mode = AllocationMode.from_str(config.allocation_mode) + + if job.role == "rollout": + if alloc_mode.gen_backend == "sglang": + server_cmd = [] + server_addrs = [] + base_seed = config.sglang.random_seed + config.sglang = to_structured_cfg(config.sglang, SGLangConfig) + # each sglang need 2 ports + ports = find_free_ports( + alloc_mode.gen.dp_size * 2, port_range=(10000, 50000) + ) + host_ip = gethostip() + host = "localhost" if not config.sglang.enable_metrics else host_ip + for i in range(alloc_mode.gen.dp_size): + config.sglang.random_seed = base_seed + i + cmd = SGLangConfig.build_cmd( + config.sglang, + host=host, + tp_size=alloc_mode.gen.tp_size, + base_gpu_id=0, + port=ports[i * 2], + dist_init_addr=f"localhost:{ports[i * 2 + 1]}", + ) + server_cmd.append(cmd) + server_addrs.append(f"{host}:{ports[i * 2]}") + + # Launch inference servers. + self.launcher.submit_array( + job_name="rollout_server", + cmd=server_cmd, + count=alloc_mode.gen.dp_size, + gpu=alloc_mode.gen.pp_size * alloc_mode.gen.tp_size, + env_vars=get_env_vars( + config.cluster.cluster_name, + config.launcher.inference_server_env_vars, + ), + ) + logger.info( + f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}" + ) + + task = next((task for task in job.tasks if task.type == "worker"), None) + if task is not None: + for i in range(job.replicas): + ports = find_free_ports(task.port_count, port_range=(10000, 50000)) + envs = get_env_vars( + config.cluster.cluster_name, + ) + extra_envs = task.env_vars if task.env_vars else {} + extra_envs["PORT_LIST"] = ",".join(map(str, ports)) + extra_envs["AREAL_LLM_SERVER_ADDRS"] = ",".join(server_addrs) + extra_envs["AREAL_RECOVER_RUN"] = str(int(is_recover_run)) + envs.update(extra_envs) + self.launcher.submit( + job_name="rollout_worker", + cmd=f"{task.cmd} --role {job.role} --index {i}", + gpu=task.gpu, + env_vars=envs, + ) + worker_id = f"worker_{uuid.uuid4().hex[:8]}" + self.rpc_client.register(worker_id, "localhost", ports[0]) + self.engine_workers.setdefault(job.role, []).append(worker_id) + logger.info(f"Submitted {job.replicas} tasks for command: {task.cmd}") + + def get_workers(self, worker_role, timeout: float = 60.0) -> list[Worker]: + workers = [] + for worker_id in self.engine_workers.get(worker_role, []): + if not self.rpc_client.check_health(worker_id, timeout): + raise TimeoutError(f"Worker {worker_id} check health timeout") + ip, port = self.rpc_client.get_info(worker_id) + worker = Worker(id=worker_id, ip=ip, worker_ports=[str(port)]) + workers.append(worker) + return workers + + def delete_workers(self): + # TODO: Implement proper worker cleanup. This might involve calling a method + # on self.launcher to terminate the processes it has started. + # For now, this is a no-op to prevent crashes. + logger.warning( + "LocalScheduler.delete_workers is not implemented and is a no-op. Worker processes might not be cleaned up properly." + ) + + # Other methods remain the same + def create_engine(self, worker_id, engine_obj, *args, **kwargs): + # launch engine rpc server on the worker + self.rpc_client.create_engine(worker_id, engine_obj, *args, **kwargs) + + def call_engine(self, worker_id, method, *args, **kwargs): + ret = self.rpc_client.call_engine(worker_id, method, 3, *args, **kwargs) + return ret diff --git a/areal/utils/scheduler.py b/areal/utils/scheduler.py new file mode 100644 index 000000000..9fff0b4d9 --- /dev/null +++ b/areal/utils/scheduler.py @@ -0,0 +1,19 @@ +from areal.api.cli_args import SchedulingSpec +from areal.api.engine_api import Scheduling + + +def scheduling_specs_to_schedulings(specs: list[SchedulingSpec]) -> list[Scheduling]: + result = [] + for spec in specs: + sch = Scheduling( + cpu=spec.cpu, + gpu=spec.gpu, + mem=spec.mem, + port_count=spec.port_count, + container_image=spec.image, + type=spec.type, + env_vars=spec.env_vars, + cmd=spec.cmd, + ) + result.append(sch) + return result diff --git a/examples/single-controller/gsm8k_sft.py b/examples/single-controller/gsm8k_sft.py new file mode 100644 index 000000000..066dc576f --- /dev/null +++ b/examples/single-controller/gsm8k_sft.py @@ -0,0 +1,185 @@ +import sys + +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import SFTConfig, load_expr_config +from areal.api.io_struct import FinetuneSpec, StepInfo +from areal.api.scheduler_api import ScheduleStrategy +from areal.controller.batch import DistributedBatchMemory +from areal.controller.train_controller import DistributedTrainController +from areal.dataset import get_custom_dataset +from areal.engine.sft.lm_engine import FSDPLMEngine +from areal.scheduler.local import LocalScheduler +from areal.utils import logging, stats_tracker +from areal.utils.data import ( + pad_sequences_to_tensors, + tensor_container_to, +) +from areal.utils.evaluator import Evaluator +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.recover import RecoverHandler +from areal.utils.saver import Saver +from areal.utils.stats_logger import StatsLogger + +logger = logging.getLogger("Trainer") + + +def main(args): + config, _ = load_expr_config(args, SFTConfig) + config: SFTConfig + + AllocationMode.from_str(config.allocation_mode) + + engine = FSDPLMEngine(config=config.model) + + tokenizer = load_hf_tokenizer(config.tokenizer_path) + train_dataset = get_custom_dataset( + path=config.train_dataset.path, + rank=0, + world_size=1, + split="train", + max_length=config.train_dataset.max_length, + type=config.train_dataset.type, + tokenizer=tokenizer, + ) + valid_dataset = get_custom_dataset( + path=config.valid_dataset.path, + rank=0, + world_size=1, + split="test", + max_length=config.valid_dataset.max_length, + type=config.valid_dataset.type, + tokenizer=tokenizer, + ) + + # Create dataset and dataloaders + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=config.train_dataset.batch_size, + shuffle=config.train_dataset.shuffle, + num_workers=config.train_dataset.num_workers, + collate_fn=pad_sequences_to_tensors, + drop_last=config.train_dataset.drop_last, + ) + valid_dataloader = StatefulDataLoader( + valid_dataset, + batch_size=config.valid_dataset.batch_size, + shuffle=config.valid_dataset.shuffle, + num_workers=config.valid_dataset.num_workers, + collate_fn=pad_sequences_to_tensors, + drop_last=config.valid_dataset.drop_last, + ) + + # Initialize engine + ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=len(train_dataloader) * config.train_dataset.batch_size, + train_batch_size=config.train_dataset.batch_size, + ) + # Initialize scheduler + scheduler = LocalScheduler(config) + # Initialize train controller + train_controller = DistributedTrainController(engine, config.model, scheduler) + train_controller.initialize( + config.allocation_mode, + ft_spec, + ScheduleStrategy(), + ) + + # Run training. + saver = Saver(config.saver, ft_spec) + stats_logger = StatsLogger(config, ft_spec) + evaluator = Evaluator(config.evaluator, ft_spec) + + recover_handler = RecoverHandler(config.recover, ft_spec) + recover_info = recover_handler.load( + engine, + saver, + evaluator, + stats_logger, + train_dataloader, + ) + start_step = ( + recover_info.last_step_info.next().global_step + if recover_info is not None + else 0 + ) + + total_epochs = config.total_train_epochs + + global_step = 0 + for epoch in range(total_epochs): + for step, data in enumerate(train_dataloader): + if global_step < start_step: + global_step += 1 + continue + step_info = StepInfo( + global_step=global_step, + epoch=epoch, + epoch_step=step, + steps_per_epoch=len(train_dataloader), + ) + + with stats_tracker.record_timing("to_device"): + data = tensor_container_to(data, "cpu") + data = DistributedBatchMemory.from_dict(data) + + with ( + stats_tracker.record_timing("train_step"), + stats_tracker.scope("sft"), + ): + stat = train_controller.train_lm(data) + train_controller.step_lr_scheduler() + logger.info(f"train stat: {stat}") + + with stats_tracker.record_timing("save"): + saver.save( + train_controller, epoch, step, global_step, tokenizer=tokenizer + ) + + with stats_tracker.record_timing("checkpoint_for_recover"): + recover_handler.dump( + engine, + step_info, + saver, + evaluator, + stats_logger, + train_dataloader, + tokenizer=tokenizer, + ) + + with stats_tracker.record_timing("eval"): + + def evaluate_fn(): + with stats_tracker.scope("sft-eval"): + for data in valid_dataloader: + data = tensor_container_to(data, "cpu") + data = DistributedBatchMemory.from_dict(data) + train_controller.evaluate_lm(data) + + evaluator.evaluate( + evaluate_fn, + epoch, + step, + global_step, + ) + + stats = list() + # todo: gather stats from all ranks + stats.append(stat) + stats.append(stats_tracker.export_all()) + stats_logger.commit( + epoch, + step, + global_step, + stats, + ) + global_step += 1 + + stats_logger.close() + train_controller.destroy() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/single-controller/gsm8k_sft.yaml b/examples/single-controller/gsm8k_sft.yaml new file mode 100644 index 000000000..f28d65f2a --- /dev/null +++ b/examples/single-controller/gsm8k_sft.yaml @@ -0,0 +1,96 @@ +experiment_name: gsm8k-sft +trial_name: trial0 + +seed: 1 +total_train_epochs: 1 +tokenizer_path: ${model.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +allocation_mode: d8p1t1 + +model: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen3-1.7B + init_from_scratch: false + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 4096 + optimizer: + type: adam + lr: 2e-5 + weight_decay: 0.05 + beta1: 0.9 + beta2: 0.95 + eps: 1e-5 + lr_scheduler_type: cosine + gradient_clipping: 1.0 + backend: fsdp + scheduling_specs: + - type: worker + port_count: 1 + gpu: 1 + cmd: python3 -m areal.scheduler.rpc.rpc_server + +train_dataset: + batch_size: 128 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: sft + +valid_dataset: + batch_size: 128 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: sft + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +launcher: + inference_server_cpus_per_gpu: 4 + inference_server_mem_per_gpu: 32768 + trainer_cpus_per_gpu: 4 + trainer_mem_per_gpu: 32768 From b4c4eb654445791210f73c1588b8b7f78b107625 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 17:53:23 +0800 Subject: [PATCH 16/52] refactor train controller --- areal/api/cli_args.py | 33 +- areal/api/controller_api.py | 258 ------- areal/api/engine_api.py | 46 +- areal/api/scheduler_api.py | 15 +- areal/controller/rollout_controller.py | 34 +- areal/controller/train_controller.py | 758 +++++++++++++++---- areal/scheduler/local_scheduler.py | 85 +-- areal/tests/test_local_scheduler.py | 142 ++-- areal/tests/test_rollout_controller.py | 8 +- areal/utils/recover.py | 36 +- areal/utils/saver.py | 3 +- areal/utils/scheduler.py | 19 - docs/cli_reference.md | 80 +- realhf/api/core/system_api.py | 6 +- realhf/apps/main.py | 2 +- realhf/experiments/async_exp/async_rl_exp.py | 12 +- realhf/experiments/common/common.py | 6 +- realhf/system/controller.py | 6 +- 18 files changed, 869 insertions(+), 680 deletions(-) delete mode 100644 areal/utils/scheduler.py diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 2def9bc62..18aa2cb52 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -3,7 +3,7 @@ import os from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Any +from typing import Any, Literal import uvloop import yaml @@ -312,6 +312,12 @@ class MegatronEngineConfig: recompute_modules: list[str] | None = None +@dataclass +class ScheduleStrategy: + type: Literal["colocation", "separation"] = "separation" + target: str = "" + + @dataclass class SchedulingSpec: cpu: int = field(default=0, metadata={"help": "Number of CPU cores required"}) @@ -332,9 +338,20 @@ class SchedulingSpec: default_factory=dict, metadata={"help": "Environment variables for the container"}, ) - cmd: str = field( - default="", metadata={"help": "Command to execute inside the container"} + # cmd + cmd: str | None = field( + default=None, + metadata={ + "help": "Command to execute inside the container. Defaults to AReaL's RPC server." + }, ) + # slurm configurations from "https://slurm.schedmd.com/sbatch.html" + nodelist: str | None = None + exclude: str | None = None + partition: str | None = None + time_limit: str | None = None # see "--time" option for format + begin: str | None = None # see "--begin" option for format + deadline: str | None = None # see "--deadline" option for format @dataclass @@ -410,10 +427,11 @@ class TrainEngineConfig: default="lora", metadata={"help": "peft method type. Only LoRA is supported for now."}, ) - scheduling_specs: list[SchedulingSpec] = field( - default_factory=list, + scheduling_spec: SchedulingSpec = field( + default_factory=SchedulingSpec, metadata={"help": "train engine schedule specs"}, ) + scheduling_strategy: ScheduleStrategy = field(default_factory=ScheduleStrategy) @dataclass @@ -882,10 +900,11 @@ class InferenceEngineConfig: "help": "The grace period after calling /pause_generation. Wait until all requests have been dropped." }, ) - scheduling_specs: list[SchedulingSpec] = field( - default_factory=list, + scheduling_spec: SchedulingSpec = field( + default_factory=SchedulingSpec, metadata={"help": "inference engine schedule specs"}, ) + scheduling_strategy: ScheduleStrategy = field(default_factory=ScheduleStrategy) @dataclass diff --git a/areal/api/controller_api.py b/areal/api/controller_api.py index 0bf02b10a..efe1c0e64 100644 --- a/areal/api/controller_api.py +++ b/areal/api/controller_api.py @@ -1,18 +1,8 @@ import abc -from collections.abc import Callable from typing import Any import torch -from areal.api.alloc_mode import ParallelStrategy -from areal.api.cli_args import TrainEngineConfig -from areal.api.engine_api import InferenceEngine, TrainEngine -from areal.api.io_struct import ( - SaveLoadMeta, - WeightUpdateMeta, -) -from areal.api.scheduler_api import Scheduler - class DistributedBatch(abc.ABC): """Abstract base class for data exchange between controller and engine. @@ -192,251 +182,3 @@ def __setstate__(self, state): Dictionary containing the serialized state """ raise NotImplementedError() - - -class TrainController(abc.ABC): - """A centralized controller that manages multiple distributed TrainEngine workers. - - TrainController serves as a high-level orchestrator for distributed training across - multiple concurrent workers, each running TrainEngine instances. It provides a - unified interface for coordinating training operations while abstracting away the - complexities of inter-worker communication and data distribution. - - Key differences from TrainEngine: - - Operates at a higher abstraction level, managing multiple engine instances - - Does not directly perform collective communications (no rank and process group APIs) - - Uses `DistributedBatch` for data that spans multiple workers - - Provides centralized coordination for distributed training workflows - - The controller handles workload distribution, synchronization, and aggregation - of results from the underlying TrainEngine workers, enabling scalable and - efficient distributed training. - """ - - def __init__( - self, - train_engine: TrainEngine, - config: TrainEngineConfig, - scheduler: Scheduler, - ): - self.train_engine = train_engine - self.config = config - self.scheduler = scheduler - - def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): - """Initialize PyTorch distributed communication groups. - - Parameters - ---------- - parallel_strategy : ParallelStrategy, optional - The parallel strategy configuration for distributed training, by default None - """ - raise NotImplementedError() - - def initialize(self, *args, **kwargs): - """Initialize environments for distributed training and load models. - - This method should be called after `create_process_group`. - - Parameters - ---------- - *args - Variable length argument list - **kwargs - Arbitrary keyword arguments - """ - raise NotImplementedError() - - def destroy(self): - """Destroy the engine and release GPU memory of models.""" - raise NotImplementedError() - - def train(self, mode: bool = True): - """Set the engine to training mode. - - Parameters - ---------- - mode : bool, optional - Whether to set the engine to training mode, by default True - """ - raise NotImplementedError() - - def eval(self): - """Set the engine to evaluation mode. - - This is a convenience method that calls `self.train(False)`. - """ - return self.train(False) - - def update_weights(self, meta: WeightUpdateMeta): - """Update weights to the inference engine in a blocking manner. - - Parameters - ---------- - meta : WeightUpdateMeta - Metadata containing information about the weight update - """ - raise NotImplementedError() - - def connect_engine(self, engine: "InferenceEngine", meta: WeightUpdateMeta): - """Connect to an inference engine for online training. - - Parameters - ---------- - engine : InferenceEngine - The inference engine to connect to - """ - raise NotImplementedError() - - def set_version(self, version: int): - """Set the current weight version in the training engine. - - Parameters - ---------- - version : int - The weight version number to set - """ - raise NotImplementedError() - - def get_version(self) -> int: - """Get the current weight version in the training engine. - - Returns - ------- - int - The current weight version number - """ - raise NotImplementedError() - - def save(self, meta: SaveLoadMeta): - """Save model weights and optimizer states for later use. - - Parameters - ---------- - meta : SaveLoadMeta - Metadata containing information about where and how to save - """ - raise NotImplementedError() - - def load(self, meta: SaveLoadMeta): - """Load model weights and optimizer states from a file. - - Parameters - ---------- - meta : SaveLoadMeta - Metadata containing information about where and how to load - """ - raise NotImplementedError() - - def step_lr_scheduler(self): - """Step the learning rate scheduler. - - Since PPO uses minibatch updates, this method should be called periodically - (e.g., once per PPO step). It is separated from train_batch to allow - for more flexible learning rate scheduling. - """ - raise NotImplementedError() - - def train_batch( - self, - input_: DistributedBatch, - loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], - loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], - ) -> dict[str, float]: - """Update the model with a batch of data and a loss function. - - Note - ---- - The loss_fn should process packed 1D inputs, instead of 2D inputs. - - Parameters - ---------- - input_ : DistributedBatch - The distributed input data for model forward pass and the loss function. - Redundant entries are allowed. - loss_fn : Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor] - The loss function that takes the model's forward output and input_, - and outputs a scalar normalized loss. - loss_weight_fn : Callable[[Dict[str, Any]], torch.Tensor] - A function used to calculate the weight of each micro-batch. Since - loss_fn normalizes the loss for a micro-batch, we need a corresponding - weight for each micro-batch to normalize the loss globally. The weight - is usually the number of response tokens in the batch. - - Returns - ------- - Dict[str, float] - Scalar statistics after training, e.g., the current learning rate, - gradient norm, etc. - """ - raise NotImplementedError() - - @torch.no_grad() - def eval_batch( - self, - input_: DistributedBatch, - loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], - loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], - ) -> torch.Tensor | None: - """Evaluate the model using the forward pass and loss function. - - Note - ---- - The loss_fn should process packed 1D inputs, instead of 2D inputs. - - Parameters - ---------- - input_ : DistributedBatch - The distributed input data for model forward pass and the loss function. - Redundant entries are allowed. - loss_fn : Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor] - The loss function that takes the model's forward output and input_, - and outputs a scalar normalized loss. - loss_weight_fn : Callable[[Dict[str, Any]], torch.Tensor] - A function used to calculate the weight of each micro-batch. Since - loss_fn normalizes the loss for a micro-batch, we need a corresponding - weight for each micro-batch to normalize the loss globally. The weight - is usually the number of response tokens in the batch. - - Returns - ------- - torch.Tensor or None - A scalar loss or None. The evaluation statistics should be aggregated - with `stats_tracker`. - """ - raise NotImplementedError() - - @torch.no_grad() - def forward( - self, - input_: DistributedBatch, - output_seqlens: list[int] | None = None, - post_hook: Callable[[torch.Tensor, dict[str, Any]], Any] | None = None, - aggregate_fn: Callable[[list[Any]], Any] = torch.cat, - ) -> Any | None: - """Run the forward pass or inference on the model. - - Note - ---- - This operation is gradient-free. - - Parameters - ---------- - input_ : DistributedBatch - The distributed input data for model forward pass. Redundant entries are allowed. - output_seqlens : List[int], optional - The desired output sequence lengths. If None, assumes that the output - has the same lengths as inputs, by default None. - post_hook : Callable[[torch.Tensor, Dict[str, Any]], Any], optional - The post-processing function for micro-batched outputs. Post-processing - the output on-the-fly during micro-batched forward can reduce peak - memory usage, by default None. - aggregate_fn : Callable[[List[Any]], Any], optional - A function to aggregate micro-batched outputs, by default torch.cat. - - Returns - ------- - Any or None - The result produced by `post_hook` and `aggregate_fn`. - """ - raise NotImplementedError() diff --git a/areal/api/engine_api.py b/areal/api/engine_api.py index 322e83dd2..82575906e 100644 --- a/areal/api/engine_api.py +++ b/areal/api/engine_api.py @@ -1,8 +1,7 @@ import abc from collections.abc import Callable from concurrent.futures import Future -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Optional import torch import torch.distributed as dist @@ -21,25 +20,6 @@ from areal.api.workflow_api import RolloutWorkflow -@dataclass -class Scheduling: - cpu: int - gpu: int - mem: int - port_count: int - cmd: str | None = None - nodelist: str | None = None - exclude: str | None = None - partition: str | None = None - container_image: str | None = None - type: Literal["worker", "engine"] = None - env_vars: dict[str, str] = field(default_factory=dict) - # time utils from "https://slurm.schedmd.com/sbatch.html" - time_limit: str | None = None # see "--time" option for format - begin: str | None = None # see "--begin" option for format - deadline: str | None = None # see "--deadline" option for format - - class TrainEngine(abc.ABC): def configure(self, config): raise NotImplementedError() @@ -146,18 +126,6 @@ def parallelism_group(self) -> dist.ProcessGroup: """ raise NotImplementedError() - def get_scheduling_config(self) -> list[Scheduling]: - """Get the scheduling configuration for the engine. - - This includes configuration such as container image, CPU/GPU/memory size. - - Returns - ------- - Scheduling - The scheduling configuration for the engine - """ - raise NotImplementedError() - def destroy(self): """Destroy the engine and release GPU memory of models.""" @@ -605,15 +573,3 @@ def pause(self): def resume(self): """Resume request submission for async rollout.""" raise NotImplementedError() - - def get_scheduling_config(self) -> list[Scheduling]: - """Get the scheduling configuration for the engine. - - This includes configuration such as container image, CPU/GPU/memory size. - - Returns - ------- - List[Scheduling] - A list of scheduling configurations for the engine - """ - raise NotImplementedError() diff --git a/areal/api/scheduler_api.py b/areal/api/scheduler_api.py index 93b4bb77d..cd7685f94 100644 --- a/areal/api/scheduler_api.py +++ b/areal/api/scheduler_api.py @@ -1,8 +1,8 @@ import abc from dataclasses import dataclass, field -from typing import Any, Literal +from typing import Any -from areal.api.engine_api import Scheduling +from areal.api.cli_args import ScheduleStrategy, SchedulingSpec @dataclass @@ -23,16 +23,10 @@ class Worker: engine_ports: list[str] = field(default_factory=list) -@dataclass -class ScheduleStrategy: - type: Literal["colocation", "separation"] = "separation" - target: str = "" - - @dataclass class Job: replicas: int = 0 - tasks: list[Scheduling] = field(default_factory=list) + tasks: list[SchedulingSpec] = field(default_factory=list) schedule_strategy: ScheduleStrategy | None = None role: str = "" @@ -49,12 +43,11 @@ class Scheduler(abc.ABC): """ @abc.abstractmethod - def create_workers(self, role: str, job: Job, *args, **kwargs) -> list[str]: + def create_workers(self, job: Job, *args, **kwargs) -> list[str]: """ Create and start worker processes for a specific role. Args: - role: Role name for this group of workers (e.g., "rollout", "actor", "critic"). scheduler_config: Configuration specifying replicas, resources, and scheduling strategy. *args: Additional positional arguments (implementation-specific). **kwargs: Additional keyword arguments (implementation-specific). diff --git a/areal/controller/rollout_controller.py b/areal/controller/rollout_controller.py index fcc9db72d..31101643e 100644 --- a/areal/controller/rollout_controller.py +++ b/areal/controller/rollout_controller.py @@ -18,7 +18,7 @@ from areal.api.controller_api import DistributedBatch from areal.api.engine_api import InferenceEngine from areal.api.io_struct import ModelRequest, ModelResponse, ParamSpec, WeightUpdateMeta -from areal.api.scheduler_api import Scheduler, SchedulingConfig, Worker +from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker from areal.controller.batch import DistributedBatchMemory from areal.core.async_task_runner import AsyncTaskRunner, TaskQueueFullError from areal.core.staleness_manager import StalenessManager @@ -95,7 +95,7 @@ def __init__( # Worker management self.workers: list[Worker] = [] # List of Worker objects from scheduler - self._worker_role = "rollout" # Role name for workers + self._worker_role: str # Round-robin scheduling self._current_worker_idx = 0 @@ -120,7 +120,9 @@ def __init__( def initialize( self, + role: str, alloc_mode: AllocationMode, + schedule_strategy: ScheduleStrategy | None = None, *args, **kwargs, ): @@ -141,13 +143,21 @@ def initialize( self.logger = logging.getLogger("[RolloutController]") # Get scheduling config from kwargs or use defaults - # FIXME: Should get scheduling config in a more strategical way - scheduling_config = SchedulingConfig(replicas=alloc_mode.gen.dp_size) + self._worker_role = role + self.config.scheduling_spec.cpu *= alloc_mode.gen_instance_size + self.config.scheduling_spec.mem *= alloc_mode.gen_instance_size + self.config.scheduling_spec.gpu = alloc_mode.gen_instance_size + job = Job( + replicas=alloc_mode.gen.dp_size, + tasks=[self.config.scheduling_spec for _ in range(alloc_mode.gen.dp_size)], + schedule_strategy=schedule_strategy, + role=self._worker_role, + ) # Use asyncio.run to call async scheduler methods synchronously asyncio.run( self._async_initialize( - scheduling_config, + job, *args, **kwargs, ) @@ -175,23 +185,15 @@ def initialize( max_staleness=self.config.max_head_offpolicyness, ) - async def _async_initialize( - self, scheduling_config: SchedulingConfig, *args, **kwargs - ): + async def _async_initialize(self, job: Job, *args, **kwargs): # Create workers via scheduler self.logger.info("Creating workers via scheduler...") - worker_ids = self.scheduler.create_workers( - role=self._worker_role, - scheduler_config=scheduling_config, - ) + worker_ids = self.scheduler.create_workers(job=job) self.logger.info(f"Workers created: {worker_ids}") # Wait for workers to be ready self.logger.info("Waiting for workers to be ready...") - self.workers = self.scheduler.get_workers( - role=self._worker_role, - timeout=CREATE_WORKER_TIMEOUT, - ) + self.workers = self.scheduler.get_workers(role=job.role) self.logger.info(f"Workers ready: {[w.id for w in self.workers]}") # Get engine class path for dynamic import on workers diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index c482e7c39..7054a7acd 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -1,189 +1,575 @@ +import asyncio from collections.abc import Callable -from concurrent.futures import Future, ThreadPoolExecutor -from functools import partial from typing import Any import torch from areal.api.alloc_mode import ParallelStrategy from areal.api.cli_args import TrainEngineConfig -from areal.api.controller_api import DistributedBatch, TrainController -from areal.api.engine_api import TrainEngine +from areal.api.controller_api import DistributedBatch +from areal.api.engine_api import InferenceEngine, TrainEngine from areal.api.io_struct import ( AllocationMode, FinetuneSpec, - ParamSpec, SaveLoadMeta, WeightUpdateMeta, ) from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker from areal.controller.batch import DistributedBatchMemory -from areal.controller.utils import create_engine_with_retry, rpc_call from areal.utils import logging -from areal.utils.http import wait_future_ordered -logger = logging.getLogger("DistributedTrainController") +logger = logging.getLogger("TrainController") -class DistributedTrainController(TrainController): +class TrainController: + """A centralized controller that manages multiple distributed TrainEngine workers. + + TrainController serves as a high-level orchestrator for distributed training across + multiple concurrent workers, each running TrainEngine instances. It provides a + unified interface for coordinating training operations while abstracting away the + complexities of inter-worker communication and data distribution. + + Key differences from TrainEngine: + - Operates at a higher abstraction level, managing multiple engine instances + - Does not directly perform collective communications (no rank and process group APIs) + - Uses `DistributedBatch` for data that spans multiple workers + - Provides centralized coordination for distributed training workflows + + The controller handles workload distribution, synchronization, and aggregation + of results from the underlying TrainEngine workers, enabling scalable and + efficient distributed training. + + Parameters + ---------- + train_engine : type[TrainEngine] + The engine class (not instance) to instantiate on each worker + config : TrainEngineConfig + Configuration for training engines + scheduler : Scheduler + Scheduler for worker management + """ + def __init__( - self, train_engine: TrainEngine, config: TrainEngineConfig, scheduler: Scheduler + self, + train_engine: type[TrainEngine], + config: TrainEngineConfig, + scheduler: Scheduler, ): - super().__init__(train_engine, config, scheduler) + self.train_engine = train_engine + self.config = config + self.scheduler = scheduler - self.role: str = "train" self.group_size: int self.alloc_mode: AllocationMode - self.workers: list[Worker] - self.engine_dp_ranks: list[int] + self.workers: list[Worker] = [] + self.dp_head_workers: list[Worker] = [] # Only DP head workers + self.engine_dp_ranks: list[int] = [] # DP rank of each DP head worker + + self._worker_role: str + self.logger = None def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): + """Initialize PyTorch distributed communication groups. + + Parameters + ---------- + parallel_strategy : ParallelStrategy, optional + The parallel strategy configuration for distributed training, by default None + """ assert self.workers is not None, "Workers are not created" self.custom_function_call("create_process_group", parallel_strategy) def initialize( self, - alloc_mode_str: str, + role: str, + alloc_mode: AllocationMode, ft_spec: FinetuneSpec, schedule_strategy: ScheduleStrategy, **kwargs, ): - """Initialize environments for distributed training and load models.""" - self.alloc_mode = AllocationMode.from_str(alloc_mode_str) - self.ft_spec = ft_spec - + """Initialize environments for distributed training and load models. + + This method should be called after `create_process_group`. + + Parameters + ---------- + role : str + Role identifier for the workers + alloc_mode : AllocationMode + Allocation mode configuration for distributed setup + ft_spec : FinetuneSpec + Finetune specification for model initialization + schedule_strategy : ScheduleStrategy + Strategy for scheduling workers + **kwargs + Additional keyword arguments passed to engine initialization + """ + self.logger = logging.getLogger("[TrainController]") + + # Store configuration + self._worker_role = role + self.alloc_mode = alloc_mode # todo: group size is a sampling parameter and an attribute of the data, should be moved to DistributedBatch self.group_size = kwargs.get("group_size", 1) + # Create job for scheduler job = Job( - replicas=self.alloc_mode.train.world_size, - tasks=self.train_engine.get_scheduling_config(), + replicas=alloc_mode.train.world_size, + tasks=[ + self.config.scheduling_spec for _ in range(alloc_mode.train.world_size) + ], schedule_strategy=schedule_strategy, - role=self.role, + role=self._worker_role, ) - logger.info(f"Start to create job: {job}") - self.scheduler.create_workers(job) - # after get workers, all rpc server is ready - self.workers = self.scheduler.get_workers(self.role, timeout=1800) - - logger.info("Start to initialize engine") - with ThreadPoolExecutor(max_workers=len(self.workers)) as executor: - create_engine: Callable[..., Any] = partial( - create_engine_with_retry, - self.scheduler.create_engine, - 60, # max_retries - 10, # retry_delay + + # Create workers via scheduler + self.logger.info("Creating workers via scheduler...") + worker_ids = self.scheduler.create_workers(job=job) + self.logger.info(f"Workers created: {worker_ids}") + + # Wait for workers to be ready + self.logger.info("Waiting for workers to be ready...") + self.workers = self.scheduler.get_workers(role=job.role) + self.logger.info(f"Workers ready: {[w.id for w in self.workers]}") + + # Get engine class path for dynamic import on workers + engine_class = self.train_engine + engine_path = f"{engine_class.__module__}.{engine_class.__name__}" + + # Create and initialize engines on workers + asyncio.run( + self._async_create_and_initialize_engines(engine_path, ft_spec, **kwargs) + ) + + # Identify DP head workers + self._identify_dp_heads() + + self.logger.info("TrainController initialization complete") + + async def _async_create_and_initialize_engines( + self, engine_path: str, ft_spec: FinetuneSpec, **kwargs + ): + """Create and initialize engines on all workers.""" + # Create engines on workers + self.logger.info("Creating engines on workers...") + tasks = [ + self.scheduler.create_engine( + worker_id=worker.id, + engine=engine_path, + config=self.config, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + self.logger.info("Engines created on all workers!") + + # Initialize engines + self.logger.info("Calling engine initialization...") + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, + method="initialize", + addr=None, + ft_spec=ft_spec, + **kwargs, ) - futures: list[Future] = [ - executor.submit( - create_engine, - worker.id, - self.train_engine, - None, - self.ft_spec, - self.alloc_mode.train, + for worker in self.workers + ] + await asyncio.gather(*tasks) + self.logger.info("All engines are initialized!") + + def _identify_dp_heads(self): + """Identify which workers are DP heads by querying their DP rank.""" + self.logger.info("Identifying DP head workers...") + + # Query all workers for their DP rank + async def _get_dp_ranks(): + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, method="data_parallel_rank" ) for worker in self.workers ] - try: - wait_future_ordered(futures, exit_on_exception=True) - except Exception as e: - logger.error(f"Failed to initialize engine: {e}") - raise - - logger.info("Start to get rank info from engine") - self.engine_dp_ranks = rpc_call( - self.scheduler, self.workers, "data_parallel_rank" - ) - logger.info("Initialize train engines succeeded!") + return await asyncio.gather(*tasks) - def destroy(self): - self.scheduler.delete_workers() + dp_ranks = asyncio.run(_get_dp_ranks()) - def train(self, mode: bool = True): - self.custom_function_call("train", mode) + # Find unique DP ranks and corresponding head workers + seen_dp_ranks = set() + self.dp_head_workers = [] + self.engine_dp_ranks = [] - def upload_weights(self, meta: WeightUpdateMeta): - self.custom_function_call("upload_weights", meta) + for worker, dp_rank in zip(self.workers, dp_ranks): + if dp_rank not in seen_dp_ranks: + self.dp_head_workers.append(worker) + self.engine_dp_ranks.append(dp_rank) + seen_dp_ranks.add(dp_rank) - def get_param_specs( - self, weight_chunked_mem_mb: int = 1024 - ) -> list[list[ParamSpec]]: - ret: list[list[list[ParamSpec]]] = self.custom_function_call( - "get_param_specs", weight_chunked_mem_mb + self.logger.info( + f"Identified {len(self.dp_head_workers)} DP head workers " + f"from {len(self.workers)} total workers. " + f"DP ranks: {self.engine_dp_ranks}" ) - return ret[0] - def set_version(self, version: int): - return self.custom_function_call("set_version", version) + def destroy(self): + """Destroy the controller and release GPU memory of models. - def get_version(self) -> int: - results = self.custom_function_call("get_version") - return results[0] + Cleans up all resources including workers, engines, and internal state. + """ + self.logger.info("Destroying TrainController...") - def save(self, meta: SaveLoadMeta): - self.custom_function_call("save", meta) + # Delete workers via scheduler + try: + self.scheduler.delete_workers(role=self._worker_role) + self.logger.info("Workers deleted") + except Exception as e: + self.logger.error(f"Error deleting workers: {e}") - def load(self, meta: SaveLoadMeta): - self.custom_function_call("load", meta) + # Clear worker lists + self.workers.clear() + self.dp_head_workers.clear() + self.engine_dp_ranks.clear() - def step_lr_scheduler(self): - self.custom_function_call("step_lr_scheduler") + self.logger.info("TrainController destroyed") def custom_function_call(self, method: str, *args, **kwargs): - return rpc_call(self.scheduler, self.workers, method, None, *args, **kwargs) + """Dispatch method call to appropriate workers based on input type. + + If any argument is a DistributedBatch, split data and call only DP heads. + Otherwise, call all workers with the same arguments. + """ + # Check if any argument is a DistributedBatch + has_distributed_batch = any( + isinstance(arg, DistributedBatch) for arg in args + ) or any(isinstance(v, DistributedBatch) for v in kwargs.values()) + + if has_distributed_batch: + # Call ONLY DP heads with split data + return self._call_dp_heads_with_data_split(method, *args, **kwargs) + else: + # Call ALL workers (no data splitting needed) + return self._call_all_workers(method, *args, **kwargs) + + def _call_dp_heads_with_data_split(self, method: str, *args, **kwargs): + """Call only DP head workers with data split across DP groups.""" + # Find and split DistributedBatch arguments + split_args = [] + for arg in args: + if isinstance(arg, DistributedBatch): + # Split across DP groups + split_args.append(self._align_batches_with_dp(arg, rebalance=True)) + else: + # Replicate to all DP heads + split_args.append([arg] * len(self.dp_head_workers)) + + split_kwargs = {} + for k, v in kwargs.items(): + if isinstance(v, DistributedBatch): + split_kwargs[k] = self._align_batches_with_dp(v, rebalance=True) + else: + split_kwargs[k] = [v] * len(self.dp_head_workers) + + # Call ONLY DP head workers with their data slice + async def _call_all(): + tasks = [] + for idx, worker in enumerate(self.dp_head_workers): + # Get this worker's slice of each argument + worker_args = [splits[idx] for splits in split_args] + worker_kwargs = {k: splits[idx] for k, splits in split_kwargs.items()} + + # Convert DistributedBatch to dict for RPC + worker_args = [ + arg.get_data() if isinstance(arg, DistributedBatch) else arg + for arg in worker_args + ] + worker_kwargs = { + k: v.get_data() if isinstance(v, DistributedBatch) else v + for k, v in worker_kwargs.items() + } + + tasks.append( + self.scheduler.async_call_engine( + worker_id=worker.id, + method=method, + *worker_args, + **worker_kwargs, + ) + ) + return await asyncio.gather(*tasks) - def custom_function_call_with_data( - self, method: str, input_: DistributedBatch, strict_order=False, *args, **kwargs - ): - if strict_order: - batches = self._align_batches_with_dp(input_, False) + results = asyncio.run(_call_all()) + return self._merge_results(results, method) + + def _call_all_workers(self, method: str, *args, **kwargs): + """Call all workers with the same arguments (no data splitting).""" + + async def _call_all(): + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, method=method, *args, **kwargs + ) + for worker in self.workers + ] + return await asyncio.gather(*tasks) + + results = asyncio.run(_call_all()) + return self._merge_results(results, method) + + def _merge_results(self, results, method: str): + """Merge results from workers based on result type. + + - For None: return None + - For dict with scalar values: return first (already synchronized) + - For dict with tensor/batch values: concat as DistributedBatch + - For tensors/lists: concat as DistributedBatch + - For scalars: return first (already synchronized) + """ + # Filter out None results + non_none_results = [r for r in results if r is not None] + + if len(non_none_results) == 0: + return None + + first_result = non_none_results[0] + + # If all results are dicts + if isinstance(first_result, dict): + # Check if it's a dict of scalars (like train_batch stats) + if all(isinstance(v, (int, float)) for v in first_result.values()): + # Stats are already synchronized within engines - return first + return first_result + else: + # Dict of tensors/batches - concat as DistributedBatch + return DistributedBatchMemory.concat( + [DistributedBatchMemory.from_dict(r) for r in non_none_results] + ) + + # If result is a tensor or torch.Tensor + elif isinstance(first_result, torch.Tensor): + # Single tensor, likely already reduced - return first + return first_result + + # If result is a list/iterable (but not string) + elif hasattr(first_result, "__iter__") and not isinstance(first_result, str): + try: + # Try to concat as DistributedBatch + return DistributedBatchMemory.concat( + [ + DistributedBatchMemory.from_dict(r) + if isinstance(r, dict) + else r + for r in non_none_results + ] + ) + except Exception: + # If concat fails, return list of results + return non_none_results + + # For scalars (int, float, bool, etc.) else: - batches = self._align_batches_with_dp(input_, True) - - stats = rpc_call( - self.scheduler, - self.workers, - method, - batches, - *args, - **kwargs, - ) - return stats + # Return first (already synchronized) + return first_result def _align_batches_with_dp( self, input_: DistributedBatch, rebalance=True ) -> list[DistributedBatch]: + """Split DistributedBatch across DP groups. + + Returns a list of batches, one for each DP head worker. + """ if rebalance: inputs = input_.chunk_by_ffd(self.group_size, self.alloc_mode.train.dp_size) else: inputs = input_.chunk(self.alloc_mode.train.dp_size) + # Return batches corresponding to DP head ranks batches = [] for dp_rank in self.engine_dp_ranks: batches.append(inputs[dp_rank]) return batches - def train_batch( + # ==================== ENGINE RPC WRAPPERS ==================== + def train(self, mode: bool = True): + """Set the engine to training mode. + + Parameters + ---------- + mode : bool, optional + Whether to set the engine to training mode, by default True + + Returns + ------- + TrainController + Returns self for method chaining + """ + self.custom_function_call("train", mode) + return self + + def eval(self): + """Set the engine to evaluation mode. + + This is a convenience method that calls `self.train(False)`. + + Returns + ------- + TrainController + Returns self for method chaining + """ + return self.train(False) + + def update_weights(self, meta: WeightUpdateMeta): + """Update weights to the inference engine in a blocking manner. + + Parameters + ---------- + meta : WeightUpdateMeta + Metadata containing information about the weight update + """ + self.custom_function_call("update_weights", meta) + + def connect_engine(self, engine: InferenceEngine, meta: WeightUpdateMeta): + """Connect to an inference engine for online training. + + Parameters + ---------- + engine : InferenceEngine + The inference engine to connect to + meta : WeightUpdateMeta + Metadata for weight update configuration + + Raises + ------ + NotImplementedError + This method is not implemented for TrainController + """ + raise NotImplementedError( + "connect_engine is not implemented for TrainController. " + "Use RolloutController for online training workflows." + ) + + def set_version(self, version: int): + """Set the current weight version in the training engine. + + Parameters + ---------- + version : int + The weight version number to set + """ + self.custom_function_call("set_version", version) + + def get_version(self) -> int: + """Get the current weight version in the training engine. + + Returns + ------- + int + The current weight version number + """ + return self.custom_function_call("get_version") + + def save(self, meta: SaveLoadMeta): + """Save model weights and optimizer states for later use. + + Parameters + ---------- + meta : SaveLoadMeta + Metadata containing information about where and how to save + """ + self.custom_function_call("save", meta) + + def load(self, meta: SaveLoadMeta): + """Load model weights and optimizer states from a file. + + Parameters + ---------- + meta : SaveLoadMeta + Metadata containing information about where and how to load + """ + self.custom_function_call("load", meta) + + def step_lr_scheduler(self): + """Step the learning rate scheduler. + + Since PPO uses minibatch updates, this method should be called periodically + (e.g., once per PPO step). It is separated from train_batch to allow + for more flexible learning rate scheduling. + """ + self.custom_function_call("step_lr_scheduler") + + def forward( self, input_: DistributedBatch, - loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], - loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], - ) -> dict[str, float]: - stats = self.custom_function_call_with_data( - "train_batch", input_, False, loss_fn, loss_weight_fn + output_seqlens: list[int] | None = None, + post_hook: Callable[[torch.Tensor, dict[str, Any]], Any] | None = None, + aggregate_fn: Callable[[list[Any]], Any] = torch.cat, + ) -> Any | None: + """Run the forward pass or inference on the model. + + Note + ---- + This operation is gradient-free. + + Parameters + ---------- + input_ : DistributedBatch + The distributed input data for model forward pass. Redundant entries are allowed. + output_seqlens : List[int], optional + The desired output sequence lengths. If None, assumes that the output + has the same lengths as inputs, by default None. + post_hook : Callable[[torch.Tensor, Dict[str, Any]], Any], optional + The post-processing function for micro-batched outputs. Post-processing + the output on-the-fly during micro-batched forward can reduce peak + memory usage, by default None. + aggregate_fn : Callable[[List[Any]], Any], optional + A function to aggregate micro-batched outputs, by default torch.cat. + + Returns + ------- + Any or None + The result produced by `post_hook` and `aggregate_fn`. + """ + return self.custom_function_call( + "forward", + input_=input_, + output_seqlens=output_seqlens, + post_hook=post_hook, + aggregate_fn=aggregate_fn, ) - return stats[0] - - def ppo_update( + def train_batch( self, input_: DistributedBatch, + loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], + loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], ) -> dict[str, float]: - stats = self.custom_function_call_with_data("ppo_update", input_, False) - - return stats[0] + """Update the model with a batch of data and a loss function. + + Note + ---- + The loss_fn should process packed 1D inputs, instead of 2D inputs. + + Parameters + ---------- + input_ : DistributedBatch + The distributed input data for model forward pass and the loss function. + Redundant entries are allowed. + loss_fn : Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor] + The loss function that takes the model's forward output and input_, + and outputs a scalar normalized loss. + loss_weight_fn : Callable[[Dict[str, Any]], torch.Tensor] + A function used to calculate the weight of each micro-batch. Since + loss_fn normalizes the loss for a micro-batch, we need a corresponding + weight for each micro-batch to normalize the loss globally. The weight + is usually the number of response tokens in the batch. + + Returns + ------- + Dict[str, float] + Scalar statistics after training, e.g., the current learning rate, + gradient norm, etc. + """ + return self.custom_function_call("train_batch", input_, loss_fn, loss_weight_fn) def eval_batch( self, @@ -191,66 +577,140 @@ def eval_batch( loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], ) -> torch.Tensor | None: - stats = self.custom_function_call_with_data( - "eval_batch", input_, False, loss_fn, loss_weight_fn - ) - - return stats[0] - - def compute_logp( + """Evaluate the model using the forward pass and loss function. + + Note + ---- + The loss_fn should process packed 1D inputs, instead of 2D inputs. + + Parameters + ---------- + input_ : DistributedBatch + The distributed input data for model forward pass and the loss function. + Redundant entries are allowed. + loss_fn : Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor] + The loss function that takes the model's forward output and input_, + and outputs a scalar normalized loss. + loss_weight_fn : Callable[[Dict[str, Any]], torch.Tensor] + A function used to calculate the weight of each micro-batch. Since + loss_fn normalizes the loss for a micro-batch, we need a corresponding + weight for each micro-batch to normalize the loss globally. The weight + is usually the number of response tokens in the batch. + + Returns + ------- + torch.Tensor or None + A scalar loss or None. The evaluation statistics should be aggregated + with `stats_tracker`. + """ + return self.custom_function_call("eval_batch", input_, loss_fn, loss_weight_fn) + + # ==================== SFT RPC WRAPPERS ==================== + def train_lm( self, input_: DistributedBatch, *args, **kwargs, - ): - logps = self.custom_function_call_with_data( - "compute_logp", input_, True, *args, **kwargs - ) - return logps + ) -> dict[str, float]: + """Train language model across workers. + + Parameters + ---------- + input_ : DistributedBatch + The distributed input data for language model training + *args + Additional positional arguments passed to the engine + **kwargs + Additional keyword arguments passed to the engine + + Returns + ------- + Dict[str, float] + Scalar statistics after training + """ + return self.custom_function_call("train_lm", input_, *args, **kwargs) - def compute_advantages( + def evaluate_lm( self, input_: DistributedBatch, *args, **kwargs, - ): - advantages = self.custom_function_call_with_data( - "compute_advantages", input_, True, *args, **kwargs - ) - - return DistributedBatchMemory.concat(advantages) - - def train_lm( + ) -> torch.Tensor | None: + """Evaluate language model across workers. + + Parameters + ---------- + input_ : DistributedBatch + The distributed input data for language model evaluation + *args + Additional positional arguments passed to the engine + **kwargs + Additional keyword arguments passed to the engine + + Returns + ------- + torch.Tensor or None + A scalar loss or None + """ + return self.custom_function_call("evaluate_lm", input_, *args, **kwargs) + + # ==================== PPO RPC WRAPPERS ==================== + def compute_logp( self, - input_: DistributedBatch, *args, **kwargs, - ) -> dict[str, float]: - stats = self.custom_function_call_with_data( - "train_lm", input_, False, *args, **kwargs - ) - - return stats[0] + ): + """Compute log probabilities across workers. + + Parameters + ---------- + *args + Positional arguments passed to the engine + **kwargs + Keyword arguments passed to the engine + + Returns + ------- + Any + Log probabilities computed by the engine + """ + return self.custom_function_call("compute_logp", *args, **kwargs) - def evaluate_lm( + def compute_advantages( self, - input_: DistributedBatch, *args, **kwargs, - ) -> torch.Tensor | None: - stats = self.custom_function_call_with_data( - "evaluate_lm", input_, False, *args, **kwargs - ) - return stats[0] + ): + """Compute advantages across workers. + + Parameters + ---------- + *args + Positional arguments passed to the engine + **kwargs + Keyword arguments passed to the engine + + Returns + ------- + Any + Advantages computed by the engine + """ + return self.custom_function_call("compute_advantages", *args, **kwargs) - def forward( + def ppo_update( self, input_: DistributedBatch, - output_seqlens: list[int] | None = None, - post_hook: Callable[[torch.Tensor, dict[str, Any]], Any] | None = None, - aggregate_fn: Callable[[list[Any]], Any] = torch.cat, - ) -> list[Any]: - stats = self.custom_function_call_with_data( - "forward", input_, True, output_seqlens, post_hook, aggregate_fn - ) - return stats[0] + ) -> dict[str, float]: + """Perform PPO update step with the given batch. + + Parameters + ---------- + input_ : DistributedBatch + The distributed input data containing trajectories for PPO update + + Returns + ------- + Dict[str, float] + Scalar statistics after PPO update + """ + return self.custom_function_call("ppo_update", input_) diff --git a/areal/scheduler/local_scheduler.py b/areal/scheduler/local_scheduler.py index ef3159ec4..d29747a6d 100644 --- a/areal/scheduler/local_scheduler.py +++ b/areal/scheduler/local_scheduler.py @@ -13,7 +13,7 @@ import orjson import psutil -from areal.api.scheduler_api import ContainerSpec, Scheduler, SchedulingConfig, Worker +from areal.api.scheduler_api import Job, Scheduler, SchedulingSpec, Worker from areal.scheduler.exceptions import ( EngineCallError, EngineCreationError, @@ -194,50 +194,47 @@ def _allocate_ports(self, count: int) -> list[int]: raise PortAllocationError(str(e)) from e def _prepare_worker_specs( - self, role: str, num_workers: int, specs: list[ContainerSpec] | None - ) -> list[ContainerSpec]: + self, role: str, num_workers: int, schedulings: list[SchedulingSpec] | None + ) -> list[SchedulingSpec]: """ Prepare worker specs for a given number of workers. Args: role: Worker role name num_workers: Number of workers to create - specs: Optional list of specs + schedulings: Optional list of scheduling specs Returns: - List of ContainerSpec objects (one per worker) + List of SchedulingSpec objects (one per worker) Raises: - WorkerCreationError: If specs configuration is invalid + WorkerCreationError: If schedulings configuration is invalid """ - if not specs: - # Default spec: 1 GPU, 2 ports - return [ContainerSpec(gpu=1, port_count=2)] * num_workers + if not schedulings: + # Default spec: 1 CPU, 1024 MB mem, 1 GPU, 2 ports + return [SchedulingSpec(cpu=1, mem=1024, gpu=1, port_count=2)] * num_workers # If a single spec is provided, use it for all workers - if len(specs) == 1: - return [specs[0]] * num_workers + if len(schedulings) == 1: + return [schedulings[0]] * num_workers # If per-worker specs, validate length matches - if len(specs) == num_workers: - return specs + if len(schedulings) == num_workers: + return schedulings # Invalid configuration raise WorkerCreationError( role, "Invalid configuration", - f"specs length ({len(specs)}) must be 1 or equal to replicas ({num_workers})", + f"schedulings length ({len(schedulings)}) must be 1 or equal to replicas ({num_workers})", ) - def create_workers( - self, role: str, scheduler_config: SchedulingConfig, *args, **kwargs - ) -> list[str]: + def create_workers(self, job: Job, *args, **kwargs) -> list[str]: """ Create worker subprocesses. Args: - role: Role name for this group of workers (e.g., "rollout", "actor", "critic") - scheduler_config: Scheduling configuration with replicas, specs, and strategy + job: Job configuration with role, replicas, tasks, and scheduling strategy *args: Additional arguments passed to worker command **kwargs: Additional keyword arguments @@ -249,6 +246,7 @@ def create_workers( GPUAllocationError: If GPU allocation fails PortAllocationError: If port allocation fails """ + role = job.role if role in self._workers: raise WorkerCreationError( role, @@ -257,23 +255,23 @@ def create_workers( ) # Extract configuration - num_workers = scheduler_config.replicas + num_workers = job.replicas if num_workers == 0: raise WorkerCreationError( role, "Invalid configuration", "replicas must be greater than 0" ) # Prepare worker specs - specs = self._prepare_worker_specs(role, num_workers, scheduler_config.specs) + schedulings = self._prepare_worker_specs(role, num_workers, job.tasks) # Determine scheduling strategy - strategy = scheduler_config.schedule_strategy + strategy = job.schedule_strategy if strategy is None: - strategy_type = "new" + strategy_type = "separation" colocate_role = None else: - strategy_type = strategy.type or "new" - colocate_role = strategy.uid if strategy_type == "colocate" else None + strategy_type = strategy.type or "separation" + colocate_role = strategy.target if strategy_type == "colocation" else None logger.info( f"Creating {num_workers} workers for role '{role}' " @@ -285,29 +283,29 @@ def create_workers( try: for idx in range(num_workers): worker_id = f"{role}/{idx}" - spec = specs[idx] + scheduling = schedulings[idx] # Allocate resources based on strategy try: # GPU allocation - if strategy_type == "colocate": + if strategy_type == "colocation": if not colocate_role: raise WorkerCreationError( role, "Invalid strategy", - "Colocate strategy requires uid (target role) to be specified", + "Colocation strategy requires target role to be specified", ) gpu_devices = self._get_colocated_gpus(colocate_role, idx) logger.debug( f"Worker {worker_id} colocated with {colocate_role}/{idx} on GPUs {gpu_devices}" ) - else: # "new" or default - gpu_devices = self._allocate_gpus(spec.gpu) + else: # "separation" or default + gpu_devices = self._allocate_gpus(scheduling.gpu) logger.debug( f"Worker {worker_id} allocated new GPUs {gpu_devices}" ) - ports = self._allocate_ports(spec.port_count) + ports = self._allocate_ports(scheduling.port_count) except ( GPUAllocationError, PortAllocationError, @@ -325,17 +323,17 @@ def create_workers( env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_devices)) env["WORKER_ID"] = worker_id - # Merge user-provided environment variables from spec - if spec.env_vars: - env.update(spec.env_vars) + # Merge user-provided environment variables from scheduling + if scheduling.env_vars: + env.update(scheduling.env_vars) # Prepare log file log_file = self.log_dir / f"{worker_id.replace('/', '_')}.log" # Build command to start RPC server - if spec.cmd: - # Use custom command from spec - cmd = shlex.split(spec.cmd) + if scheduling.cmd: + # Use custom command from scheduling + cmd = shlex.split(scheduling.cmd) else: # Default: start RPC server cmd = [ @@ -385,7 +383,8 @@ def create_workers( worker = Worker( id=worker_id, ip=gethostip(), - ports=[str(p) for p in ports], + worker_ports=[str(p) for p in ports], + engine_ports=[], ) worker_info = WorkerInfo( @@ -482,7 +481,7 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: def _is_worker_ready(self, worker_info: WorkerInfo) -> bool: """Check if worker's RPC server is ready via HTTP health check.""" - port = int(worker_info.worker.ports[0]) + port = int(worker_info.worker.worker_ports[0]) url = f"http://{worker_info.worker.ip}:{port}/health" try: @@ -544,7 +543,7 @@ def _cleanup_workers(self, workers: list[WorkerInfo]): for worker_info in workers: try: # Release ports - for port_str in worker_info.worker.ports: + for port_str in worker_info.worker.worker_ports: self._allocated_ports.discard(int(port_str)) # Terminate process tree @@ -645,7 +644,7 @@ async def create_engine( } # Send HTTP request to create engine - port = int(worker_info.worker.ports[0]) + port = int(worker_info.worker.worker_ports[0]) url = f"http://{worker_info.worker.ip}:{port}/create_engine" try: @@ -745,7 +744,7 @@ def call_engine( } # Retry logic with exponential backoff - port = int(worker_info.worker.ports[0]) + port = int(worker_info.worker.worker_ports[0]) url = f"http://{worker_info.worker.ip}:{port}/call" last_error = None @@ -838,7 +837,7 @@ async def async_call_engine( raise WorkerNotFoundError(worker_id) # Route to different endpoint based on method - port = int(worker_info.worker.ports[0]) + port = int(worker_info.worker.worker_ports[0]) if method == "run_workflow": # Special routing for workflow execution url = f"http://{worker_info.worker.ip}:{port}/run_workflow" diff --git a/areal/tests/test_local_scheduler.py b/areal/tests/test_local_scheduler.py index dd8d52e6d..8ecdafe25 100644 --- a/areal/tests/test_local_scheduler.py +++ b/areal/tests/test_local_scheduler.py @@ -8,9 +8,9 @@ import pytest from areal.api.scheduler_api import ( - ContainerSpec, + Job, ScheduleStrategy, - SchedulingConfig, + SchedulingSpec, Worker, ) from areal.scheduler.exceptions import ( @@ -94,7 +94,7 @@ def create_worker_info( process = create_mock_process() return WorkerInfo( - worker=Worker(id=worker_id, ip=ip, ports=ports), + worker=Worker(id=worker_id, ip=ip, worker_ports=ports, engine_ports=[]), process=process, role=role, gpu_devices=gpu_devices, @@ -335,8 +335,8 @@ def test_create_workers_with_default_spec( scheduler = LocalScheduler(gpu_devices=[0, 1], log_dir=str(tmp_path)) - config = SchedulingConfig(replicas=2, role="rollout") - worker_ids = scheduler.create_workers("rollout", config) + job = Job(replicas=2, role="rollout") + worker_ids = scheduler.create_workers(job) assert worker_ids == ["rollout/0", "rollout/1"] assert "rollout" in scheduler._workers @@ -366,19 +366,19 @@ def test_create_workers_with_single_spec_for_all( scheduler = LocalScheduler(gpu_devices=[0, 1, 2], log_dir=str(tmp_path)) - config = SchedulingConfig( + job = Job( replicas=3, role="actor", - specs=[ContainerSpec(gpu=2, port_count=3)], + tasks=[SchedulingSpec(cpu=1, mem=1024, gpu=2, port_count=3)], ) - worker_ids = scheduler.create_workers("actor", config) + worker_ids = scheduler.create_workers(job) assert len(worker_ids) == 3 assert mock_popen.call_count == 3 # All workers should use the same spec for worker_info in scheduler._workers["actor"]: - assert len(worker_info.worker.ports) == 3 + assert len(worker_info.worker.worker_ports) == 3 @patch("areal.scheduler.local_scheduler.gethostip") @patch("areal.scheduler.local_scheduler.subprocess.Popen") @@ -401,19 +401,19 @@ def test_create_workers_with_per_worker_specs( scheduler = LocalScheduler(gpu_devices=[0, 1], log_dir=str(tmp_path)) - config = SchedulingConfig( + job = Job( replicas=2, role="critic", - specs=[ - ContainerSpec(gpu=1, port_count=1), - ContainerSpec(gpu=1, port_count=2), + tasks=[ + SchedulingSpec(cpu=1, mem=1024, gpu=1, port_count=1), + SchedulingSpec(cpu=1, mem=1024, gpu=1, port_count=2), ], ) - worker_ids = scheduler.create_workers("critic", config) + worker_ids = scheduler.create_workers(job) assert len(worker_ids) == 2 - assert len(scheduler._workers["critic"][0].worker.ports) == 1 - assert len(scheduler._workers["critic"][1].worker.ports) == 2 + assert len(scheduler._workers["critic"][0].worker.worker_ports) == 1 + assert len(scheduler._workers["critic"][1].worker.worker_ports) == 2 @patch("areal.scheduler.local_scheduler.gethostip") @patch("areal.scheduler.local_scheduler.subprocess.Popen") @@ -432,16 +432,20 @@ def test_create_workers_with_custom_command( scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) - config = SchedulingConfig( + job = Job( replicas=1, role="custom", - specs=[ - ContainerSpec( - gpu=1, port_count=2, cmd="python my_custom_server.py --port 8000" + tasks=[ + SchedulingSpec( + cpu=1, + mem=1024, + gpu=1, + port_count=2, + cmd="python my_custom_server.py --port 8000", ) ], ) - worker_ids = scheduler.create_workers("custom", config) + worker_ids = scheduler.create_workers(job) assert len(worker_ids) == 1 @@ -467,18 +471,20 @@ def test_create_workers_with_environment_variables( scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) - config = SchedulingConfig( + job = Job( replicas=1, role="envtest", - specs=[ - ContainerSpec( + tasks=[ + SchedulingSpec( + cpu=1, + mem=1024, gpu=1, port_count=2, env_vars={"CUSTOM_VAR": "custom_value", "ANOTHER_VAR": "123"}, ) ], ) - worker_ids = scheduler.create_workers("envtest", config) + worker_ids = scheduler.create_workers(job) assert len(worker_ids) == 1 @@ -511,10 +517,12 @@ def test_create_workers_with_colocate_strategy( scheduler = LocalScheduler(gpu_devices=[0, 1, 2, 3], log_dir=str(tmp_path)) # Create target workers (actors) - actor_config = SchedulingConfig( - replicas=2, role="actor", specs=[ContainerSpec(gpu=2, port_count=2)] + actor_job = Job( + replicas=2, + role="actor", + tasks=[SchedulingSpec(cpu=1, mem=1024, gpu=2, port_count=2)], ) - scheduler.create_workers("actor", actor_config) + scheduler.create_workers(actor_job) # Get GPU allocations for actors actor_gpus_0 = scheduler._workers["actor"][0].gpu_devices @@ -525,13 +533,13 @@ def test_create_workers_with_colocate_strategy( mock_find_ports.return_value = [8010, 8011] # Create colocated workers (critics) - critic_config = SchedulingConfig( + critic_job = Job( replicas=2, role="critic", - specs=[ContainerSpec(gpu=2, port_count=2)], - schedule_strategy=ScheduleStrategy(type="colocate", uid="actor"), + tasks=[SchedulingSpec(cpu=1, mem=1024, gpu=2, port_count=2)], + schedule_strategy=ScheduleStrategy(type="colocation", target="actor"), ) - critic_ids = scheduler.create_workers("critic", critic_config) + critic_ids = scheduler.create_workers(critic_job) assert len(critic_ids) == 2 @@ -558,12 +566,12 @@ def test_create_workers_duplicate_role_error(self, tmp_path): mock_proc.poll.return_value = None mock_popen.return_value = mock_proc - config = SchedulingConfig(replicas=1, role="test") - scheduler.create_workers("test", config) + job = Job(replicas=1, role="test") + scheduler.create_workers(job) # Try to create again with pytest.raises(WorkerCreationError) as exc_info: - scheduler.create_workers("test", config) + scheduler.create_workers(job) assert "Worker group already exists" in str(exc_info.value) assert exc_info.value.worker_key == "test" @@ -572,30 +580,30 @@ def test_create_workers_zero_replicas_error(self, tmp_path): """Should raise WorkerCreationError when replicas is 0.""" scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) - config = SchedulingConfig(replicas=0, role="test") + job = Job(replicas=0, role="test") with pytest.raises(WorkerCreationError) as exc_info: - scheduler.create_workers("test", config) + scheduler.create_workers(job) assert "replicas must be greater than 0" in str(exc_info.value) def test_create_workers_invalid_specs_length(self, tmp_path): - """Should raise WorkerCreationError when specs length is invalid.""" + """Should raise WorkerCreationError when tasks length is invalid.""" scheduler = LocalScheduler(gpu_devices=[0, 1], log_dir=str(tmp_path)) - config = SchedulingConfig( + job = Job( replicas=3, role="test", - specs=[ - ContainerSpec(gpu=1, port_count=2), - ContainerSpec(gpu=1, port_count=2), - ], # 2 specs for 3 replicas + tasks=[ + SchedulingSpec(cpu=1, mem=1024, gpu=1, port_count=2), + SchedulingSpec(cpu=1, mem=1024, gpu=1, port_count=2), + ], # 2 tasks for 3 replicas ) with pytest.raises(WorkerCreationError) as exc_info: - scheduler.create_workers("test", config) + scheduler.create_workers(job) - assert "specs length (2) must be 1 or equal to replicas (3)" in str( + assert "schedulings length (2) must be 1 or equal to replicas (3)" in str( exc_info.value ) @@ -622,13 +630,13 @@ def test_create_workers_subprocess_fails_immediately( scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) - config = SchedulingConfig(replicas=1, role="test") + job = Job(replicas=1, role="test") with patch.object( scheduler, "_read_log_tail", return_value="Error: Failed to start server" ): with pytest.raises(WorkerCreationError) as exc_info: - scheduler.create_workers("test", config) + scheduler.create_workers(job) assert "exited immediately with code 1" in str(exc_info.value) @@ -653,31 +661,33 @@ def test_create_workers_cleanup_on_partial_failure( scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) - config = SchedulingConfig(replicas=2, role="test") + job = Job(replicas=2, role="test") with patch.object(scheduler, "_cleanup_workers") as mock_cleanup: with pytest.raises(WorkerCreationError) as exc_info: - scheduler.create_workers("test", config) + scheduler.create_workers(job) # Verify cleanup was called assert mock_cleanup.called assert "Resource allocation failed" in str(exc_info.value) - def test_create_workers_colocate_strategy_missing_uid(self, tmp_path): - """Should raise WorkerCreationError when colocate strategy is missing target role uid.""" + def test_create_workers_colocate_strategy_missing_target(self, tmp_path): + """Should raise WorkerCreationError when colocation strategy is missing target role.""" scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) - config = SchedulingConfig( + job = Job( replicas=1, role="test", - specs=[ContainerSpec(gpu=1, port_count=2)], - schedule_strategy=ScheduleStrategy(type="colocate", uid=""), # Missing uid + tasks=[SchedulingSpec(cpu=1, mem=1024, gpu=1, port_count=2)], + schedule_strategy=ScheduleStrategy( + type="colocation", target="" + ), # Missing target ) with pytest.raises(WorkerCreationError) as exc_info: - scheduler.create_workers("test", config) + scheduler.create_workers(job) - assert "Colocate strategy requires uid" in str(exc_info.value) + assert "Colocation strategy requires target" in str(exc_info.value) class TestGetWorkers: @@ -1499,8 +1509,8 @@ def test_worker_id_format( scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) - config = SchedulingConfig(replicas=5, role="worker") - worker_ids = scheduler.create_workers("worker", config) + job = Job(replicas=5, role="worker") + worker_ids = scheduler.create_workers(job) assert worker_ids == [ "worker/0", @@ -1555,8 +1565,8 @@ async def test_run_workflow_endpoint_basic(self, tmp_path): scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) try: - config = SchedulingConfig(replicas=1) - worker_ids = scheduler.create_workers(role="test", scheduler_config=config) + job = Job(replicas=1, role="test") + worker_ids = scheduler.create_workers(job) assert len(worker_ids) == 1 workers = scheduler.get_workers(role="test", timeout=30.0) @@ -1594,8 +1604,8 @@ async def test_run_workflow_endpoint_multiple_calls(self, tmp_path): scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) try: - config = SchedulingConfig(replicas=1) - scheduler.create_workers(role="test", scheduler_config=config) + job = Job(replicas=1, role="test") + scheduler.create_workers(job) workers = scheduler.get_workers(role="test", timeout=30.0) worker_id = workers[0].id @@ -1629,8 +1639,8 @@ async def test_run_workflow_serialization(self, tmp_path): scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) try: - config = SchedulingConfig(replicas=1) - scheduler.create_workers(role="test", scheduler_config=config) + job = Job(replicas=1, role="test") + scheduler.create_workers(job) workers = scheduler.get_workers(role="test", timeout=30.0) worker_id = workers[0].id @@ -1670,8 +1680,8 @@ async def test_run_workflow_with_kwargs(self, tmp_path): scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) try: - config = SchedulingConfig(replicas=1) - scheduler.create_workers(role="test", scheduler_config=config) + job = Job(replicas=1, role="test") + scheduler.create_workers(job) workers = scheduler.get_workers(role="test", timeout=30.0) worker_id = workers[0].id diff --git a/areal/tests/test_rollout_controller.py b/areal/tests/test_rollout_controller.py index c02dcab41..afc028fd3 100644 --- a/areal/tests/test_rollout_controller.py +++ b/areal/tests/test_rollout_controller.py @@ -24,7 +24,13 @@ def __init__(self): def create_workers(self, role, scheduler_config, *args, **kwargs): worker_ids = [f"{role}/{i}" for i in range(scheduler_config.replicas)] self.workers = [ - Worker(id=wid, ip="127.0.0.1", ports=["8000", "8001"]) for wid in worker_ids + Worker( + id=wid, + ip="127.0.0.1", + worker_ports=["8000", "8001"], + engine_ports=["9000", "9001"], + ) + for wid in worker_ids ] return worker_ids diff --git a/areal/utils/recover.py b/areal/utils/recover.py index c936a26fd..3c62c7d89 100644 --- a/areal/utils/recover.py +++ b/areal/utils/recover.py @@ -2,16 +2,16 @@ import json import os import pickle -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING import torch.distributed as dist from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoProcessor, PreTrainedTokenizerFast from areal.api.cli_args import RecoverConfig -from areal.api.controller_api import TrainController from areal.api.engine_api import InferenceEngine, TrainEngine from areal.api.io_struct import FinetuneSpec, SaveLoadMeta, StepInfo, WeightUpdateMeta +from areal.controller.train_controller import TrainController from areal.utils import logging, timeutil from areal.utils.evaluator import Evaluator from areal.utils.saver import Saver @@ -32,11 +32,11 @@ class RecoverInfo: # Recover will start from the next iteration, obtained by `last_step_info.next()`. last_step_info: StepInfo - saver_info: Dict - evaluator_info: Dict - stats_logger_info: Dict - dataloader_info: Dict | List[Dict] - checkpoint_info: Dict + saver_info: dict + evaluator_info: dict + stats_logger_info: dict + dataloader_info: dict | list[dict] + checkpoint_info: dict def dump(self, dump_dir: str): # Dumps the recover info to multiple files in `dump_dir`: @@ -91,24 +91,24 @@ def load(cls, load_dir: str): try: step_info_path = os.path.join(load_dir, "step_info.json") - with open(step_info_path, "r") as f: + with open(step_info_path) as f: step_info_dict = json.load(f) last_step_info = StepInfo(**step_info_dict) evaluator_info_path = os.path.join(load_dir, "evaluator_info.json") - with open(evaluator_info_path, "r") as f: + with open(evaluator_info_path) as f: evaluator_info = json.load(f) saver_info_path = os.path.join(load_dir, "saver_info.json") - with open(saver_info_path, "r") as f: + with open(saver_info_path) as f: saver_info = json.load(f) stats_logger_info_path = os.path.join(load_dir, "stats_logger_info.json") - with open(stats_logger_info_path, "r") as f: + with open(stats_logger_info_path) as f: stats_logger_info = json.load(f) checkpoint_info_path = os.path.join(load_dir, "checkpoint_info.json") - with open(checkpoint_info_path, "r") as f: + with open(checkpoint_info_path) as f: checkpoint_info = json.load(f) dataloader_info_path = os.path.join(load_dir, "dataloader_info.pkl") @@ -161,12 +161,12 @@ def recover_info_path( ): return os.path.join( Saver.get_save_root(experiment_name, trial_name, fileroot), - f"recover_info", + "recover_info", ) def dump( self, - engine: TrainEngine | Dict[str, TrainEngine], + engine: TrainEngine | dict[str, TrainEngine], step_info: StepInfo, saver: Saver, evaluator: Evaluator, @@ -214,7 +214,7 @@ def dump( def load( self, - engine: TrainEngine | Dict[str, TrainEngine] | TrainController, + engine: TrainEngine | dict[str, TrainEngine] | TrainController, saver: Saver, evaluator: Evaluator, stats_logger: "StatsLogger", @@ -228,9 +228,9 @@ def load( if os.environ.get("AREAL_RECOVER_RUN", "0") != "1": return if inference_engine is not None: - assert ( - weight_update_meta is not None - ), "Inference engine requires weight update meta for recovery." + assert weight_update_meta is not None, ( + "Inference engine requires weight update meta for recovery." + ) if isinstance(engine, (TrainEngine, TrainController)): engine = {"default": engine} diff --git a/areal/utils/saver.py b/areal/utils/saver.py index ec0d9c71a..02db55257 100644 --- a/areal/utils/saver.py +++ b/areal/utils/saver.py @@ -4,14 +4,13 @@ from transformers import AutoProcessor, PreTrainedTokenizerFast from areal.api.cli_args import SaverConfig -from areal.api.controller_api import TrainController from areal.api.engine_api import TrainEngine from areal.api.io_struct import FinetuneSpec, SaveLoadMeta +from areal.controller.train_controller import TrainController from areal.utils import timeutil class Saver: - def __init__(self, config: SaverConfig, ft_spec: FinetuneSpec): self.config = config self.ft_spec = ft_spec diff --git a/areal/utils/scheduler.py b/areal/utils/scheduler.py deleted file mode 100644 index 9fff0b4d9..000000000 --- a/areal/utils/scheduler.py +++ /dev/null @@ -1,19 +0,0 @@ -from areal.api.cli_args import SchedulingSpec -from areal.api.engine_api import Scheduling - - -def scheduling_specs_to_schedulings(specs: list[SchedulingSpec]) -> list[Scheduling]: - result = [] - for spec in specs: - sch = Scheduling( - cpu=spec.cpu, - gpu=spec.gpu, - mem=spec.mem, - port_count=spec.port_count, - container_image=spec.image, - type=spec.type, - env_vars=spec.env_vars, - cmd=spec.cmd, - ) - result.append(sch) - return result diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 9225c4151..69189abf2 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -73,6 +73,7 @@ For detailed examples, see the experiment configurations in the `examples/` dire - [DistributedDataParallel Configuration](section-distributed-data-parallel) - [MegatronEngine Configuration](section-megatron-engine) +- [ScheduleStrategy](section-schedule-strategy) - [Scheduler Configuration](section-scheduler) - [Scheduling Specification](section-scheduling) @@ -335,7 +336,8 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `scheduling_specs` | list of [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | +| `scheduling_strategy` | [`ScheduleStrategy`](section-schedule-strategy) | **Required** | - | | `group_size` | integer | `1` | Number of sequences in each group | | `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | | `eps_clip` | float | `0.2` | Clipping factor for policy ratio | @@ -393,7 +395,8 @@ Configuration for PPO critic model, a subclass of a TrainEngine. | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `scheduling_specs` | list of [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | +| `scheduling_strategy` | [`ScheduleStrategy`](section-schedule-strategy) | **Required** | - | | `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | | `eps_clip` | float | `0.5` | Clipping factor for value loss | | `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | @@ -428,7 +431,8 @@ Core configuration for model training, including optimization and backend settin | `lora_alpha` | integer | `16` | lora alpha | | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `scheduling_specs` | list of [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | +| `scheduling_strategy` | [`ScheduleStrategy`](section-schedule-strategy) | **Required** | - | (section-generation-hyperparameters)= @@ -456,22 +460,23 @@ Controls text generation behavior for rollout. Configuration for inference servers, including offpolicyness control. -| Parameter | Type | Default | Description | -| ------------------------- | ---------------------------------------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string \| None | `None` | - | -| `trial_name` | string \| None | `None` | - | -| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | -| `queue_size` | integer \| None | `None` | (Deprecated) Input/Output queue size for async rollout. | -| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | -| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | -| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | -| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | -| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | -| `setup_timeout` | float | `120.0` | Timeout in seconds of connecting to remote servers or launching local servers. | -| `request_timeout` | float | `3600` | Timeout for HTTP requests. | -| `request_retries` | integer | `3` | Number of retries for failed requests. | -| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | -| `scheduling_specs` | list of [`SchedulingSpec`](section-scheduling) | **Required** | inference engine schedule specs | +| Parameter | Type | Default | Description | +| ------------------------- | ----------------------------------------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string \| None | `None` | - | +| `trial_name` | string \| None | `None` | - | +| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | +| `queue_size` | integer \| None | `None` | (Deprecated) Input/Output queue size for async rollout. | +| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | +| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | +| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | +| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | +| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | +| `setup_timeout` | float | `120.0` | Timeout in seconds of connecting to remote servers or launching local servers. | +| `request_timeout` | float | `3600` | Timeout for HTTP requests. | +| `request_retries` | integer | `3` | Number of retries for failed requests. | +| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | inference engine schedule specs | +| `scheduling_strategy` | [`ScheduleStrategy`](section-schedule-strategy) | **Required** | - | (section-sg-lang)= @@ -802,6 +807,17 @@ Refer to Megatron-LM documentation for implementation details. | `distribute_saved_activations` | boolean \| None | `None` | - | | `recompute_modules` | list of string \| None | `None` | - | +(section-schedule-strategy)= + +## ScheduleStrategy + +Configuration class: ScheduleStrategy + +| Parameter | Type | Default | Description | +| --------- | --------- | -------------- | ----------- | +| `type` | `Literal` | `"separation"` | - | +| `target` | string | `""` | - | + (section-scheduler)= ## Scheduler Configuration @@ -823,13 +839,19 @@ Configuration for worker scheduling. Used in the single-controller mode. Experim Configuration class: SchedulingSpec -| Parameter | Type | Default | Description | -| ------------ | ------- | ------------ | ---------------------------------------------------------------- | -| `cpu` | integer | `0` | Number of CPU cores required | -| `gpu` | integer | `0` | Number of GPU units required | -| `mem` | integer | `0` | Amount of memory (GB) required | -| `port_count` | integer | `2` | Number of ports to expose | -| `image` | string | `""` | Docker/Singularity container image to use | -| `type` | string | `"worker"` | Task type (e.g., worker, engine) **Choices:** `worker`, `engine` | -| `env_vars` | `dict` | **Required** | Environment variables for the container | -| `cmd` | string | `""` | Command to execute inside the container | +| Parameter | Type | Default | Description | +| ------------ | -------------- | ------------ | ------------------------------------------------------------------------ | +| `cpu` | integer | `0` | Number of CPU cores required | +| `gpu` | integer | `0` | Number of GPU units required | +| `mem` | integer | `0` | Amount of memory (GB) required | +| `port_count` | integer | `2` | Number of ports to expose | +| `image` | string | `""` | Docker/Singularity container image to use | +| `type` | string | `"worker"` | Task type (e.g., worker, engine) **Choices:** `worker`, `engine` | +| `env_vars` | `dict` | **Required** | Environment variables for the container | +| `cmd` | string \| None | `None` | Command to execute inside the container. Defaults to AReaL's RPC server. | +| `nodelist` | string \| None | `None` | - | +| `exclude` | string \| None | `None` | - | +| `partition` | string \| None | `None` | - | +| `time_limit` | string \| None | `None` | - | +| `begin` | string \| None | `None` | - | +| `deadline` | string \| None | `None` | - | diff --git a/realhf/api/core/system_api.py b/realhf/api/core/system_api.py index ea30213fe..47c695792 100644 --- a/realhf/api/core/system_api.py +++ b/realhf/api/core/system_api.py @@ -34,7 +34,7 @@ class ExpStatus(Enum): @dataclasses.dataclass -class Scheduling: +class SchedulingSpec: # TODO: add partition cpu: int gpu: int @@ -173,7 +173,7 @@ class MasterWorker: @dataclasses.dataclass class TasksGroup: count: int - scheduling: Scheduling + scheduling: SchedulingSpec @dataclasses.dataclass @@ -458,7 +458,7 @@ class Experiment: """Base class for defining the procedure of an experiment.""" def scheduling_setup(self) -> ExperimentScheduling: - """Returns the Scheduling of all workers.""" + """Returns the SchedulingSpec of all workers.""" raise NotImplementedError() def initial_setup(self) -> ExperimentConfig | List[ExperimentConfig]: diff --git a/realhf/apps/main.py b/realhf/apps/main.py index e1ca08764..416d656a2 100644 --- a/realhf/apps/main.py +++ b/realhf/apps/main.py @@ -47,7 +47,7 @@ def _submit_workers( job_environs = {**environs, **sch_cfg.scheduling.env_vars} cmd = sched_client.remote_worker_cmd(expr_name, trial_name, debug, worker_type) - logger.debug(f"Scheduling worker {worker_type}, {scheduling_configs}") + logger.debug(f"SchedulingSpec worker {worker_type}, {scheduling_configs}") nodelist = sch_cfg.scheduling.nodelist exclude = sch_cfg.scheduling.exclude diff --git a/realhf/experiments/async_exp/async_rl_exp.py b/realhf/experiments/async_exp/async_rl_exp.py index 3c23f13d1..313606400 100755 --- a/realhf/experiments/async_exp/async_rl_exp.py +++ b/realhf/experiments/async_exp/async_rl_exp.py @@ -34,7 +34,7 @@ GserverManager, ModelWorker, RolloutWorker, - Scheduling, + SchedulingSpec, TasksGroup, ) from realhf.api.quickstart.device_mesh import RPCAllocation @@ -90,7 +90,7 @@ def scheduling_setup(self) -> ExperimentScheduling: return ExperimentScheduling( master_worker=TasksGroup( count=1, - scheduling=Scheduling( + scheduling=SchedulingSpec( cpu=self.cpus_per_master_worker, gpu=0, mem=self.mem_per_master_worker, @@ -101,7 +101,7 @@ def scheduling_setup(self) -> ExperimentScheduling: ), model_worker=TasksGroup( count=train_world_size, - scheduling=Scheduling( + scheduling=SchedulingSpec( cpu=self.cpus_per_model_worker, gpu=1, mem=self.mem_per_model_worker, @@ -112,7 +112,7 @@ def scheduling_setup(self) -> ExperimentScheduling: ), generation_server=TasksGroup( count=gen_world_size // gen_tp_size, - scheduling=Scheduling( + scheduling=SchedulingSpec( cpu=self.cpus_per_generation_server, gpu=gen_tp_size, mem=self.mem_per_generation_server, @@ -123,7 +123,7 @@ def scheduling_setup(self) -> ExperimentScheduling: ), gserver_manager=TasksGroup( count=1, - scheduling=Scheduling( + scheduling=SchedulingSpec( cpu=self.cpus_per_gserver_manager, gpu=0, mem=self.mem_per_gserver_manager, @@ -134,7 +134,7 @@ def scheduling_setup(self) -> ExperimentScheduling: ), rollout_worker=TasksGroup( count=self.n_rollout_workers or train_world_size, - scheduling=Scheduling( + scheduling=SchedulingSpec( cpu=self.cpus_per_rollout_worker, gpu=0, mem=self.mem_per_rollout_worker, diff --git a/realhf/experiments/common/common.py b/realhf/experiments/common/common.py index 370f16fd1..2507fe896 100644 --- a/realhf/experiments/common/common.py +++ b/realhf/experiments/common/common.py @@ -34,7 +34,7 @@ ExperimentConfig, ExperimentScheduling, ModelWorker, - Scheduling, + SchedulingSpec, TasksGroup, ) from realhf.api.quickstart.device_mesh import ( @@ -163,7 +163,7 @@ def scheduling_setup(self) -> ExperimentScheduling: return ExperimentScheduling( master_worker=TasksGroup( count=1, - scheduling=Scheduling( + scheduling=SchedulingSpec( cpu=self.cpus_per_master_worker, gpu=0, mem=self.mem_per_master_worker, @@ -174,7 +174,7 @@ def scheduling_setup(self) -> ExperimentScheduling: ), model_worker=TasksGroup( count=self.n_nodes * self.n_gpus_per_node, - scheduling=Scheduling( + scheduling=SchedulingSpec( cpu=self.cpus_per_model_worker, gpu=1, mem=self.mem_per_model_worker, diff --git a/realhf/system/controller.py b/realhf/system/controller.py index 0e8b982ee..33bc3075e 100644 --- a/realhf/system/controller.py +++ b/realhf/system/controller.py @@ -127,7 +127,7 @@ def __check_consistent_scheduling( setup: system_api.ExperimentConfig, verbose=False, ): - # Scheduling and connecting to workers. + # SchedulingSpec and connecting to workers. workers_configs = [ (k, getattr(setup, k), getattr(scheduling, k)) for k in WORKER_TYPES @@ -142,7 +142,7 @@ def __check_consistent_scheduling( raise ValueError( f"Configuration and scheduling mismatch. " f"Number of worker configurations: {len(worker_setups)}, " - f"Scheduling configs: {schedules}." + f"SchedulingSpec configs: {schedules}." ) for name, config, schedule in workers_configs: @@ -153,7 +153,7 @@ def __check_consistent_scheduling( ) if len(config) != count: logger.error( - "Scheduling and config mismatch, interrupting all workers." + "SchedulingSpec and config mismatch, interrupting all workers." ) self.interrupt() raise IndexError( From 5a702a1524eacdc5a68181e0020d850d37ddcddd Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 19:54:34 +0800 Subject: [PATCH 17/52] add train controller tests --- areal/controller/train_controller.py | 321 +++++---- areal/scheduler/rpc/rpc_server.py | 25 + areal/tests/test_train_controller.py | 871 ++++++++++++++++++++++++ examples/single-controller/gsm8k_sft.py | 7 +- 4 files changed, 1052 insertions(+), 172 deletions(-) create mode 100644 areal/tests/test_train_controller.py diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index 7054a7acd..f2b1dd4b4 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -1,5 +1,6 @@ import asyncio from collections.abc import Callable +from datetime import datetime from typing import Any import torch @@ -7,7 +8,7 @@ from areal.api.alloc_mode import ParallelStrategy from areal.api.cli_args import TrainEngineConfig from areal.api.controller_api import DistributedBatch -from areal.api.engine_api import InferenceEngine, TrainEngine +from areal.api.engine_api import TrainEngine from areal.api.io_struct import ( AllocationMode, FinetuneSpec, @@ -16,7 +17,9 @@ ) from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker from areal.controller.batch import DistributedBatchMemory -from areal.utils import logging +from areal.controller.rollout_controller import RolloutController +from areal.platforms import current_platform +from areal.utils import logging, name_resolve, names logger = logging.getLogger("TrainController") @@ -59,11 +62,13 @@ def __init__( self.config = config self.scheduler = scheduler - self.group_size: int self.alloc_mode: AllocationMode self.workers: list[Worker] = [] - self.dp_head_workers: list[Worker] = [] # Only DP head workers - self.engine_dp_ranks: list[int] = [] # DP rank of each DP head worker + self.workers_is_dp_head: list[bool] = [] # Only DP head workers + self.parallel_strategy: ParallelStrategy | None = None + + self.rollout: RolloutController + self.weight_update_group_initialized = False self._worker_role: str self.logger = None @@ -76,7 +81,10 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None parallel_strategy : ParallelStrategy, optional The parallel strategy configuration for distributed training, by default None """ - assert self.workers is not None, "Workers are not created" + assert len(self.workers) > 0, "Workers are not created" + if parallel_strategy is None: + parallel_strategy = ParallelStrategy() + self.parallel_strategy = parallel_strategy self.custom_function_call("create_process_group", parallel_strategy) def initialize( @@ -109,8 +117,14 @@ def initialize( # Store configuration self._worker_role = role self.alloc_mode = alloc_mode - # todo: group size is a sampling parameter and an attribute of the data, should be moved to DistributedBatch - self.group_size = kwargs.get("group_size", 1) + + # Initialize parallel_strategy from alloc_mode if not already set + if self.parallel_strategy is None: + self.parallel_strategy = ParallelStrategy( + data_parallel_size=alloc_mode.train.dp_size, + tensor_parallel_size=alloc_mode.train.tp_size, + pipeline_parallel_size=alloc_mode.train.pp_size, + ) # Create job for scheduler job = Job( @@ -169,7 +183,6 @@ async def _async_create_and_initialize_engines( self.scheduler.async_call_engine( worker_id=worker.id, method="initialize", - addr=None, ft_spec=ft_spec, **kwargs, ) @@ -183,33 +196,16 @@ def _identify_dp_heads(self): self.logger.info("Identifying DP head workers...") # Query all workers for their DP rank - async def _get_dp_ranks(): + async def _get_dp_head(): tasks = [ self.scheduler.async_call_engine( - worker_id=worker.id, method="data_parallel_rank" + worker_id=worker.id, method="is_data_parallel_head" ) for worker in self.workers ] return await asyncio.gather(*tasks) - dp_ranks = asyncio.run(_get_dp_ranks()) - - # Find unique DP ranks and corresponding head workers - seen_dp_ranks = set() - self.dp_head_workers = [] - self.engine_dp_ranks = [] - - for worker, dp_rank in zip(self.workers, dp_ranks): - if dp_rank not in seen_dp_ranks: - self.dp_head_workers.append(worker) - self.engine_dp_ranks.append(dp_rank) - seen_dp_ranks.add(dp_rank) - - self.logger.info( - f"Identified {len(self.dp_head_workers)} DP head workers " - f"from {len(self.workers)} total workers. " - f"DP ranks: {self.engine_dp_ranks}" - ) + self.workers_is_dp_head = asyncio.run(_get_dp_head()) def destroy(self): """Destroy the controller and release GPU memory of models. @@ -227,31 +223,15 @@ def destroy(self): # Clear worker lists self.workers.clear() - self.dp_head_workers.clear() - self.engine_dp_ranks.clear() + self.workers_is_dp_head.clear() self.logger.info("TrainController destroyed") def custom_function_call(self, method: str, *args, **kwargs): """Dispatch method call to appropriate workers based on input type. - If any argument is a DistributedBatch, split data and call only DP heads. - Otherwise, call all workers with the same arguments. + If any argument is a DistributedBatch, split data. Call only DP heads. """ - # Check if any argument is a DistributedBatch - has_distributed_batch = any( - isinstance(arg, DistributedBatch) for arg in args - ) or any(isinstance(v, DistributedBatch) for v in kwargs.values()) - - if has_distributed_batch: - # Call ONLY DP heads with split data - return self._call_dp_heads_with_data_split(method, *args, **kwargs) - else: - # Call ALL workers (no data splitting needed) - return self._call_all_workers(method, *args, **kwargs) - - def _call_dp_heads_with_data_split(self, method: str, *args, **kwargs): - """Call only DP head workers with data split across DP groups.""" # Find and split DistributedBatch arguments split_args = [] for arg in args: @@ -260,37 +240,48 @@ def _call_dp_heads_with_data_split(self, method: str, *args, **kwargs): split_args.append(self._align_batches_with_dp(arg, rebalance=True)) else: # Replicate to all DP heads - split_args.append([arg] * len(self.dp_head_workers)) + split_args.append([arg] * self.parallel_strategy.dp_size) split_kwargs = {} for k, v in kwargs.items(): if isinstance(v, DistributedBatch): split_kwargs[k] = self._align_batches_with_dp(v, rebalance=True) else: - split_kwargs[k] = [v] * len(self.dp_head_workers) + split_kwargs[k] = [v] * self.parallel_strategy.dp_size - # Call ONLY DP head workers with their data slice + # Call all workers. + # ONLY DP head workers get their data slice. + # Other workers will get data by broadcasting in RPC server. async def _call_all(): tasks = [] - for idx, worker in enumerate(self.dp_head_workers): - # Get this worker's slice of each argument - worker_args = [splits[idx] for splits in split_args] - worker_kwargs = {k: splits[idx] for k, splits in split_kwargs.items()} - - # Convert DistributedBatch to dict for RPC - worker_args = [ - arg.get_data() if isinstance(arg, DistributedBatch) else arg - for arg in worker_args - ] - worker_kwargs = { - k: v.get_data() if isinstance(v, DistributedBatch) else v - for k, v in worker_kwargs.items() - } + dp_idx = 0 + for idx, worker in enumerate(self.workers): + if self.workers_is_dp_head[idx]: + # Get this worker's slice of each argument + worker_args = [splits[dp_idx] for splits in split_args] + worker_kwargs = { + k: splits[dp_idx] for k, splits in split_kwargs.items() + } + + # Convert DistributedBatch to dict for RPC + # FIXME: pass metadata instead of real tensors + worker_args = [ + arg.get_data() if isinstance(arg, DistributedBatch) else arg + for arg in worker_args + ] + worker_kwargs = { + k: v.get_data() if isinstance(v, DistributedBatch) else v + for k, v in worker_kwargs.items() + } + dp_idx += 1 + else: + worker_args = [] + worker_kwargs = {} tasks.append( self.scheduler.async_call_engine( - worker_id=worker.id, - method=method, + worker.id, + method, *worker_args, **worker_kwargs, ) @@ -298,77 +289,43 @@ async def _call_all(): return await asyncio.gather(*tasks) results = asyncio.run(_call_all()) + # Only remain data from DP head. + results = [r for idx, r in enumerate(results) if self.workers_is_dp_head[idx]] return self._merge_results(results, method) - def _call_all_workers(self, method: str, *args, **kwargs): - """Call all workers with the same arguments (no data splitting).""" - - async def _call_all(): - tasks = [ - self.scheduler.async_call_engine( - worker_id=worker.id, method=method, *args, **kwargs - ) - for worker in self.workers - ] - return await asyncio.gather(*tasks) - - results = asyncio.run(_call_all()) - return self._merge_results(results, method) - - def _merge_results(self, results, method: str): - """Merge results from workers based on result type. + def _merge_results(self, results, method): + """Merge results from DP head workers based on result type. - - For None: return None - - For dict with scalar values: return first (already synchronized) - - For dict with tensor/batch values: concat as DistributedBatch - - For tensors/lists: concat as DistributedBatch - - For scalars: return first (already synchronized) + - For torch.Tensor: concat results as DistributedBatch + - For others: assume they have been synchronized and return the first """ - # Filter out None results - non_none_results = [r for r in results if r is not None] - - if len(non_none_results) == 0: - return None + first_result = results[0] - first_result = non_none_results[0] - - # If all results are dicts + # FIXME: should use a more general data conversion strategy if isinstance(first_result, dict): - # Check if it's a dict of scalars (like train_batch stats) - if all(isinstance(v, (int, float)) for v in first_result.values()): - # Stats are already synchronized within engines - return first - return first_result - else: - # Dict of tensors/batches - concat as DistributedBatch - return DistributedBatchMemory.concat( - [DistributedBatchMemory.from_dict(r) for r in non_none_results] - ) - - # If result is a tensor or torch.Tensor - elif isinstance(first_result, torch.Tensor): - # Single tensor, likely already reduced - return first - return first_result - - # If result is a list/iterable (but not string) - elif hasattr(first_result, "__iter__") and not isinstance(first_result, str): - try: - # Try to concat as DistributedBatch - return DistributedBatchMemory.concat( - [ - DistributedBatchMemory.from_dict(r) - if isinstance(r, dict) - else r - for r in non_none_results - ] - ) - except Exception: - # If concat fails, return list of results - return non_none_results - - # For scalars (int, float, bool, etc.) - else: - # Return first (already synchronized) - return first_result + if len(first_result) == 0: + return DistributedBatchMemory.from_dict({}) + + k = next(iter(first_result.keys())) + if isinstance(first_result[k], torch.Tensor): + # Check if this looks like a proper batch (has attention_mask) + # If so, use DistributedBatchMemory.concat which handles padding + if "attention_mask" in first_result: + return DistributedBatchMemory.concat( + [DistributedBatchMemory.from_dict(r) for r in results] + ) + else: + # Simple tensor dict - just concatenate tensors along batch dim + merged = {} + for key in first_result.keys(): + if isinstance(first_result[key], torch.Tensor): + merged[key] = torch.cat([r[key] for r in results], dim=0) + else: + merged[key] = first_result[key] + return DistributedBatchMemory.from_dict(merged) + + # Return first (already synchronized) + return first_result def _align_batches_with_dp( self, input_: DistributedBatch, rebalance=True @@ -377,17 +334,77 @@ def _align_batches_with_dp( Returns a list of batches, one for each DP head worker. """ + # Handle empty batch by replicating to all DP groups + if len(input_.get_data()) == 0: + return [input_] * self.alloc_mode.train.dp_size + + # NOTE: group normalization should be done in workflow if rebalance: - inputs = input_.chunk_by_ffd(self.group_size, self.alloc_mode.train.dp_size) + inputs = input_.chunk_by_ffd(1, self.alloc_mode.train.dp_size) else: inputs = input_.chunk(self.alloc_mode.train.dp_size) + return inputs + + def connect_engine(self, rollout: RolloutController, meta: WeightUpdateMeta): + if self.rollout is not None and self.rollout != rollout: + self.logger.warning( + f"Connected rollout controller changed from {self.rollout} to {rollout}." + ) + self.rollout = rollout + + if ( + meta.type == current_platform.communication_backend + and not self.weight_update_group_initialized + ): + self._init_weight_update_from_distributed(meta) + self.weight_update_group_initialized = True + + def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta): + raise NotImplementedError() + + def _update_weights_from_distributed(self, meta: WeightUpdateMeta): + raise NotImplementedError() + + def _update_weights_from_disk(self, meta: WeightUpdateMeta): + # Update all LocalInfEngine's local weight + fut = self.rollout.update_weights_from_disk(meta) + self.save( + SaveLoadMeta( + path=meta.path, + weight_format="hf", + with_optim=False, + tokenizer=None, + processor=None, + ) + ) + update_name = names.update_weights_from_disk( + self.config.experiment_name, + self.config.trial_name, + self.get_version(), + ) + name_resolve.add( + update_name, str(datetime.now().timestamp()), keepalive_ttl=120 + ) - # Return batches corresponding to DP head ranks - batches = [] - for dp_rank in self.engine_dp_ranks: - batches.append(inputs[dp_rank]) + fut.result() - return batches + def _check_rollout_engine_connected(self): + """Validate that rollout engine has been connected via connect_engine().""" + if self.rollout is None: + raise RuntimeError( + "Rollout engine not connected. Call connect_engine()" + " before using rollout/update_weight methods." + ) + + def update_weights(self, meta: WeightUpdateMeta): + self._check_rollout_engine_connected() + if meta.type == current_platform.communication_backend: + assert self.weight_update_group_initialized + self._update_weights_from_distributed(meta) + elif meta.type == "disk": + self._update_weights_from_disk(meta) + else: + raise ValueError(f"Unknown weight update type {meta.type}") # ==================== ENGINE RPC WRAPPERS ==================== def train(self, mode: bool = True): @@ -418,36 +435,6 @@ def eval(self): """ return self.train(False) - def update_weights(self, meta: WeightUpdateMeta): - """Update weights to the inference engine in a blocking manner. - - Parameters - ---------- - meta : WeightUpdateMeta - Metadata containing information about the weight update - """ - self.custom_function_call("update_weights", meta) - - def connect_engine(self, engine: InferenceEngine, meta: WeightUpdateMeta): - """Connect to an inference engine for online training. - - Parameters - ---------- - engine : InferenceEngine - The inference engine to connect to - meta : WeightUpdateMeta - Metadata for weight update configuration - - Raises - ------ - NotImplementedError - This method is not implemented for TrainController - """ - raise NotImplementedError( - "connect_engine is not implemented for TrainController. " - "Use RolloutController for online training workflows." - ) - def set_version(self, version: int): """Set the current weight version in the training engine. diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index d90e2217f..1881b90d0 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -177,6 +177,31 @@ async def call_engine_method(request: Request): args = deserialize_value(args) kwargs = deserialize_value(kwargs) + try: + if isinstance(_engine, TrainEngine): + logger.info(f"Broadcasting data for TrainEngine method: {method_name}") + from areal.utils.data import broadcast_tensor_container + + args = broadcast_tensor_container( + args, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + kwargs = broadcast_tensor_container( + kwargs, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + logger.info("Broadcasting data done.") + except Exception as e: + logger.error( + f"Broadcasting data for method '{method_name}' failed: {e}\n{traceback.format_exc()}" + ) + raise HTTPException( + status_code=500, + detail=f"Data bcast '{method_name}' failed: {str(e)}", + ) + # Call method directly (no need for hasattr/getattr with typed engine) logger.info(f"Calling engine method: {method_name}") try: diff --git a/areal/tests/test_train_controller.py b/areal/tests/test_train_controller.py new file mode 100644 index 000000000..a89be5181 --- /dev/null +++ b/areal/tests/test_train_controller.py @@ -0,0 +1,871 @@ +"""Unit tests for TrainController. + +Tests cover initialization, worker management, batch operations, +RPC wrappers, PPO/SFT methods, weight management, and error handling. +""" + +import asyncio +from unittest.mock import Mock, patch + +import pytest +import torch + +from areal.api.alloc_mode import ParallelStrategy +from areal.api.cli_args import ScheduleStrategy, SchedulingSpec, TrainEngineConfig +from areal.api.engine_api import TrainEngine +from areal.api.io_struct import ( + AllocationMode, + FinetuneSpec, + SaveLoadMeta, + WeightUpdateMeta, +) +from areal.api.scheduler_api import Worker +from areal.controller.batch import DistributedBatchMemory +from areal.controller.train_controller import TrainController + + +class MockTrainEngine(TrainEngine): + """Mock TrainEngine for testing.""" + + @classmethod + def __module__(cls): + return "areal.tests.test_train_controller" + + @classmethod + def __name__(cls): + return "MockTrainEngine" + + +class MockScheduler: + """Mock Scheduler for testing TrainController.""" + + def __init__(self): + self.workers = [] + self.call_count = 0 + self.engine_calls = [] + self.deleted_roles = [] + + def create_workers(self, job): + """Create mock workers based on job configuration.""" + worker_ids = [f"{job.role}/{i}" for i in range(job.replicas)] + self.workers = [ + Worker( + id=wid, + ip="127.0.0.1", + worker_ports=["8000", "8001"], + engine_ports=["9000", "9001"], + ) + for wid in worker_ids + ] + return worker_ids + + def get_workers(self, role, timeout=None): + """Return list of workers for the given role.""" + return self.workers + + async def create_engine(self, worker_id, engine, config): + """Mock engine creation.""" + await asyncio.sleep(0.001) + return None + + async def async_call_engine(self, worker_id, method, *args, **kwargs): + """Mock async engine call.""" + self.engine_calls.append((worker_id, method, args, kwargs)) + self.call_count += 1 + + # Return appropriate mock results based on method + if method == "is_data_parallel_head": + # First worker in each DP group is the head + worker_idx = int(worker_id.split("/")[-1]) + return worker_idx % 2 == 0 # Every other worker is a DP head + + elif method == "get_version": + return 1 + + elif method == "train_batch": + return {"loss": 0.5, "lr": 0.001, "grad_norm": 1.0} + + elif method == "eval_batch": + return torch.tensor(0.3) + + elif method == "forward": + return {"logits": torch.randn(4, 10, 50257)} + + elif method == "compute_logp": + return {"log_probs": torch.randn(4, 10)} + + elif method == "compute_advantages": + return {"advantages": torch.randn(4)} + + elif method == "ppo_update": + return {"ppo_loss": 0.2, "kl_div": 0.01} + + elif method == "train_lm": + return {"lm_loss": 0.4, "perplexity": 1.5} + + elif method == "evaluate_lm": + return torch.tensor(0.35) + + await asyncio.sleep(0.001) + return None + + def delete_workers(self, role): + """Mock worker deletion.""" + self.deleted_roles.append(role) + self.workers.clear() + + +# ==================== FIXTURES ==================== + + +@pytest.fixture +def mock_scheduler(): + """Provide a MockScheduler instance.""" + return MockScheduler() + + +@pytest.fixture +def train_config(): + """Provide a TrainEngineConfig for testing.""" + return TrainEngineConfig( + scheduling_spec=SchedulingSpec(cpu=4, gpu=1, mem=16000, port_count=2) + ) + + +@pytest.fixture +def alloc_mode(): + """Provide an AllocationMode for testing.""" + mode = AllocationMode.from_str("d4t2p1") + return mode + + +@pytest.fixture +def parallel_strategy(): + """Provide a ParallelStrategy for testing.""" + return ParallelStrategy( + data_parallel_size=4, tensor_parallel_size=2, pipeline_parallel_size=1 + ) + + +@pytest.fixture +def ft_spec(): + """Provide a FinetuneSpec for testing.""" + return FinetuneSpec(total_train_epochs=10, dataset_size=1000, train_batch_size=32) + + +@pytest.fixture +def schedule_strategy(): + """Provide a ScheduleStrategy for testing.""" + return ScheduleStrategy(type="separation", target="") + + +@pytest.fixture +def train_controller(mock_scheduler, train_config): + """Provide a TrainController instance.""" + return TrainController( + train_engine=MockTrainEngine, config=train_config, scheduler=mock_scheduler + ) + + +def create_mock_distributed_batch(size=4, seq_len=10): + """Create a mock DistributedBatch for testing.""" + data = { + "input_ids": torch.randint(0, 100, (size, seq_len)), + "attention_mask": torch.ones(size, seq_len, dtype=torch.bool), + "loss_mask": torch.ones(size, seq_len, dtype=torch.bool), + } + return DistributedBatchMemory.from_dict(data) + + +# ==================== TEST CLASSES ==================== + + +class TestTrainControllerInitialization: + """Tests for TrainController initialization and setup.""" + + def test_constructor(self, mock_scheduler, train_config): + """Test TrainController constructor.""" + controller = TrainController( + train_engine=MockTrainEngine, config=train_config, scheduler=mock_scheduler + ) + + assert controller.train_engine == MockTrainEngine + assert controller.config == train_config + assert controller.scheduler == mock_scheduler + assert controller.workers == [] + assert controller.worker_is_dp_head == [] + assert controller.logger is None + + def test_initialize(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + """Test initialize method creates workers and engines.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + # Verify workers were created + assert len(train_controller.workers) == alloc_mode.train.world_size + assert train_controller._worker_role == "train_worker" + assert train_controller.alloc_mode == alloc_mode + + # Verify DP heads were identified + assert len(train_controller.worker_is_dp_head) == len(train_controller.workers) + + # Verify scheduler was called + assert train_controller.scheduler.call_count > 0 + + def test_create_process_group_sets_parallel_strategy( + self, train_controller, parallel_strategy + ): + """Test that create_process_group correctly assigns parallel_strategy. + + This is a regression test for the bug at line 79 where parallel_strategy + was being assigned to itself instead of the parameter. + """ + # Setup: Add mock workers + train_controller.workers = [Mock(), Mock()] + train_controller.parallel_strategy = parallel_strategy + + # Call create_process_group with a different strategy + new_strategy = ParallelStrategy( + data_parallel_size=8, tensor_parallel_size=1, pipeline_parallel_size=1 + ) + + with patch.object(train_controller, "custom_function_call") as mock_custom_call: + train_controller.create_process_group(new_strategy) + + # Verify the parallel_strategy should be updated to new_strategy + # NOTE: This test currently fails due to the bug at line 79 + # After fixing the bug, this assertion should pass + assert train_controller.parallel_strategy == new_strategy + assert train_controller.parallel_strategy.data_parallel_size == 8 + + # Verify custom_function_call was invoked + mock_custom_call.assert_called_once_with( + "create_process_group", new_strategy + ) + + def test_identify_dp_heads( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test _identify_dp_heads correctly identifies DP head workers.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + # MockScheduler returns True for even-indexed workers + for idx, is_head in enumerate(train_controller.worker_is_dp_head): + assert is_head == (idx % 2 == 0) + + +class TestTrainControllerDestroy: + """Tests for TrainController cleanup and destruction.""" + + def test_destroy(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + """Test destroy method cleans up resources.""" + # Initialize first + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + initial_worker_count = len(train_controller.workers) + assert initial_worker_count > 0 + + # Call destroy + train_controller.destroy() + + # Verify cleanup + assert len(train_controller.workers) == 0 + assert len(train_controller.worker_is_dp_head) == 0 + assert "train_worker" in train_controller.scheduler.deleted_roles + + def test_destroy_handles_errors( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test destroy handles errors gracefully.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + # Make delete_workers raise an exception + def raise_error(role): + raise RuntimeError("Simulated error") + + train_controller.scheduler.delete_workers = raise_error + + # Should not raise, just log the error + train_controller.destroy() + + # Workers should still be cleared + assert len(train_controller.workers) == 0 + + +class TestTrainControllerBatchOperations: + """Tests for batch splitting and alignment operations.""" + + def test_align_batches_with_dp_rebalance( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test _align_batches_with_dp with rebalance=True.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + batch = create_mock_distributed_batch(size=16) + chunks = train_controller._align_batches_with_dp(batch, rebalance=True) + + # Should split into dp_size chunks + assert len(chunks) == alloc_mode.train.dp_size + + # Each chunk should be a DistributedBatch + for chunk in chunks: + assert isinstance(chunk, DistributedBatchMemory) + + def test_align_batches_with_dp_no_rebalance( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test _align_batches_with_dp with rebalance=False.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + batch = create_mock_distributed_batch(size=16) + chunks = train_controller._align_batches_with_dp(batch, rebalance=False) + + # Should split into dp_size chunks + assert len(chunks) == alloc_mode.train.dp_size + + # Each chunk should be a DistributedBatch + for chunk in chunks: + assert isinstance(chunk, DistributedBatchMemory) + + +class TestTrainControllerMergeResults: + """Tests for result merging from workers.""" + + def test_merge_results_with_tensor_dict(self, train_controller): + """Test _merge_results with dictionary of tensors.""" + results = [ + {"loss": torch.tensor([0.5, 0.6])}, + {"loss": torch.tensor([0.3, 0.4])}, + ] + + merged = train_controller._merge_results(results, "train_batch") + + # Should concatenate into DistributedBatch + assert isinstance(merged, DistributedBatchMemory) + assert "loss" in merged.get_data() + + def test_merge_results_with_empty_dict(self, train_controller): + """Test _merge_results with empty dictionaries.""" + results = [{}, {}] + + merged = train_controller._merge_results(results, "some_method") + + # Should return empty DistributedBatch + assert isinstance(merged, DistributedBatchMemory) + assert len(merged.get_data()) == 0 + + def test_merge_results_with_non_tensor(self, train_controller): + """Test _merge_results with non-tensor results.""" + results = [{"status": "ok"}, {"status": "ok"}] + + merged = train_controller._merge_results(results, "some_method") + + # Should return first result (already synchronized) + assert merged == {"status": "ok"} + + def test_merge_results_accepts_method_parameter(self, train_controller): + """Test that _merge_results accepts method parameter. + + This is a regression test for the bug at line 279 where the method + parameter was missing from the signature. + """ + results = [torch.tensor(0.5), torch.tensor(0.3)] + + # This should work without TypeError + try: + result = train_controller._merge_results(results, "train_batch") + # Test passes if no exception + assert result is not None + except TypeError as e: + if "missing" in str(e) and "required positional argument" in str(e): + pytest.fail(f"_merge_results missing required parameter: {e}") + + +class TestTrainControllerRPCWrappers: + """Tests for RPC wrapper methods.""" + + def test_train_mode(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + """Test train() method sets training mode.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + result = train_controller.train(mode=True) + + # Should return self for chaining + assert result is train_controller + + # Verify custom_function_call was invoked + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "train" in engine_calls + + def test_eval_mode(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + """Test eval() method sets evaluation mode.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + result = train_controller.eval() + + # Should return self for chaining + assert result is train_controller + + # Verify train(False) was called + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "train" in engine_calls + + def test_forward(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + """Test forward() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + result = train_controller.forward(batch) + + # Should return merged results from DP heads + assert result is not None + + # Verify forward was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "forward" in engine_calls + + def test_train_batch( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test train_batch() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + + def loss_fn(output, batch_data): + return torch.tensor(0.5) + + def loss_weight_fn(batch_data): + return torch.tensor(1.0) + + result = train_controller.train_batch(batch, loss_fn, loss_weight_fn) + + # Should return stats dictionary + assert isinstance(result, dict) + + # Verify train_batch was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "train_batch" in engine_calls + + def test_eval_batch(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + """Test eval_batch() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + + def loss_fn(output, batch_data): + return torch.tensor(0.3) + + def loss_weight_fn(batch_data): + return torch.tensor(1.0) + + result = train_controller.eval_batch(batch, loss_fn, loss_weight_fn) + + # Should return loss tensor or merged results + assert result is not None + + # Verify eval_batch was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "eval_batch" in engine_calls + + def test_step_lr_scheduler( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test step_lr_scheduler() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + train_controller.step_lr_scheduler() + + # Verify step_lr_scheduler was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "step_lr_scheduler" in engine_calls + + +class TestTrainControllerPPOMethods: + """Tests for PPO-specific methods.""" + + def test_compute_logp( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test compute_logp() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + result = train_controller.compute_logp() + + # Should return merged results + assert result is not None + + # Verify compute_logp was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "compute_logp" in engine_calls + + def test_compute_advantages( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test compute_advantages() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + result = train_controller.compute_advantages() + + # Should return merged results + assert result is not None + + # Verify compute_advantages was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "compute_advantages" in engine_calls + + def test_ppo_update(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + """Test ppo_update() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + result = train_controller.ppo_update(batch) + + # Should return stats dictionary + assert isinstance(result, dict) + + # Verify ppo_update was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "ppo_update" in engine_calls + + +class TestTrainControllerSFTMethods: + """Tests for SFT-specific methods.""" + + def test_train_lm(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + """Test train_lm() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + result = train_controller.train_lm(batch) + + # Should return stats dictionary + assert isinstance(result, dict) + + # Verify train_lm was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "train_lm" in engine_calls + + def test_evaluate_lm( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test evaluate_lm() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + result = train_controller.evaluate_lm(batch) + + # Should return loss tensor or merged results + assert result is not None + + # Verify evaluate_lm was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "evaluate_lm" in engine_calls + + +class TestTrainControllerWeightManagement: + """Tests for weight management operations.""" + + def test_set_version( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test set_version() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + train_controller.set_version(42) + + # Verify set_version was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "set_version" in engine_calls + + def test_get_version( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test get_version() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + version = train_controller.get_version() + + # Should return version number + assert isinstance(version, int) + + # Verify get_version was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "get_version" in engine_calls + + def test_update_weights( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test update_weights() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + meta = WeightUpdateMeta(type="disk", path="/tmp/weights") + train_controller.update_weights(meta) + + # Verify update_weights was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "update_weights" in engine_calls + + def test_save(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + """Test save() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + meta = SaveLoadMeta( + path="/tmp/checkpoint", weight_format="safetensors", with_optim=True + ) + train_controller.save(meta) + + # Verify save was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "save" in engine_calls + + def test_load(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + """Test load() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + meta = SaveLoadMeta( + path="/tmp/checkpoint", weight_format="safetensors", with_optim=True + ) + train_controller.load(meta) + + # Verify load was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "load" in engine_calls + + def test_connect_engine_raises_not_implemented(self, train_controller): + """Test connect_engine() raises NotImplementedError.""" + from areal.api.engine_api import InferenceEngine + + mock_engine = Mock(spec=InferenceEngine) + meta = WeightUpdateMeta(type="nccl", alloc_mode=AllocationMode.from_str("d4")) + + with pytest.raises(NotImplementedError) as exc_info: + train_controller.connect_engine(mock_engine, meta) + + assert "not implemented for TrainController" in str(exc_info.value) + assert "RolloutController" in str(exc_info.value) + + +class TestTrainControllerCustomFunctionCall: + """Tests for custom_function_call orchestration.""" + + def test_custom_function_call_with_distributed_batch( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test custom_function_call with DistributedBatch argument.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + # Clear previous calls from initialization + train_controller.scheduler.engine_calls = [] + + batch = create_mock_distributed_batch(size=16) + result = train_controller.custom_function_call("forward", input_=batch) + + # Should split batch across DP groups and call only DP heads + assert result is not None + + # Count how many workers were called + worker_calls = len(train_controller.scheduler.engine_calls) + + # Should call all workers (DP heads get data, others get empty) + assert worker_calls == len(train_controller.workers) + + def test_custom_function_call_with_regular_args( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test custom_function_call with non-DistributedBatch arguments.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + # Clear previous calls + train_controller.scheduler.engine_calls = [] + + result = train_controller.custom_function_call("set_version", 5) + + # set_version returns None, which is expected - just verify it doesn't crash + # The key test is that all workers were called + assert result is None + + # Verify all workers were called + assert len(train_controller.scheduler.engine_calls) == len( + train_controller.workers + ) + + def test_custom_function_call_filters_dp_heads( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test custom_function_call only returns results from DP heads.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + train_controller.custom_function_call("train_batch", input_=batch) + + # Results should only come from DP head workers + # (verified by _merge_results receiving filtered results) + + +class TestTrainControllerEdgeCases: + """Tests for edge cases and error handling.""" + + def test_empty_distributed_batch( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test handling of empty DistributedBatch.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + empty_batch = DistributedBatchMemory.from_dict({}) + result = train_controller.forward(empty_batch) + + # Should handle empty batch gracefully + assert result is not None + + def test_create_process_group_requires_workers( + self, train_controller, parallel_strategy + ): + """Test create_process_group asserts workers exist.""" + # Don't initialize, so workers list is empty + with pytest.raises(AssertionError, match="Workers are not created"): + train_controller.create_process_group(parallel_strategy) + + def test_method_chaining( + self, train_controller, alloc_mode, ft_spec, schedule_strategy + ): + """Test that train() and eval() support method chaining.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + schedule_strategy=schedule_strategy, + ) + + # Should be able to chain calls + result = train_controller.train().eval().train() + assert result is train_controller diff --git a/examples/single-controller/gsm8k_sft.py b/examples/single-controller/gsm8k_sft.py index 066dc576f..88eb4e88f 100644 --- a/examples/single-controller/gsm8k_sft.py +++ b/examples/single-controller/gsm8k_sft.py @@ -2,12 +2,11 @@ from torchdata.stateful_dataloader import StatefulDataLoader -from areal.api.alloc_mode import AllocationMode from areal.api.cli_args import SFTConfig, load_expr_config from areal.api.io_struct import FinetuneSpec, StepInfo from areal.api.scheduler_api import ScheduleStrategy from areal.controller.batch import DistributedBatchMemory -from areal.controller.train_controller import DistributedTrainController +from areal.controller.train_controller import TrainController from areal.dataset import get_custom_dataset from areal.engine.sft.lm_engine import FSDPLMEngine from areal.scheduler.local import LocalScheduler @@ -29,8 +28,6 @@ def main(args): config, _ = load_expr_config(args, SFTConfig) config: SFTConfig - AllocationMode.from_str(config.allocation_mode) - engine = FSDPLMEngine(config=config.model) tokenizer = load_hf_tokenizer(config.tokenizer_path) @@ -80,7 +77,7 @@ def main(args): # Initialize scheduler scheduler = LocalScheduler(config) # Initialize train controller - train_controller = DistributedTrainController(engine, config.model, scheduler) + train_controller = TrainController(engine, config.model, scheduler) train_controller.initialize( config.allocation_mode, ft_spec, From 54ee6fdfdcc10b81923d502b248bf8c4d554a17f Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 20:00:50 +0800 Subject: [PATCH 18/52] renaming --- .../{local_scheduler.py => local.py} | 0 areal/tests/test_local_scheduler.py | 86 +++++++++---------- test.py | 36 ++++++++ 3 files changed, 75 insertions(+), 47 deletions(-) rename areal/scheduler/{local_scheduler.py => local.py} (100%) create mode 100644 test.py diff --git a/areal/scheduler/local_scheduler.py b/areal/scheduler/local.py similarity index 100% rename from areal/scheduler/local_scheduler.py rename to areal/scheduler/local.py diff --git a/areal/tests/test_local_scheduler.py b/areal/tests/test_local_scheduler.py index 8ecdafe25..a06ca170a 100644 --- a/areal/tests/test_local_scheduler.py +++ b/areal/tests/test_local_scheduler.py @@ -25,7 +25,7 @@ WorkerNotFoundError, WorkerTimeoutError, ) -from areal.scheduler.local_scheduler import LocalScheduler, WorkerInfo +from areal.scheduler.local import LocalScheduler, WorkerInfo # ============================================================================ # Fixtures and Helper Functions @@ -258,9 +258,7 @@ class TestPortAllocation: def test_allocate_ports_success(self, tmp_path): """Should allocate requested number of free ports.""" - with patch( - "areal.scheduler.local_scheduler.find_free_ports" - ) as mock_find_ports: + with patch("areal.scheduler.local.find_free_ports") as mock_find_ports: mock_find_ports.return_value = [8000, 8001, 8002] scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) @@ -272,9 +270,7 @@ def test_allocate_ports_success(self, tmp_path): def test_allocate_ports_excludes_already_allocated(self, tmp_path): """Should exclude already allocated ports from search.""" - with patch( - "areal.scheduler.local_scheduler.find_free_ports" - ) as mock_find_ports: + with patch("areal.scheduler.local.find_free_ports") as mock_find_ports: mock_find_ports.side_effect = [ [8000, 8001], [8002, 8003], @@ -298,9 +294,7 @@ def test_allocate_ports_excludes_already_allocated(self, tmp_path): def test_allocate_ports_failure(self, tmp_path): """Should raise PortAllocationError when port allocation fails.""" - with patch( - "areal.scheduler.local_scheduler.find_free_ports" - ) as mock_find_ports: + with patch("areal.scheduler.local.find_free_ports") as mock_find_ports: mock_find_ports.side_effect = ValueError("No free ports available") scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) @@ -314,8 +308,8 @@ def test_allocate_ports_failure(self, tmp_path): class TestWorkerCreation: """Test worker creation with various configurations.""" - @patch("areal.scheduler.local_scheduler.gethostip") - @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") @patch("areal.scheduler.local_scheduler.find_free_ports") def test_create_workers_with_default_spec( self, mock_find_ports, mock_popen, mock_gethostip, tmp_path @@ -345,8 +339,8 @@ def test_create_workers_with_default_spec( # Verify default spec was used assert mock_popen.call_count == 2 - @patch("areal.scheduler.local_scheduler.gethostip") - @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") @patch("areal.scheduler.local_scheduler.find_free_ports") def test_create_workers_with_single_spec_for_all( self, mock_find_ports, mock_popen, mock_gethostip, tmp_path @@ -380,8 +374,8 @@ def test_create_workers_with_single_spec_for_all( for worker_info in scheduler._workers["actor"]: assert len(worker_info.worker.worker_ports) == 3 - @patch("areal.scheduler.local_scheduler.gethostip") - @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") @patch("areal.scheduler.local_scheduler.find_free_ports") def test_create_workers_with_per_worker_specs( self, mock_find_ports, mock_popen, mock_gethostip, tmp_path @@ -415,8 +409,8 @@ def test_create_workers_with_per_worker_specs( assert len(scheduler._workers["critic"][0].worker.worker_ports) == 1 assert len(scheduler._workers["critic"][1].worker.worker_ports) == 2 - @patch("areal.scheduler.local_scheduler.gethostip") - @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") @patch("areal.scheduler.local_scheduler.find_free_ports") def test_create_workers_with_custom_command( self, mock_find_ports, mock_popen, mock_gethostip, tmp_path @@ -454,8 +448,8 @@ def test_create_workers_with_custom_command( cmd_args = popen_call[0][0] assert cmd_args == ["python", "my_custom_server.py", "--port", "8000"] - @patch("areal.scheduler.local_scheduler.gethostip") - @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") @patch("areal.scheduler.local_scheduler.find_free_ports") def test_create_workers_with_environment_variables( self, mock_find_ports, mock_popen, mock_gethostip, tmp_path @@ -496,8 +490,8 @@ def test_create_workers_with_environment_variables( assert env["CUDA_VISIBLE_DEVICES"] == "0" assert env["WORKER_ID"] == "envtest/0" - @patch("areal.scheduler.local_scheduler.gethostip") - @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") @patch("areal.scheduler.local_scheduler.find_free_ports") def test_create_workers_with_colocate_strategy( self, mock_find_ports, mock_popen, mock_gethostip, tmp_path @@ -555,9 +549,9 @@ def test_create_workers_duplicate_role_error(self, tmp_path): scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) with ( - patch("areal.scheduler.local_scheduler.subprocess.Popen") as mock_popen, - patch("areal.scheduler.local_scheduler.find_free_ports") as mock_find_ports, - patch("areal.scheduler.local_scheduler.gethostip") as mock_gethostip, + patch("areal.scheduler.local.subprocess.Popen") as mock_popen, + patch("areal.scheduler.local.find_free_ports") as mock_find_ports, + patch("areal.scheduler.local.gethostip") as mock_gethostip, ): mock_gethostip.return_value = "127.0.0.1" mock_find_ports.return_value = [8000, 8001] @@ -607,8 +601,8 @@ def test_create_workers_invalid_specs_length(self, tmp_path): exc_info.value ) - @patch("areal.scheduler.local_scheduler.gethostip") - @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") @patch("areal.scheduler.local_scheduler.find_free_ports") def test_create_workers_subprocess_fails_immediately( self, mock_find_ports, mock_popen, mock_gethostip, tmp_path @@ -640,8 +634,8 @@ def test_create_workers_subprocess_fails_immediately( assert "exited immediately with code 1" in str(exc_info.value) - @patch("areal.scheduler.local_scheduler.gethostip") - @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") @patch("areal.scheduler.local_scheduler.find_free_ports") def test_create_workers_cleanup_on_partial_failure( self, mock_find_ports, mock_popen, mock_gethostip, tmp_path @@ -700,7 +694,7 @@ def test_get_workers_role_not_found(self, scheduler): assert exc_info.value.worker_id == "nonexistent" - @patch("areal.scheduler.local_scheduler.time.sleep") + @patch("areal.scheduler.local.time.sleep") def test_get_workers_success(self, mock_sleep, scheduler, tmp_path): """Should return workers when all are ready.""" # Create mock workers @@ -720,8 +714,8 @@ def test_get_workers_success(self, mock_sleep, scheduler, tmp_path): assert workers[0].id == "test/0" assert workers[1].id == "test/1" - @patch("areal.scheduler.local_scheduler.time.time") - @patch("areal.scheduler.local_scheduler.time.sleep") + @patch("areal.scheduler.local.time.time") + @patch("areal.scheduler.local.time.sleep") def test_get_workers_timeout(self, mock_sleep, mock_time, scheduler, tmp_path): """Should raise WorkerTimeoutError when timeout is exceeded.""" # Mock time progression - provide enough values @@ -760,7 +754,7 @@ def test_get_workers_process_died(self, scheduler, tmp_path): assert exc_info.value.worker_id == "test/0" assert exc_info.value.exit_code == 1 - @patch("areal.scheduler.local_scheduler.time.sleep") + @patch("areal.scheduler.local.time.sleep") def test_get_workers_gradual_readiness(self, mock_sleep, scheduler, tmp_path): """Should wait for all workers to become ready gradually.""" worker1 = create_worker_info( @@ -946,8 +940,8 @@ def test_cleanup_workers_handles_errors(self, scheduler, tmp_path): class TestProcessTermination: """Test process termination functionality.""" - @patch("areal.scheduler.local_scheduler.psutil.Process") - @patch("areal.scheduler.local_scheduler.psutil.wait_procs") + @patch("areal.scheduler.local.psutil.Process") + @patch("areal.scheduler.local.psutil.wait_procs") def test_terminate_process_tree_graceful( self, mock_wait_procs, mock_process_class, tmp_path ): @@ -976,8 +970,8 @@ def test_terminate_process_tree_graceful( mock_child2.kill.assert_not_called() mock_parent.kill.assert_not_called() - @patch("areal.scheduler.local_scheduler.psutil.Process") - @patch("areal.scheduler.local_scheduler.psutil.wait_procs") + @patch("areal.scheduler.local.psutil.Process") + @patch("areal.scheduler.local.psutil.wait_procs") def test_terminate_process_tree_force_kill( self, mock_wait_procs, mock_process_class, tmp_path ): @@ -1006,7 +1000,7 @@ def process_side_effect(pid): mock_child.terminate.assert_called_once() mock_child.kill.assert_called_once() - @patch("areal.scheduler.local_scheduler.psutil.Process") + @patch("areal.scheduler.local.psutil.Process") def test_terminate_process_tree_no_such_process(self, mock_process_class, tmp_path): """Should handle gracefully when process doesn't exist.""" mock_process_class.side_effect = psutil.NoSuchProcess(1234) @@ -1016,7 +1010,7 @@ def test_terminate_process_tree_no_such_process(self, mock_process_class, tmp_pa # Should not raise scheduler._terminate_process_tree(1234) - @patch("areal.scheduler.local_scheduler.psutil.Process") + @patch("areal.scheduler.local.psutil.Process") def test_terminate_process_tree_handles_child_no_such_process( self, mock_process_class, tmp_path ): @@ -1276,7 +1270,7 @@ def test_call_engine_method_error(self, scheduler, tmp_path): assert "Method 'nonexistent' not found" in str(exc_info.value) - @patch("areal.scheduler.local_scheduler.time.sleep") + @patch("areal.scheduler.local.time.sleep") def test_call_engine_retry_on_503(self, mock_sleep, scheduler, tmp_path): """Should retry on 503 Service Unavailable.""" worker = create_worker_info(log_file=str(tmp_path / "test.log")) @@ -1298,7 +1292,7 @@ def test_call_engine_retry_on_503(self, mock_sleep, scheduler, tmp_path): assert result == "success" assert mock_sleep.called - @patch("areal.scheduler.local_scheduler.time.sleep") + @patch("areal.scheduler.local.time.sleep") def test_call_engine_max_retries_exhausted(self, mock_sleep, scheduler, tmp_path): """Should raise EngineCallError after max retries.""" worker = create_worker_info(log_file=str(tmp_path / "test.log")) @@ -1315,7 +1309,7 @@ def test_call_engine_max_retries_exhausted(self, mock_sleep, scheduler, tmp_path ) or "Service unavailable" in str(exc_info.value) assert exc_info.value.attempt == 3 - @patch("areal.scheduler.local_scheduler.time.sleep") + @patch("areal.scheduler.local.time.sleep") def test_call_engine_exponential_backoff(self, mock_sleep, scheduler, tmp_path): """Should use exponential backoff for retries.""" worker = create_worker_info(log_file=str(tmp_path / "test.log")) @@ -1464,9 +1458,7 @@ def test_gpu_counter_wraps_correctly(self, tmp_path): def test_port_allocation_accumulates_correctly(self, tmp_path): """Should correctly accumulate allocated ports over multiple allocations.""" - with patch( - "areal.scheduler.local_scheduler.find_free_ports" - ) as mock_find_ports: + with patch("areal.scheduler.local.find_free_ports") as mock_find_ports: mock_find_ports.side_effect = [ [8000, 8001], [8002, 8003], @@ -1489,8 +1481,8 @@ def test_port_allocation_accumulates_correctly(self, tmp_path): 8006, } - @patch("areal.scheduler.local_scheduler.gethostip") - @patch("areal.scheduler.local_scheduler.subprocess.Popen") + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") @patch("areal.scheduler.local_scheduler.find_free_ports") def test_worker_id_format( self, mock_find_ports, mock_popen, mock_gethostip, tmp_path diff --git a/test.py b/test.py new file mode 100644 index 000000000..fb9300535 --- /dev/null +++ b/test.py @@ -0,0 +1,36 @@ +from areal.controller.rollout_controller import RolloutController +from areal.engine.sglang_local import LocalSGLangEngine +from areal.scheduler.local import LocalScheduler +from areal.api.cli_args import InferenceEngineConfig +from areal.api.alloc_mode import AllocationMode +config = InferenceEngineConfig( + experiment_name='test', + trial_name='test', + max_concurrent_rollouts=16, + consumer_batch_size=16, + max_head_offpolicyness=2, + enable_rollout_tracing=True, +) +rollout = RolloutController( + inf_engine=LocalSGLangEngine, + config=config, + scheduler=LocalScheduler(log_dir='./logs/integration') +) +rollout.initialize(alloc_mode=AllocationMode.from_str("sglang.d1")) +from areal.tests.utils import TestWorkflow +workflow = TestWorkflow() +from torchdata.stateful_dataloader import StatefulDataLoader +from datasets import Dataset +import random +dataset= Dataset.from_dict(dict(random=[random.random() for _ in range(16)])) +print(dataset) +dataloader = StatefulDataLoader(dataset=dataset, batch_size=4, collate_fn=lambda x: x) +for data in dataloader: + print(data) +result = rollout.prepare_batch( + dataloader, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + should_accept_path=None, +) +print(result) From b21e45213c6ebbce4e11d43885a81d5be303f2cf Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 20:01:07 +0800 Subject: [PATCH 19/52] . --- test.py | 36 ------------------------------------ 1 file changed, 36 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index fb9300535..000000000 --- a/test.py +++ /dev/null @@ -1,36 +0,0 @@ -from areal.controller.rollout_controller import RolloutController -from areal.engine.sglang_local import LocalSGLangEngine -from areal.scheduler.local import LocalScheduler -from areal.api.cli_args import InferenceEngineConfig -from areal.api.alloc_mode import AllocationMode -config = InferenceEngineConfig( - experiment_name='test', - trial_name='test', - max_concurrent_rollouts=16, - consumer_batch_size=16, - max_head_offpolicyness=2, - enable_rollout_tracing=True, -) -rollout = RolloutController( - inf_engine=LocalSGLangEngine, - config=config, - scheduler=LocalScheduler(log_dir='./logs/integration') -) -rollout.initialize(alloc_mode=AllocationMode.from_str("sglang.d1")) -from areal.tests.utils import TestWorkflow -workflow = TestWorkflow() -from torchdata.stateful_dataloader import StatefulDataLoader -from datasets import Dataset -import random -dataset= Dataset.from_dict(dict(random=[random.random() for _ in range(16)])) -print(dataset) -dataloader = StatefulDataLoader(dataset=dataset, batch_size=4, collate_fn=lambda x: x) -for data in dataloader: - print(data) -result = rollout.prepare_batch( - dataloader, - workflow_path="areal.tests.utils.TestWorkflow", - workflow_kwargs={}, - should_accept_path=None, -) -print(result) From 170cc75217da4e61eacc7992dc065ac2180d6eba Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 20:37:12 +0800 Subject: [PATCH 20/52] update train script --- areal/controller/train_controller.py | 12 +++ areal/engine/sft/lm_engine.py | 30 ++++--- areal/scheduler/local.py | 23 ++++- areal/scheduler/rpc/rpc_server.py | 30 ++++++- examples/single-controller/gsm8k_sft.py | 107 ++++++++---------------- 5 files changed, 114 insertions(+), 88 deletions(-) diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index f2b1dd4b4..c2547eaf3 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -406,6 +406,18 @@ def update_weights(self, meta: WeightUpdateMeta): else: raise ValueError(f"Unknown weight update type {meta.type}") + def export_stats(self): + async def _call_all(): + tasks = [ + self.scheduler.async_call_engine(worker.id, "export_stats") + for worker in self.workers + ] + return await asyncio.gather(*tasks) + + results = asyncio.run(_call_all()) + # stats have been aggregated and synchronized. + return results[0] + # ==================== ENGINE RPC WRAPPERS ==================== def train(self, mode: bool = True): """Set the engine to training mode. diff --git a/areal/engine/sft/lm_engine.py b/areal/engine/sft/lm_engine.py index da3c353e3..181029a69 100644 --- a/areal/engine/sft/lm_engine.py +++ b/areal/engine/sft/lm_engine.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any import torch @@ -14,21 +14,25 @@ class LMEngine: def __init__(self, engine: TrainEngine): self.engine = engine - def train_lm(self, data: Dict[str, Any]): + def train_lm(self, data: dict[str, Any]): self.engine.train() - return self.engine.train_batch( - input_=data, - loss_fn=compute_packed_sft_loss, - loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), - ) + with ( + stats_tracker.scope("sft"), + ): + return self.engine.train_batch( + input_=data, + loss_fn=compute_packed_sft_loss, + loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), + ) def evaluate_lm(self, data): self.engine.eval() - return self.engine.eval_batch( - input_=data, - loss_fn=compute_packed_sft_loss, - loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), - ) + with stats_tracker.scope("sft-eval"): + return self.engine.eval_batch( + input_=data, + loss_fn=compute_packed_sft_loss, + loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), + ) class FSDPLMEngine(FSDPEngine): @@ -56,7 +60,7 @@ def evaluate_lm(self, data): def compute_packed_sft_loss( - logits: torch.Tensor, input_: Dict[str, Any] + logits: torch.Tensor, input_: dict[str, Any] ) -> torch.Tensor: # Use rolled input_ids. Ulysses SP will roll input_ids in ulysses_prepare_inputs(). labels: torch.Tensor = input_.get( diff --git a/areal/scheduler/local.py b/areal/scheduler/local.py index d29747a6d..ddae9c072 100644 --- a/areal/scheduler/local.py +++ b/areal/scheduler/local.py @@ -1,5 +1,6 @@ """Local scheduler for managing worker subprocesses on a single GPU node.""" +import getpass import os import shlex import subprocess @@ -65,7 +66,10 @@ class LocalScheduler(Scheduler): def __init__( self, gpu_devices: list[int] | None = None, - log_dir: str = "./logs/workers", + fileroot: str | None = None, + experiment_name: str | None = None, + trial_name: str | None = None, + log_dir: str | None = None, startup_timeout: float = 30.0, health_check_interval: float = 1.0, ): @@ -79,7 +83,19 @@ def __init__( health_check_interval: Interval for health checks (seconds) """ self.gpu_devices = gpu_devices or self._detect_gpus() - self.log_dir = Path(log_dir) + if log_dir is not None: + self.log_dir = Path(log_dir) + else: + assert experiment_name is not None + assert trial_name is not None + assert fileroot is not None + self.log_dir = ( + Path(fileroot) + / "logs" + / getpass.getuser() + / experiment_name + / trial_name + ) self.startup_timeout = startup_timeout self.health_check_interval = health_check_interval @@ -843,6 +859,9 @@ async def async_call_engine( url = f"http://{worker_info.worker.ip}:{port}/run_workflow" # Serialize kwargs for workflow execution payload = serialize_value(kwargs) + elif method == "export_stats": + url = f"http://{worker_info.worker.ip}:{port}/export_stats" + payload = None else: # Standard engine method call url = f"http://{worker_info.worker.ip}:{port}/call" diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index 1881b90d0..078511407 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -16,7 +16,7 @@ from areal.api.engine_api import InferenceEngine, TrainEngine from areal.scheduler.rpc.serialization import deserialize_value, serialize_value -from areal.utils import logging +from areal.utils import logging, stats_tracker logger = logging.getLogger("RPCServer") @@ -375,6 +375,34 @@ async def run_workflow(request: Request): raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") +@app.post("/export_stats") +async def export_stats(request: Request): + try: + body = await request.body() + data = orjson.loads(body) + assert data is None + + global _engine + if isinstance(_engine, TrainEngine): + return { + "status": "success", + "result": stats_tracker.export( + reduce_group=_engine.data_parallel_group + ), + } + else: + assert isinstance(_engine, InferenceEngine) + # Rollout engines do not have collective communication channel. + # Return individual results and reduce in the client side. + return {"status": "success", "result": stats_tracker.export_all()} + + except HTTPException: + raise + except Exception as e: + logger.error(f"Unexpected error in run_workflow: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + def main(): """Main entry point for the RPC server.""" parser = argparse.ArgumentParser(description="AReaL Worker RPC Server") diff --git a/examples/single-controller/gsm8k_sft.py b/examples/single-controller/gsm8k_sft.py index 88eb4e88f..1b1c00e45 100644 --- a/examples/single-controller/gsm8k_sft.py +++ b/examples/single-controller/gsm8k_sft.py @@ -1,11 +1,9 @@ import sys -from torchdata.stateful_dataloader import StatefulDataLoader - +from areal.api.alloc_mode import AllocationMode from areal.api.cli_args import SFTConfig, load_expr_config from areal.api.io_struct import FinetuneSpec, StepInfo from areal.api.scheduler_api import ScheduleStrategy -from areal.controller.batch import DistributedBatchMemory from areal.controller.train_controller import TrainController from areal.dataset import get_custom_dataset from areal.engine.sft.lm_engine import FSDPLMEngine @@ -13,8 +11,8 @@ from areal.utils import logging, stats_tracker from areal.utils.data import ( pad_sequences_to_tensors, - tensor_container_to, ) +from areal.utils.dataloader import create_dataloader from areal.utils.evaluator import Evaluator from areal.utils.hf_utils import load_hf_tokenizer from areal.utils.recover import RecoverHandler @@ -28,44 +26,28 @@ def main(args): config, _ = load_expr_config(args, SFTConfig) config: SFTConfig - engine = FSDPLMEngine(config=config.model) - tokenizer = load_hf_tokenizer(config.tokenizer_path) + + # Create dataset and dataloaders train_dataset = get_custom_dataset( - path=config.train_dataset.path, - rank=0, - world_size=1, - split="train", - max_length=config.train_dataset.max_length, - type=config.train_dataset.type, - tokenizer=tokenizer, + split="train", dataset_config=config.train_dataset, tokenizer=tokenizer ) valid_dataset = get_custom_dataset( - path=config.valid_dataset.path, - rank=0, - world_size=1, - split="test", - max_length=config.valid_dataset.max_length, - type=config.valid_dataset.type, - tokenizer=tokenizer, + split="test", dataset_config=config.valid_dataset, tokenizer=tokenizer ) - - # Create dataset and dataloaders - train_dataloader = StatefulDataLoader( + train_dataloader = create_dataloader( train_dataset, - batch_size=config.train_dataset.batch_size, - shuffle=config.train_dataset.shuffle, - num_workers=config.train_dataset.num_workers, + rank=0, + world_size=1, + dataset_config=config.train_dataset, collate_fn=pad_sequences_to_tensors, - drop_last=config.train_dataset.drop_last, ) - valid_dataloader = StatefulDataLoader( + valid_dataloader = create_dataloader( valid_dataset, - batch_size=config.valid_dataset.batch_size, - shuffle=config.valid_dataset.shuffle, - num_workers=config.valid_dataset.num_workers, + rank=0, + world_size=1, + dataset_config=config.valid_dataset, collate_fn=pad_sequences_to_tensors, - drop_last=config.valid_dataset.drop_last, ) # Initialize engine @@ -74,14 +56,20 @@ def main(args): dataset_size=len(train_dataloader) * config.train_dataset.batch_size, train_batch_size=config.train_dataset.batch_size, ) + # Initialize scheduler - scheduler = LocalScheduler(config) + scheduler = LocalScheduler( + fileroot=config.cluster.fileroot, + experiment_name=config.experiment_name, + trial_name=config.trial_name, + ) # Initialize train controller - train_controller = TrainController(engine, config.model, scheduler) - train_controller.initialize( - config.allocation_mode, - ft_spec, - ScheduleStrategy(), + allocation_mode = AllocationMode.from_str(config.allocation_mode) + engine = TrainController(FSDPLMEngine, config=config.model, scheduler=scheduler) + engine.initialize( + alloc_mode=allocation_mode, + ft_spec=ft_spec, + schedule_strategy=ScheduleStrategy(), ) # Run training. @@ -118,22 +106,14 @@ def main(args): steps_per_epoch=len(train_dataloader), ) - with stats_tracker.record_timing("to_device"): - data = tensor_container_to(data, "cpu") - data = DistributedBatchMemory.from_dict(data) - with ( stats_tracker.record_timing("train_step"), - stats_tracker.scope("sft"), ): - stat = train_controller.train_lm(data) - train_controller.step_lr_scheduler() - logger.info(f"train stat: {stat}") + engine.train_lm(data) + engine.step_lr_scheduler() with stats_tracker.record_timing("save"): - saver.save( - train_controller, epoch, step, global_step, tokenizer=tokenizer - ) + saver.save(engine, epoch, step, global_step, tokenizer=tokenizer) with stats_tracker.record_timing("checkpoint_for_recover"): recover_handler.dump( @@ -149,33 +129,16 @@ def main(args): with stats_tracker.record_timing("eval"): def evaluate_fn(): - with stats_tracker.scope("sft-eval"): - for data in valid_dataloader: - data = tensor_container_to(data, "cpu") - data = DistributedBatchMemory.from_dict(data) - train_controller.evaluate_lm(data) - - evaluator.evaluate( - evaluate_fn, - epoch, - step, - global_step, - ) + for data in valid_dataloader: + engine.evaluate_lm(data) - stats = list() - # todo: gather stats from all ranks - stats.append(stat) - stats.append(stats_tracker.export_all()) - stats_logger.commit( - epoch, - step, - global_step, - stats, - ) + evaluator.evaluate(evaluate_fn, epoch, step, global_step) + + stats_logger.commit(epoch, step, global_step, engine.export_stats()) global_step += 1 stats_logger.close() - train_controller.destroy() + engine.destroy() if __name__ == "__main__": From 157b0b09536c865ec431c9ed286e98828194706a Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 20:54:59 +0800 Subject: [PATCH 21/52] implement rollout stats --- areal/controller/rollout_controller.py | 28 ++++++++++++++++++++++++-- areal/scheduler/rpc/rpc_server.py | 12 ++++++++--- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/areal/controller/rollout_controller.py b/areal/controller/rollout_controller.py index 31101643e..680fa9e4d 100644 --- a/areal/controller/rollout_controller.py +++ b/areal/controller/rollout_controller.py @@ -1,5 +1,3 @@ -"""RolloutController implementation using LocalScheduler and RPC workers.""" - from __future__ import annotations import asyncio @@ -751,6 +749,32 @@ def resume(self): except Exception as e: self.logger.error(f"Error resuming worker {worker.id}: {e}") + def export_stats(self): + async def _call_all(): + tasks = [ + self.scheduler.async_call_engine( + worker=worker, + method="export_stats", + ) + for worker in self.workers + ] + return await asyncio.gather(*tasks) + + # Stats + all_raw_stats = asyncio.run(_call_all()) + stats = {} + exported = set() + for raw_stats in all_raw_stats: + for k in raw_stats: + if k in exported: + continue + data = sum([s[1].get(k, []) for s in all_raw_stats], []) + if len(data) == 0: + continue + stats[k] = sum(data) / len(data) + exported.add(k) + return stats + def register_callback_to_all_worker( self, method: str, callback: Callable, **kwargs ): diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index 078511407..2b590e433 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -392,9 +392,15 @@ async def export_stats(request: Request): } else: assert isinstance(_engine, InferenceEngine) - # Rollout engines do not have collective communication channel. - # Return individual results and reduce in the client side. - return {"status": "success", "result": stats_tracker.export_all()} + # Rollout engines do not have the collective communication channel. + # Return individual results and reduce them in the client side. + raw_stats = {} + for name, tracker in stats_tracker.TRACKERS.items(): + s = {name.strip("/") + k: v for k, v in tracker.stats.items()} + raw_stats.update(s) + # clear stats tracker + stats_tracker.export_all() + return {"status": "success", "result": raw_stats} except HTTPException: raise From 7475004e6f0bb5deb19832c308cef819676f4376 Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Wed, 29 Oct 2025 21:57:23 +0800 Subject: [PATCH 22/52] . --- areal/api/cli_args.py | 10 ++-- areal/controller/train_controller.py | 61 +++++++++++++---------- areal/scheduler/local.py | 32 ++++++++---- areal/scheduler/rpc/rpc_server.py | 3 +- docs/cli_reference.md | 8 +-- examples/single-controller/gsm8k_sft.py | 1 + examples/single-controller/gsm8k_sft.yaml | 24 ++++----- 7 files changed, 81 insertions(+), 58 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 18aa2cb52..30d2cf075 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -3,7 +3,7 @@ import os from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Any, Literal +from typing import Any import uvloop import yaml @@ -314,8 +314,12 @@ class MegatronEngineConfig: @dataclass class ScheduleStrategy: - type: Literal["colocation", "separation"] = "separation" - target: str = "" + type: str = field( + default="separation", metadata={"choices": ["separation", "colocation"]} + ) + target: str | None = field( + default=None, metadata={"help": "The target role to be colocated with"} + ) @dataclass diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index c2547eaf3..c7196b47e 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -1,5 +1,6 @@ import asyncio from collections.abc import Callable +from copy import deepcopy from datetime import datetime from typing import Any @@ -20,6 +21,7 @@ from areal.controller.rollout_controller import RolloutController from areal.platforms import current_platform from areal.utils import logging, name_resolve, names +from areal.utils.network import find_free_ports logger = logging.getLogger("TrainController") @@ -74,18 +76,8 @@ def __init__( self.logger = None def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): - """Initialize PyTorch distributed communication groups. - - Parameters - ---------- - parallel_strategy : ParallelStrategy, optional - The parallel strategy configuration for distributed training, by default None - """ - assert len(self.workers) > 0, "Workers are not created" - if parallel_strategy is None: - parallel_strategy = ParallelStrategy() - self.parallel_strategy = parallel_strategy - self.custom_function_call("create_process_group", parallel_strategy) + # A dummy method. Process group will be created during `initialize` + pass def initialize( self, @@ -118,23 +110,30 @@ def initialize( self._worker_role = role self.alloc_mode = alloc_mode - # Initialize parallel_strategy from alloc_mode if not already set - if self.parallel_strategy is None: - self.parallel_strategy = ParallelStrategy( - data_parallel_size=alloc_mode.train.dp_size, - tensor_parallel_size=alloc_mode.train.tp_size, - pipeline_parallel_size=alloc_mode.train.pp_size, - ) + if alloc_mode.gen_backend == "sglang": + self.config.scheduling_spec.env_vars["NCCL_CUMEM_ENABLE"] = "0" + self.config.scheduling_spec.env_vars["NCCL_NVLS_ENABLE"] = "0" + + self.parallel_strategy = alloc_mode.train # Create job for scheduler job = Job( replicas=alloc_mode.train.world_size, tasks=[ - self.config.scheduling_spec for _ in range(alloc_mode.train.world_size) + deepcopy(self.config.scheduling_spec) + for _ in range(alloc_mode.train.world_size) ], schedule_strategy=schedule_strategy, role=self._worker_role, ) + # Create environment variables to mimic torchrun + for i, task in enumerate(job.tasks): + task.env_vars["RANK"] = str(i) + task.env_vars["WORLD_SIZE"] = str(alloc_mode.train.world_size) + task.env_vars["LOCAL_RANK"] = str(i) + # TODO: find a real master addr with scheduler + task.env_vars["MASTER_ADDR"] = "localhost" + task.env_vars["MASTER_PORT"] = str(find_free_ports(1)[0]) # Create workers via scheduler self.logger.info("Creating workers via scheduler...") @@ -151,19 +150,15 @@ def initialize( engine_path = f"{engine_class.__module__}.{engine_class.__name__}" # Create and initialize engines on workers - asyncio.run( - self._async_create_and_initialize_engines(engine_path, ft_spec, **kwargs) - ) + asyncio.run(self._async_create_engines(engine_path)) + asyncio.run(self._async_initialize_engines(ft_spec, **kwargs)) # Identify DP head workers self._identify_dp_heads() self.logger.info("TrainController initialization complete") - async def _async_create_and_initialize_engines( - self, engine_path: str, ft_spec: FinetuneSpec, **kwargs - ): - """Create and initialize engines on all workers.""" + async def _async_create_engines(self, engine_path: str): # Create engines on workers self.logger.info("Creating engines on workers...") tasks = [ @@ -177,13 +172,25 @@ async def _async_create_and_initialize_engines( await asyncio.gather(*tasks) self.logger.info("Engines created on all workers!") + async def _async_initialize_engines(self, ft_spec: FinetuneSpec, **kwargs): # Initialize engines self.logger.info("Calling engine initialization...") + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, + method="create_process_group", + parallel_strategy=self.parallel_strategy, + _should_bcast=False, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) tasks = [ self.scheduler.async_call_engine( worker_id=worker.id, method="initialize", ft_spec=ft_spec, + _should_bcast=False, **kwargs, ) for worker in self.workers diff --git a/areal/scheduler/local.py b/areal/scheduler/local.py index ddae9c072..fe0241dd1 100644 --- a/areal/scheduler/local.py +++ b/areal/scheduler/local.py @@ -15,6 +15,7 @@ import psutil from areal.api.scheduler_api import Job, Scheduler, SchedulingSpec, Worker +from areal.platforms import current_platform from areal.scheduler.exceptions import ( EngineCallError, EngineCreationError, @@ -30,6 +31,9 @@ ) from areal.scheduler.rpc.serialization import deserialize_value, serialize_value from areal.utils import logging +from areal.utils.launcher import ( + get_env_vars, +) from areal.utils.network import find_free_ports, gethostip logger = logging.getLogger("LocalScheduler") @@ -69,6 +73,7 @@ def __init__( fileroot: str | None = None, experiment_name: str | None = None, trial_name: str | None = None, + cluster_name: str | None = None, log_dir: str | None = None, startup_timeout: float = 30.0, health_check_interval: float = 1.0, @@ -96,6 +101,7 @@ def __init__( / experiment_name / trial_name ) + self.cluster_name = cluster_name self.startup_timeout = startup_timeout self.health_check_interval = health_check_interval @@ -122,13 +128,13 @@ def __init__( def _detect_gpus(self) -> list[int]: """Detect available GPU devices.""" - cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES") + cuda_visible = os.environ.get(current_platform.device_control_env_var) if cuda_visible: try: return [int(x) for x in cuda_visible.split(",")] except ValueError: logger.warning( - f"Invalid CUDA_VISIBLE_DEVICES: {cuda_visible}, using default [0]" + f"Invalid {current_platform.device_control_env_var}: {cuda_visible}, using default [0]" ) return [0] # Default to single GPU @@ -335,9 +341,13 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: ) from e # Prepare environment - env = os.environ.copy() - env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_devices)) - env["WORKER_ID"] = worker_id + env = get_env_vars( + self.cluster_name, + ",".join([f"{k}={v}" for k, v in scheduling.env_vars.items()]), + ) + env[current_platform.device_control_env_var] = ",".join( + map(str, gpu_devices) + ) # Merge user-provided environment variables from scheduling if scheduling.env_vars: @@ -364,7 +374,7 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: if args: cmd.extend(args) - logger.debug(f"Starting worker {worker_id}: {' '.join(cmd)}") + logger.info(f"Starting worker {worker_id}: {' '.join(cmd)}") # Spawn subprocess try: @@ -603,8 +613,12 @@ def _terminate_process_tree(self, pid: int): except psutil.NoSuchProcess: # Process already gone pass - except Exception as e: - logger.warning(f"Error terminating process tree {pid}: {e}") + except Exception: + import traceback + + logger.warning( + f"Error terminating process tree {pid}: {traceback.print_exec()}" + ) def _read_log_tail(self, log_file: str, lines: int = 50) -> str: """Read the last N lines from a log file.""" @@ -666,7 +680,7 @@ async def create_engine( try: logger.info(f"Creating engine '{engine}' on worker '{worker_id}'") - response = self._http_client.post( + response = await self._async_http_client.post( url, content=orjson.dumps(payload), headers={"Content-Type": "application/json"}, diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index 2b590e433..e336827b7 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -178,7 +178,8 @@ async def call_engine_method(request: Request): kwargs = deserialize_value(kwargs) try: - if isinstance(_engine, TrainEngine): + should_bcast = kwargs.pop("_should_bcast", True) + if isinstance(_engine, TrainEngine) and should_bcast: logger.info(f"Broadcasting data for TrainEngine method: {method_name}") from areal.utils.data import broadcast_tensor_container diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 69189abf2..22b66b775 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -813,10 +813,10 @@ Refer to Megatron-LM documentation for implementation details. Configuration class: ScheduleStrategy -| Parameter | Type | Default | Description | -| --------- | --------- | -------------- | ----------- | -| `type` | `Literal` | `"separation"` | - | -| `target` | string | `""` | - | +| Parameter | Type | Default | Description | +| --------- | -------------- | -------------- | ----------------------------------------- | +| `type` | string | `"separation"` | - **Choices:** `separation`, `colocation` | +| `target` | string \| None | `None` | The target role to be colocated with | (section-scheduler)= diff --git a/examples/single-controller/gsm8k_sft.py b/examples/single-controller/gsm8k_sft.py index 1b1c00e45..b7906750f 100644 --- a/examples/single-controller/gsm8k_sft.py +++ b/examples/single-controller/gsm8k_sft.py @@ -67,6 +67,7 @@ def main(args): allocation_mode = AllocationMode.from_str(config.allocation_mode) engine = TrainController(FSDPLMEngine, config=config.model, scheduler=scheduler) engine.initialize( + role="default", alloc_mode=allocation_mode, ft_spec=ft_spec, schedule_strategy=ScheduleStrategy(), diff --git a/examples/single-controller/gsm8k_sft.yaml b/examples/single-controller/gsm8k_sft.yaml index f28d65f2a..936435ddd 100644 --- a/examples/single-controller/gsm8k_sft.yaml +++ b/examples/single-controller/gsm8k_sft.yaml @@ -8,17 +8,17 @@ tokenizer_path: ${model.path} cluster: n_nodes: 1 n_gpus_per_node: 8 - fileroot: /tmp/areal/experiments + fileroot: /storage/openpsi/experiments name_resolve: type: nfs - nfs_record_root: /tmp/areal/name_resolve + nfs_record_root: /storage/openpsi/name_resolve allocation_mode: d8p1t1 model: experiment_name: ${experiment_name} trial_name: ${trial_name} - path: Qwen/Qwen3-1.7B + path: /storage/openpsi/models/Qwen__Qwen3-1.7B init_from_scratch: false gradient_checkpointing: false dtype: bfloat16 @@ -34,18 +34,20 @@ model: lr_scheduler_type: cosine gradient_clipping: 1.0 backend: fsdp - scheduling_specs: - - type: worker + scheduling_spec: + type: worker port_count: 1 gpu: 1 - cmd: python3 -m areal.scheduler.rpc.rpc_server + # AReaL will by default uses `python3 -m areal.scheduler.rpc.rpc_server --port {PORT}` + # where ${PORT} is dynamically allocated + # cmd: python3 -m areal.scheduler.rpc.rpc_server train_dataset: batch_size: 128 shuffle: true pin_memory: true num_workers: 4 - path: openai/gsm8k + path: /storage/openpsi/data/gsm8k type: sft valid_dataset: @@ -53,7 +55,7 @@ valid_dataset: shuffle: true pin_memory: true num_workers: 4 - path: openai/gsm8k + path: /storage/openpsi/data/gsm8k type: sft # Utilities @@ -88,9 +90,3 @@ stats_logger: fileroot: ${cluster.fileroot} wandb: mode: disabled - -launcher: - inference_server_cpus_per_gpu: 4 - inference_server_mem_per_gpu: 32768 - trainer_cpus_per_gpu: 4 - trainer_mem_per_gpu: 32768 From 6e54a589b0b6f0266549fed5b468b51e912f3b82 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 22:07:37 +0800 Subject: [PATCH 23/52] fix --- areal/scheduler/rpc/rpc_server.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index e336827b7..bee6dddbe 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -8,10 +8,11 @@ import importlib import traceback from contextlib import asynccontextmanager +from typing import Any import orjson import uvicorn -from fastapi import FastAPI, HTTPException, Request +from fastapi import Body, FastAPI, HTTPException, Request from fastapi.responses import ORJSONResponse from areal.api.engine_api import InferenceEngine, TrainEngine @@ -60,7 +61,7 @@ async def health_check(): @app.post("/create_engine") -async def create_engine(request: Request): +def create_engine(data: dict[str, Any] = Body(...)): """ Create and initialize an engine instance on this worker. @@ -74,9 +75,6 @@ async def create_engine(request: Request): global _engine try: - body = await request.body() - data = orjson.loads(body) - engine_path = data.get("engine") # Deserialize init_args and init_kwargs (may contain tensors or dataclasses) init_args = deserialize_value(data.get("init_args", [])) @@ -141,7 +139,7 @@ async def create_engine(request: Request): @app.post("/call") -async def call_engine_method(request: Request): +def call_engine_method(data: dict[str, Any] = Body(...)): """ Call a method on the engine instance. @@ -161,9 +159,6 @@ async def call_engine_method(request: Request): ) try: - body = await request.body() - data = orjson.loads(body) - method_name = data.get("method") args = data.get("args", []) kwargs = data.get("kwargs", {}) @@ -377,10 +372,8 @@ async def run_workflow(request: Request): @app.post("/export_stats") -async def export_stats(request: Request): +def export_stats(data: dict[str, Any] | None = Body(None)): try: - body = await request.body() - data = orjson.loads(body) assert data is None global _engine From deee027d5ea3d5dbf26f09dacada07edbd45c912 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 22:40:47 +0800 Subject: [PATCH 24/52] add sync rpc server --- areal/api/cli_args.py | 8 +- areal/scheduler/local.py | 29 +- areal/scheduler/rpc/async_rpc_server.py | 439 ++++++++++++++++++++++++ areal/scheduler/rpc/rpc_server.py | 68 +--- areal/scheduler/rpc/sync_rpc_server.py | 249 ++++++++++++++ 5 files changed, 721 insertions(+), 72 deletions(-) create mode 100644 areal/scheduler/rpc/async_rpc_server.py create mode 100644 areal/scheduler/rpc/sync_rpc_server.py diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 30d2cf075..51ae0c492 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -432,7 +432,9 @@ class TrainEngineConfig: metadata={"help": "peft method type. Only LoRA is supported for now."}, ) scheduling_spec: SchedulingSpec = field( - default_factory=SchedulingSpec, + default_factory=lambda: SchedulingSpec( + cmd="python -m areal.scheduler.rpc.sync_rpc_server" + ), metadata={"help": "train engine schedule specs"}, ) scheduling_strategy: ScheduleStrategy = field(default_factory=ScheduleStrategy) @@ -905,7 +907,9 @@ class InferenceEngineConfig: }, ) scheduling_spec: SchedulingSpec = field( - default_factory=SchedulingSpec, + default_factory=lambda: SchedulingSpec( + cmd="python -m areal.scheduler.rpc.async_rpc_server" + ), metadata={"help": "inference engine schedule specs"}, ) scheduling_strategy: ScheduleStrategy = field(default_factory=ScheduleStrategy) diff --git a/areal/scheduler/local.py b/areal/scheduler/local.py index fe0241dd1..a80f3987c 100644 --- a/areal/scheduler/local.py +++ b/areal/scheduler/local.py @@ -4,7 +4,6 @@ import os import shlex import subprocess -import sys import time from dataclasses import dataclass, field from pathlib import Path @@ -357,22 +356,18 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: log_file = self.log_dir / f"{worker_id.replace('/', '_')}.log" # Build command to start RPC server - if scheduling.cmd: - # Use custom command from scheduling - cmd = shlex.split(scheduling.cmd) - else: - # Default: start RPC server - cmd = [ - sys.executable, - "-m", - "areal.scheduler.rpc.rpc_server", - "--port", - str(ports[0]), # Main RPC port - ] - - # Add any additional arguments - if args: - cmd.extend(args) + if not scheduling.cmd: + self._cleanup_workers(workers) + raise WorkerCreationError( + role, + f"SchedulingSpec.cmd is required but not set for worker {worker_id}", + "Specify either 'python -m areal.scheduler.rpc.async_rpc_server' or " + "'python -m areal.scheduler.rpc.sync_rpc_server' in your config.", + ) + + cmd = shlex.split(scheduling.cmd) + # Append --port argument to command + cmd.extend(["--port", str(ports[0])]) logger.info(f"Starting worker {worker_id}: {' '.join(cmd)}") diff --git a/areal/scheduler/rpc/async_rpc_server.py b/areal/scheduler/rpc/async_rpc_server.py new file mode 100644 index 000000000..576362807 --- /dev/null +++ b/areal/scheduler/rpc/async_rpc_server.py @@ -0,0 +1,439 @@ +"""Async FastAPI-based RPC server for InferenceEngine workers. + +This server runs on worker nodes to expose InferenceEngine methods via HTTP/JSON RPC. +It uses safe JSON serialization instead of cloudpickle and supports async workflow +execution via the /run_workflow endpoint. + +Key differences from sync_rpc_server: +- Multi-threaded: Uses FastAPI/uvicorn with async support +- InferenceEngine: Primarily for InferenceEngine (async rollout generation) +- Has /run_workflow: Supports direct workflow execution +- All async endpoints: All HTTP handlers are async functions +""" + +import argparse +import importlib +import traceback +from contextlib import asynccontextmanager +from typing import Any + +import orjson +import uvicorn +from fastapi import Body, FastAPI, HTTPException, Request +from fastapi.responses import ORJSONResponse + +from areal.api.engine_api import InferenceEngine, TrainEngine +from areal.scheduler.rpc.serialization import deserialize_value, serialize_value +from areal.utils import logging, stats_tracker + +logger = logging.getLogger("RPCServer") + +# Global engine instance - must be TrainEngine or InferenceEngine +_engine: TrainEngine | InferenceEngine | None = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Lifespan context manager for startup and shutdown events.""" + # Startup + logger.info("RPC server starting up...") + yield + # Shutdown + global _engine + logger.info("Shutting down RPC server...") + if _engine is not None: + try: + # Call destroy method if available + if hasattr(_engine, "destroy"): + _engine.destroy() + logger.info("Engine destroyed successfully") + except Exception as e: + logger.error(f"Error destroying engine: {e}") + _engine = None + + +app = FastAPI( + title="AReaL Worker RPC Server", + description="FastAPI-based RPC server for remote engine operations", + default_response_class=ORJSONResponse, + lifespan=lifespan, +) +app._expected_trajectory_keys = None + + +@app.get("/health") +async def health_check(): + """Health check endpoint to verify server is alive.""" + return {"status": "healthy", "engine_initialized": _engine is not None} + + +@app.post("/create_engine") +async def create_engine(data: dict[str, Any] = Body(...)): + """ + Create and initialize an engine instance on this worker. + + Expected JSON payload: + { + "engine": "areal.engine.ppo.actor.FSDPPPOActor", # Import path + "init_args": [...], # Positional arguments + "init_kwargs": {...} # Keyword arguments + } + """ + global _engine + + try: + engine_path = data.get("engine") + # Deserialize init_args and init_kwargs (may contain tensors or dataclasses) + init_args = deserialize_value(data.get("init_args", [])) + init_kwargs = deserialize_value(data.get("init_kwargs", {})) + + if not engine_path: + raise HTTPException( + status_code=400, detail="Missing 'engine' field in request" + ) + + # Dynamic import + try: + module_path, class_name = engine_path.rsplit(".", 1) + module = importlib.import_module(module_path) + engine_class = getattr(module, class_name) + + # Validate that the class is a TrainEngine or InferenceEngine + if not ( + issubclass(engine_class, TrainEngine) + or issubclass(engine_class, InferenceEngine) + ): + raise TypeError( + f"Engine class must be a subclass of TrainEngine or InferenceEngine, " + f"got {engine_class}" + ) + except (ValueError, ImportError, AttributeError) as e: + logger.error(f"Failed to import engine '{engine_path}': {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to import engine '{engine_path}': {str(e)}", + ) + except TypeError as e: + logger.error(f"Invalid engine type: {e}") + raise HTTPException( + status_code=400, + detail=str(e), + ) + + # Instantiate engine + try: + _engine = engine_class(*init_args, **init_kwargs) + logger.info(f"Engine '{engine_path}' instantiated successfully") + return { + "status": "success", + "message": f"Engine '{engine_path}' created and initialized", + "result": None, + } + except Exception as e: + logger.error(f"Failed to instantiate engine: {e}\n{traceback.format_exc()}") + raise HTTPException( + status_code=500, + detail=f"Failed to instantiate engine: {str(e)}", + ) + + except HTTPException: + raise + except Exception as e: + logger.error( + f"Unexpected error in create_engine: {e}\n{traceback.format_exc()}" + ) + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + +@app.post("/call") +async def call_engine_method(data: dict[str, Any] = Body(...)): + """ + Call a method on the engine instance. + + Expected JSON payload: + { + "method": "train_batch", + "args": [...], + "kwargs": {...} + } + """ + global _engine + + if _engine is None: + raise HTTPException( + status_code=503, + detail="Engine not initialized. Call /create_engine first.", + ) + + try: + method_name = data.get("method") + args = data.get("args", []) + kwargs = data.get("kwargs", {}) + + if not method_name: + raise HTTPException( + status_code=400, detail="Missing 'method' field in request" + ) + + # Deserialize args and kwargs (convert SerializedTensor dicts to tensors) + args = deserialize_value(args) + kwargs = deserialize_value(kwargs) + + try: + should_bcast = kwargs.pop("_should_bcast", True) + if isinstance(_engine, TrainEngine) and should_bcast: + logger.info(f"Broadcasting data for TrainEngine method: {method_name}") + from areal.utils.data import broadcast_tensor_container + + args = broadcast_tensor_container( + args, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + kwargs = broadcast_tensor_container( + kwargs, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + logger.info("Broadcasting data done.") + except Exception as e: + logger.error( + f"Broadcasting data for method '{method_name}' failed: {e}\n{traceback.format_exc()}" + ) + raise HTTPException( + status_code=500, + detail=f"Data bcast '{method_name}' failed: {str(e)}", + ) + + # Call method directly (no need for hasattr/getattr with typed engine) + logger.info(f"Calling engine method: {method_name}") + try: + # Get the method - will raise AttributeError if it doesn't exist + method = getattr(_engine, method_name) + result = method(*args, **kwargs) + + # Serialize result (convert tensors to SerializedTensor dicts) + serialized_result = serialize_value(result) + return {"status": "success", "result": serialized_result} + + except AttributeError as e: + logger.error(f"Method '{method_name}' not found on engine: {e}") + raise HTTPException( + status_code=400, + detail=f"Engine does not have method '{method_name}'", + ) + except Exception as e: + logger.error( + f"Engine method '{method_name}' failed: {e}\n{traceback.format_exc()}" + ) + raise HTTPException( + status_code=500, + detail=f"Engine method '{method_name}' failed: {str(e)}", + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Unexpected error in call: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + +@app.post("/run_workflow") +async def run_workflow(request: Request): + """ + Run a workflow's arun_episode method directly without using the engine. + + Expected JSON payload: + { + "workflow": "areal.api.workflow_api.RolloutWorkflow", # Import path + "workflow_kwargs": {...}, # Keyword arguments for workflow instantiation + "data": {...} # Data to pass to arun_episode + } + """ + try: + body = await request.body() + data = orjson.loads(body) + + workflow_path = data.get("workflow") + workflow_kwargs = data.get("workflow_kwargs") + episode_data = data.get("data") + should_accept_path = data.get("should_accept_path", None) + check_trajectory_format = data.get("check_trajectory_format") + + if not workflow_path: + raise HTTPException( + status_code=400, detail="Missing 'workflow' field in request" + ) + + if episode_data is None: + raise HTTPException( + status_code=400, detail="Missing 'data' field in request" + ) + + # Deserialize episode_data (may contain tensors) + episode_data = deserialize_value(episode_data) + + # Dynamic import workflow + try: + module_path, class_name = workflow_path.rsplit(".", 1) + module = importlib.import_module(module_path) + workflow_class = getattr(module, class_name) + logger.info(f"Imported workflow class: {workflow_path}") + except (ValueError, ImportError, AttributeError) as e: + logger.error(f"Failed to import workflow '{workflow_path}': {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to import workflow '{workflow_path}': {str(e)}", + ) + # Instantiate workflow + try: + workflow = workflow_class(**workflow_kwargs) + logger.info(f"Workflow '{workflow_path}' instantiated successfully") + except Exception as e: + logger.error( + f"Failed to instantiate workflow: {e}\n{traceback.format_exc()}" + ) + raise HTTPException( + status_code=500, + detail=f"Failed to instantiate workflow: {str(e)}", + ) + + should_accept = None + if should_accept_path is not None: + # Dynamic import filtering function + try: + module_path, fn_name = should_accept_path.rsplit(".", 1) + module = importlib.import_module(module_path) + should_accept = getattr(module, fn_name) + logger.info(f"Imported filtering function: {should_accept_path}") + except (ValueError, ImportError, AttributeError) as e: + logger.error( + f"Failed to import filtering function '{should_accept_path}': {e}" + ) + raise HTTPException( + status_code=400, + detail=f"Failed to import filtering function '{should_accept_path}': {str(e)}", + ) + + # Run episode + try: + global _engine + traj = await workflow.arun_episode(_engine, episode_data) + + global app + if check_trajectory_format and traj is not None: + from areal.core.workflow_executor import ( + check_trajectory_format as check_fn, + ) + + check_fn( + traj, + expected_keys=app._expected_trajectory_keys, + logger=logger, + ) + # Track expected keys for consistency checking + if isinstance(traj, dict) and "input_ids" in traj: + if app._expected_trajectory_keys is None: + app._expected_trajectory_keys = set(traj.keys()) + logger.info( + f"Trajectory format check: tracking keys " + f"{app._expected_trajectory_keys}" + ) + + from areal.experimental.openai.types import InteractionWithTokenLogpReward + from areal.utils.data import concat_padded_tensors + + # Convert InteractionWithTokenLogpReward to tensor dict if needed + if isinstance(traj, dict) and all( + isinstance(v, InteractionWithTokenLogpReward) for v in traj.values() + ): + traj = concat_padded_tensors( + [v.to_tensor_dict() for v in traj.values()] + ) + + assert traj is None or isinstance(traj, dict), traj + + # Apply should_accept filtering + accept_this = traj is not None and ( + should_accept is None or should_accept(traj) + ) + + # Serialize trajectory result (convert tensors to SerializedTensor dicts) + if accept_this: + serialized_traj = serialize_value(traj) + return {"status": "success", "result": serialized_traj} + else: + return {"status": "success", "result": None} + except Exception as e: + logger.error(f"Workflow arun_episode failed: {e}\n{traceback.format_exc()}") + raise HTTPException( + status_code=500, + detail=f"Workflow arun_episode failed: {str(e)}", + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Unexpected error in run_workflow: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + +@app.post("/export_stats") +async def export_stats(data: dict[str, Any] | None = Body(None)): + try: + assert data is None + + global _engine + if isinstance(_engine, TrainEngine): + return { + "status": "success", + "result": stats_tracker.export( + reduce_group=_engine.data_parallel_group + ), + } + else: + assert isinstance(_engine, InferenceEngine) + # Rollout engines do not have the collective communication channel. + # Return individual results and reduce them in the client side. + raw_stats = {} + for name, tracker in stats_tracker.TRACKERS.items(): + s = {name.strip("/") + k: v for k, v in tracker.stats.items()} + raw_stats.update(s) + # clear stats tracker + stats_tracker.export_all() + return {"status": "success", "result": raw_stats} + + except HTTPException: + raise + except Exception as e: + logger.error(f"Unexpected error in export_stats: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + +def main(): + """Main entry point for the async RPC server.""" + parser = argparse.ArgumentParser( + description="AReaL Async RPC Server for InferenceEngine" + ) + parser.add_argument("--port", type=int, required=True, help="Port to serve on") + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)" + ) + + args, _ = parser.parse_known_args() + + logger.info(f"Starting async RPC server on {args.host}:{args.port}") + + # Run uvicorn server with a single worker (required for GPU workloads) + uvicorn.run( + app, + host=args.host, + port=args.port, + workers=1, # Single worker required for GPU memory management + log_level="info", + access_log=True, + ) + + +if __name__ == "__main__": + main() diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index bee6dddbe..08e081601 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -15,14 +15,14 @@ from fastapi import Body, FastAPI, HTTPException, Request from fastapi.responses import ORJSONResponse -from areal.api.engine_api import InferenceEngine, TrainEngine +from areal.api.engine_api import InferenceEngine from areal.scheduler.rpc.serialization import deserialize_value, serialize_value from areal.utils import logging, stats_tracker logger = logging.getLogger("RPCServer") -# Global engine instance - must be TrainEngine or InferenceEngine -_engine: TrainEngine | InferenceEngine | None = None +# Global engine instance - must be InferenceEngine +_engine: InferenceEngine | None = None @asynccontextmanager @@ -91,13 +91,9 @@ def create_engine(data: dict[str, Any] = Body(...)): module = importlib.import_module(module_path) engine_class = getattr(module, class_name) - # Validate that the class is a TrainEngine or InferenceEngine - if not ( - issubclass(engine_class, TrainEngine) - or issubclass(engine_class, InferenceEngine) - ): + if not (issubclass(engine_class, InferenceEngine)): raise TypeError( - f"Engine class must be a subclass of TrainEngine or InferenceEngine, " + f"Engine class must be a subclass of InferenceEngine, " f"got {engine_class}" ) except (ValueError, ImportError, AttributeError) as e: @@ -172,32 +168,6 @@ def call_engine_method(data: dict[str, Any] = Body(...)): args = deserialize_value(args) kwargs = deserialize_value(kwargs) - try: - should_bcast = kwargs.pop("_should_bcast", True) - if isinstance(_engine, TrainEngine) and should_bcast: - logger.info(f"Broadcasting data for TrainEngine method: {method_name}") - from areal.utils.data import broadcast_tensor_container - - args = broadcast_tensor_container( - args, - src_rank=_engine.current_data_parallel_head(), - group=_engine.context_and_model_parallel_group, - ) - kwargs = broadcast_tensor_container( - kwargs, - src_rank=_engine.current_data_parallel_head(), - group=_engine.context_and_model_parallel_group, - ) - logger.info("Broadcasting data done.") - except Exception as e: - logger.error( - f"Broadcasting data for method '{method_name}' failed: {e}\n{traceback.format_exc()}" - ) - raise HTTPException( - status_code=500, - detail=f"Data bcast '{method_name}' failed: {str(e)}", - ) - # Call method directly (no need for hasattr/getattr with typed engine) logger.info(f"Calling engine method: {method_name}") try: @@ -377,24 +347,16 @@ def export_stats(data: dict[str, Any] | None = Body(None)): assert data is None global _engine - if isinstance(_engine, TrainEngine): - return { - "status": "success", - "result": stats_tracker.export( - reduce_group=_engine.data_parallel_group - ), - } - else: - assert isinstance(_engine, InferenceEngine) - # Rollout engines do not have the collective communication channel. - # Return individual results and reduce them in the client side. - raw_stats = {} - for name, tracker in stats_tracker.TRACKERS.items(): - s = {name.strip("/") + k: v for k, v in tracker.stats.items()} - raw_stats.update(s) - # clear stats tracker - stats_tracker.export_all() - return {"status": "success", "result": raw_stats} + assert isinstance(_engine, InferenceEngine) + # Rollout engines do not have the collective communication channel. + # Return individual results and reduce them in the client side. + raw_stats = {} + for name, tracker in stats_tracker.TRACKERS.items(): + s = {name.strip("/") + k: v for k, v in tracker.stats.items()} + raw_stats.update(s) + # clear stats tracker + stats_tracker.export_all() + return {"status": "success", "result": raw_stats} except HTTPException: raise diff --git a/areal/scheduler/rpc/sync_rpc_server.py b/areal/scheduler/rpc/sync_rpc_server.py new file mode 100644 index 000000000..f02a2d65c --- /dev/null +++ b/areal/scheduler/rpc/sync_rpc_server.py @@ -0,0 +1,249 @@ +"""Single-threaded Flask-based RPC server for distributed TrainEngine workers. + +This server runs on worker nodes to expose TrainEngine methods via HTTP/JSON RPC. +It uses a single-threaded WSGI server to avoid threading conflicts with PyTorch +distributed communication (NCCL). + +Key differences from async_rpc_server: +- Single-threaded: Uses Flask with threaded=False for NCCL compatibility +- TrainEngine only: Only accepts TrainEngine subclasses +- No /run_workflow: Workflow execution is handled by async_rpc_server +""" + +import argparse +import importlib +import traceback + +from flask import Flask, jsonify, request + +from areal.api.engine_api import TrainEngine +from areal.scheduler.rpc.serialization import deserialize_value, serialize_value +from areal.utils import logging, stats_tracker + +logger = logging.getLogger("SyncRPCServer") + +# Global engine instance - must be TrainEngine +_engine: TrainEngine | None = None + + +app = Flask(__name__) + + +@app.route("/health", methods=["GET"]) +def health_check(): + """Health check endpoint to verify server is alive.""" + return jsonify({"status": "healthy", "engine_initialized": _engine is not None}) + + +@app.route("/create_engine", methods=["POST"]) +def create_engine(): + """ + Create and initialize a TrainEngine instance on this worker. + + Expected JSON payload: + { + "engine": "areal.engine.ppo.actor.FSDPPPOActor", # Import path + "init_args": [...], # Positional arguments + "init_kwargs": {...} # Keyword arguments + } + """ + global _engine + + try: + data = request.get_json() + engine_path = data.get("engine") + # Deserialize init_args and init_kwargs (may contain tensors or dataclasses) + init_args = deserialize_value(data.get("init_args", [])) + init_kwargs = deserialize_value(data.get("init_kwargs", {})) + + if not engine_path: + return jsonify({"error": "Missing 'engine' field in request"}), 400 + + # Dynamic import + try: + module_path, class_name = engine_path.rsplit(".", 1) + module = importlib.import_module(module_path) + engine_class = getattr(module, class_name) + + # Validate that the class is a TrainEngine + if not issubclass(engine_class, TrainEngine): + raise TypeError( + f"Engine class must be a subclass of TrainEngine, " + f"got {engine_class}. Use async_rpc_server for InferenceEngine." + ) + except (ValueError, ImportError, AttributeError) as e: + logger.error(f"Failed to import engine '{engine_path}': {e}") + return ( + jsonify( + {"error": f"Failed to import engine '{engine_path}': {str(e)}"} + ), + 400, + ) + except TypeError as e: + logger.error(f"Invalid engine type: {e}") + return jsonify({"error": str(e)}), 400 + + # Instantiate engine + try: + _engine = engine_class(*init_args, **init_kwargs) + logger.info(f"Engine '{engine_path}' instantiated successfully") + return jsonify( + { + "status": "success", + "message": f"Engine '{engine_path}' created and initialized", + "result": None, + } + ) + except Exception as e: + logger.error(f"Failed to instantiate engine: {e}\n{traceback.format_exc()}") + return ( + jsonify({"error": f"Failed to instantiate engine: {str(e)}"}), + 500, + ) + + except Exception as e: + logger.error( + f"Unexpected error in create_engine: {e}\n{traceback.format_exc()}" + ) + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + +@app.route("/call", methods=["POST"]) +def call_engine_method(): + """ + Call a method on the TrainEngine instance. + + Expected JSON payload: + { + "method": "train_batch", + "args": [...], + "kwargs": {...} + } + """ + global _engine + + if _engine is None: + return ( + jsonify({"error": "Engine not initialized. Call /create_engine first."}), + 503, + ) + + try: + data = request.get_json() + method_name = data.get("method") + args = data.get("args", []) + kwargs = data.get("kwargs", {}) + + if not method_name: + return jsonify({"error": "Missing 'method' field in request"}), 400 + + # Deserialize args and kwargs (convert SerializedTensor dicts to tensors) + args = deserialize_value(args) + kwargs = deserialize_value(kwargs) + + try: + should_bcast = kwargs.pop("_should_bcast", True) + if should_bcast: + logger.info(f"Broadcasting data for TrainEngine method: {method_name}") + from areal.utils.data import broadcast_tensor_container + + args = broadcast_tensor_container( + args, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + kwargs = broadcast_tensor_container( + kwargs, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + logger.info("Broadcasting data done.") + except Exception as e: + logger.error( + f"Broadcasting data for method '{method_name}' failed: {e}\n{traceback.format_exc()}" + ) + return ( + jsonify({"error": f"Data broadcast '{method_name}' failed: {str(e)}"}), + 500, + ) + + # Call method directly + logger.info(f"Calling engine method: {method_name}") + try: + # Get the method - will raise AttributeError if it doesn't exist + method = getattr(_engine, method_name) + result = method(*args, **kwargs) + + # Serialize result (convert tensors to SerializedTensor dicts) + serialized_result = serialize_value(result) + return jsonify({"status": "success", "result": serialized_result}) + + except AttributeError as e: + logger.error(f"Method '{method_name}' not found on engine: {e}") + return ( + jsonify({"error": f"Engine does not have method '{method_name}'"}), + 400, + ) + except Exception as e: + logger.error( + f"Engine method '{method_name}' failed: {e}\n{traceback.format_exc()}" + ) + return ( + jsonify({"error": f"Engine method '{method_name}' failed: {str(e)}"}), + 500, + ) + + except Exception as e: + logger.error(f"Unexpected error in call: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + +@app.route("/export_stats", methods=["POST"]) +def export_stats(): + """Export training statistics from stats_tracker.""" + try: + global _engine + if _engine is None: + return ( + jsonify({"error": "Engine not initialized"}), + 503, + ) + + # TrainEngine: reduce stats across data_parallel_group + result = stats_tracker.export(reduce_group=_engine.data_parallel_group) + return jsonify({"status": "success", "result": result}) + + except Exception as e: + logger.error(f"Unexpected error in export_stats: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + +def main(): + """Main entry point for the sync RPC server.""" + parser = argparse.ArgumentParser( + description="AReaL Sync RPC Server for TrainEngine" + ) + parser.add_argument("--port", type=int, required=True, help="Port to serve on") + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)" + ) + + args, _ = parser.parse_known_args() + + logger.info(f"Starting sync RPC server on {args.host}:{args.port}") + + # Run Flask with single-threaded WSGI server + # threaded=False ensures no thread pool (required for NCCL compatibility) + # processes=1 ensures single process (no forking) + app.run( + host=args.host, + port=args.port, + threaded=False, + processes=1, + debug=False, + use_reloader=False, + ) + + +if __name__ == "__main__": + main() From ece5152f924820f9a6592cf4e84db48c32abaa14 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Wed, 29 Oct 2025 22:58:38 +0800 Subject: [PATCH 25/52] refactor to http server instead of flask --- areal/scheduler/rpc/sync_rpc_server.py | 427 ++++++++++++++----------- 1 file changed, 235 insertions(+), 192 deletions(-) diff --git a/areal/scheduler/rpc/sync_rpc_server.py b/areal/scheduler/rpc/sync_rpc_server.py index f02a2d65c..03866327d 100644 --- a/areal/scheduler/rpc/sync_rpc_server.py +++ b/areal/scheduler/rpc/sync_rpc_server.py @@ -1,20 +1,9 @@ -"""Single-threaded Flask-based RPC server for distributed TrainEngine workers. - -This server runs on worker nodes to expose TrainEngine methods via HTTP/JSON RPC. -It uses a single-threaded WSGI server to avoid threading conflicts with PyTorch -distributed communication (NCCL). - -Key differences from async_rpc_server: -- Single-threaded: Uses Flask with threaded=False for NCCL compatibility -- TrainEngine only: Only accepts TrainEngine subclasses -- No /run_workflow: Workflow execution is handled by async_rpc_server -""" - import argparse import importlib +import json import traceback - -from flask import Flask, jsonify, request +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any from areal.api.engine_api import TrainEngine from areal.scheduler.rpc.serialization import deserialize_value, serialize_value @@ -26,196 +15,251 @@ _engine: TrainEngine | None = None -app = Flask(__name__) - - -@app.route("/health", methods=["GET"]) -def health_check(): - """Health check endpoint to verify server is alive.""" - return jsonify({"status": "healthy", "engine_initialized": _engine is not None}) - +class SyncRPCHandler(BaseHTTPRequestHandler): + """HTTP request handler for sync RPC server endpoints.""" -@app.route("/create_engine", methods=["POST"]) -def create_engine(): - """ - Create and initialize a TrainEngine instance on this worker. + def log_message(self, format: str, *args: Any) -> None: + """Override to use our logger instead of stderr.""" + logger.debug(f"{self.address_string()} - {format % args}") - Expected JSON payload: - { - "engine": "areal.engine.ppo.actor.FSDPPPOActor", # Import path - "init_args": [...], # Positional arguments - "init_kwargs": {...} # Keyword arguments - } - """ - global _engine + def _send_json_response(self, data: dict, status_code: int = 200) -> None: + """Send JSON response with appropriate headers.""" + self.send_response(status_code) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(data).encode("utf-8")) - try: - data = request.get_json() - engine_path = data.get("engine") - # Deserialize init_args and init_kwargs (may contain tensors or dataclasses) - init_args = deserialize_value(data.get("init_args", [])) - init_kwargs = deserialize_value(data.get("init_kwargs", {})) + def _read_json_body(self) -> dict | None: + """Read and parse JSON request body.""" + try: + content_length = int(self.headers.get("Content-Length", 0)) + if content_length == 0: + return {} + body = self.rfile.read(content_length) + return json.loads(body.decode("utf-8")) + except (json.JSONDecodeError, ValueError) as e: + logger.error(f"Failed to parse JSON body: {e}") + self._send_json_response( + {"error": f"Invalid JSON in request body: {str(e)}"}, 400 + ) + return None + + def do_GET(self) -> None: + """Handle GET requests.""" + if self.path == "/health": + self._handle_health_check() + else: + self._send_json_response({"error": f"Not found: {self.path}"}, 404) + + def do_POST(self) -> None: + """Handle POST requests.""" + if self.path == "/create_engine": + self._handle_create_engine() + elif self.path == "/call": + self._handle_call_engine_method() + elif self.path == "/export_stats": + self._handle_export_stats() + else: + self._send_json_response({"error": f"Not found: {self.path}"}, 404) + + def _handle_health_check(self) -> None: + """Health check endpoint to verify server is alive.""" + global _engine + self._send_json_response( + {"status": "healthy", "engine_initialized": _engine is not None} + ) - if not engine_path: - return jsonify({"error": "Missing 'engine' field in request"}), 400 + def _handle_create_engine(self) -> None: + """ + Create and initialize a TrainEngine instance on this worker. + + Expected JSON payload: + { + "engine": "areal.engine.ppo.actor.FSDPPPOActor", # Import path + "init_args": [...], # Positional arguments + "init_kwargs": {...} # Keyword arguments + } + """ + global _engine - # Dynamic import try: - module_path, class_name = engine_path.rsplit(".", 1) - module = importlib.import_module(module_path) - engine_class = getattr(module, class_name) - - # Validate that the class is a TrainEngine - if not issubclass(engine_class, TrainEngine): - raise TypeError( - f"Engine class must be a subclass of TrainEngine, " - f"got {engine_class}. Use async_rpc_server for InferenceEngine." + data = self._read_json_body() + if data is None: + return + + engine_path = data.get("engine") + # Deserialize init_args and init_kwargs (may contain tensors or dataclasses) + init_args = deserialize_value(data.get("init_args", [])) + init_kwargs = deserialize_value(data.get("init_kwargs", {})) + + if not engine_path: + self._send_json_response( + {"error": "Missing 'engine' field in request"}, 400 + ) + return + + # Dynamic import + try: + module_path, class_name = engine_path.rsplit(".", 1) + module = importlib.import_module(module_path) + engine_class = getattr(module, class_name) + + # Validate that the class is a TrainEngine + if not issubclass(engine_class, TrainEngine): + raise TypeError( + f"Engine class must be a subclass of TrainEngine, " + f"got {engine_class}. Use async_rpc_server for InferenceEngine." + ) + except (ValueError, ImportError, AttributeError) as e: + logger.error(f"Failed to import engine '{engine_path}': {e}") + self._send_json_response( + {"error": f"Failed to import engine '{engine_path}': {str(e)}"}, + 400, + ) + return + except TypeError as e: + logger.error(f"Invalid engine type: {e}") + self._send_json_response({"error": str(e)}, 400) + return + + # Instantiate engine + try: + _engine = engine_class(*init_args, **init_kwargs) + logger.info(f"Engine '{engine_path}' instantiated successfully") + self._send_json_response( + { + "status": "success", + "message": f"Engine '{engine_path}' created and initialized", + "result": None, + } + ) + except Exception as e: + logger.error( + f"Failed to instantiate engine: {e}\n{traceback.format_exc()}" + ) + self._send_json_response( + {"error": f"Failed to instantiate engine: {str(e)}"}, 500 ) - except (ValueError, ImportError, AttributeError) as e: - logger.error(f"Failed to import engine '{engine_path}': {e}") - return ( - jsonify( - {"error": f"Failed to import engine '{engine_path}': {str(e)}"} - ), - 400, - ) - except TypeError as e: - logger.error(f"Invalid engine type: {e}") - return jsonify({"error": str(e)}), 400 - # Instantiate engine - try: - _engine = engine_class(*init_args, **init_kwargs) - logger.info(f"Engine '{engine_path}' instantiated successfully") - return jsonify( - { - "status": "success", - "message": f"Engine '{engine_path}' created and initialized", - "result": None, - } - ) except Exception as e: - logger.error(f"Failed to instantiate engine: {e}\n{traceback.format_exc()}") - return ( - jsonify({"error": f"Failed to instantiate engine: {str(e)}"}), - 500, + logger.error( + f"Unexpected error in create_engine: {e}\n{traceback.format_exc()}" ) + self._send_json_response({"error": f"Internal server error: {str(e)}"}, 500) + + def _handle_call_engine_method(self) -> None: + """ + Call a method on the TrainEngine instance. + + Expected JSON payload: + { + "method": "train_batch", + "args": [...], + "kwargs": {...} + } + """ + global _engine - except Exception as e: - logger.error( - f"Unexpected error in create_engine: {e}\n{traceback.format_exc()}" - ) - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 - - -@app.route("/call", methods=["POST"]) -def call_engine_method(): - """ - Call a method on the TrainEngine instance. - - Expected JSON payload: - { - "method": "train_batch", - "args": [...], - "kwargs": {...} - } - """ - global _engine - - if _engine is None: - return ( - jsonify({"error": "Engine not initialized. Call /create_engine first."}), - 503, - ) + if _engine is None: + self._send_json_response( + {"error": "Engine not initialized. Call /create_engine first."}, 503 + ) + return - try: - data = request.get_json() - method_name = data.get("method") - args = data.get("args", []) - kwargs = data.get("kwargs", {}) + try: + data = self._read_json_body() + if data is None: + return - if not method_name: - return jsonify({"error": "Missing 'method' field in request"}), 400 + method_name = data.get("method") + args = data.get("args", []) + kwargs = data.get("kwargs", {}) - # Deserialize args and kwargs (convert SerializedTensor dicts to tensors) - args = deserialize_value(args) - kwargs = deserialize_value(kwargs) + if not method_name: + self._send_json_response( + {"error": "Missing 'method' field in request"}, 400 + ) + return + + # Deserialize args and kwargs (convert SerializedTensor dicts to tensors) + args = deserialize_value(args) + kwargs = deserialize_value(kwargs) + + try: + should_bcast = kwargs.pop("_should_bcast", True) + if should_bcast: + logger.info( + f"Broadcasting data for TrainEngine method: {method_name}" + ) + from areal.utils.data import broadcast_tensor_container + + args = broadcast_tensor_container( + args, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + kwargs = broadcast_tensor_container( + kwargs, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + logger.info("Broadcasting data done.") + except Exception as e: + logger.error( + f"Broadcasting data for method '{method_name}' failed: {e}\n{traceback.format_exc()}" + ) + self._send_json_response( + {"error": f"Data broadcast '{method_name}' failed: {str(e)}"}, 500 + ) + return + + # Call method directly + logger.info(f"Calling engine method: {method_name}") + try: + # Get the method - will raise AttributeError if it doesn't exist + method = getattr(_engine, method_name) + result = method(*args, **kwargs) + + # Serialize result (convert tensors to SerializedTensor dicts) + serialized_result = serialize_value(result) + self._send_json_response( + {"status": "success", "result": serialized_result} + ) - try: - should_bcast = kwargs.pop("_should_bcast", True) - if should_bcast: - logger.info(f"Broadcasting data for TrainEngine method: {method_name}") - from areal.utils.data import broadcast_tensor_container - - args = broadcast_tensor_container( - args, - src_rank=_engine.current_data_parallel_head(), - group=_engine.context_and_model_parallel_group, + except AttributeError as e: + logger.error(f"Method '{method_name}' not found on engine: {e}") + self._send_json_response( + {"error": f"Engine does not have method '{method_name}'"}, 400 ) - kwargs = broadcast_tensor_container( - kwargs, - src_rank=_engine.current_data_parallel_head(), - group=_engine.context_and_model_parallel_group, + except Exception as e: + logger.error( + f"Engine method '{method_name}' failed: {e}\n{traceback.format_exc()}" + ) + self._send_json_response( + {"error": f"Engine method '{method_name}' failed: {str(e)}"}, 500 ) - logger.info("Broadcasting data done.") - except Exception as e: - logger.error( - f"Broadcasting data for method '{method_name}' failed: {e}\n{traceback.format_exc()}" - ) - return ( - jsonify({"error": f"Data broadcast '{method_name}' failed: {str(e)}"}), - 500, - ) - # Call method directly - logger.info(f"Calling engine method: {method_name}") - try: - # Get the method - will raise AttributeError if it doesn't exist - method = getattr(_engine, method_name) - result = method(*args, **kwargs) - - # Serialize result (convert tensors to SerializedTensor dicts) - serialized_result = serialize_value(result) - return jsonify({"status": "success", "result": serialized_result}) - - except AttributeError as e: - logger.error(f"Method '{method_name}' not found on engine: {e}") - return ( - jsonify({"error": f"Engine does not have method '{method_name}'"}), - 400, - ) except Exception as e: - logger.error( - f"Engine method '{method_name}' failed: {e}\n{traceback.format_exc()}" - ) - return ( - jsonify({"error": f"Engine method '{method_name}' failed: {str(e)}"}), - 500, - ) + logger.error(f"Unexpected error in call: {e}\n{traceback.format_exc()}") + self._send_json_response({"error": f"Internal server error: {str(e)}"}, 500) - except Exception as e: - logger.error(f"Unexpected error in call: {e}\n{traceback.format_exc()}") - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + def _handle_export_stats(self) -> None: + """Export training statistics from stats_tracker.""" + try: + global _engine + if _engine is None: + self._send_json_response({"error": "Engine not initialized"}, 503) + return + # TrainEngine: reduce stats across data_parallel_group + result = stats_tracker.export(reduce_group=_engine.data_parallel_group) + self._send_json_response({"status": "success", "result": result}) -@app.route("/export_stats", methods=["POST"]) -def export_stats(): - """Export training statistics from stats_tracker.""" - try: - global _engine - if _engine is None: - return ( - jsonify({"error": "Engine not initialized"}), - 503, + except Exception as e: + logger.error( + f"Unexpected error in export_stats: {e}\n{traceback.format_exc()}" ) - - # TrainEngine: reduce stats across data_parallel_group - result = stats_tracker.export(reduce_group=_engine.data_parallel_group) - return jsonify({"status": "success", "result": result}) - - except Exception as e: - logger.error(f"Unexpected error in export_stats: {e}\n{traceback.format_exc()}") - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + self._send_json_response({"error": f"Internal server error: {str(e)}"}, 500) def main(): @@ -232,17 +276,16 @@ def main(): logger.info(f"Starting sync RPC server on {args.host}:{args.port}") - # Run Flask with single-threaded WSGI server - # threaded=False ensures no thread pool (required for NCCL compatibility) - # processes=1 ensures single process (no forking) - app.run( - host=args.host, - port=args.port, - threaded=False, - processes=1, - debug=False, - use_reloader=False, - ) + # Create and run single-threaded HTTP server + # HTTPServer is single-threaded by default (processes one request at a time) + # This ensures NCCL compatibility + server = HTTPServer((args.host, args.port), SyncRPCHandler) + + try: + server.serve_forever() + except KeyboardInterrupt: + logger.info("Shutting down sync RPC server") + server.shutdown() if __name__ == "__main__": From 69805e8c13f088164eabd9e5d517dc8f16ae40cd Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Thu, 30 Oct 2025 10:31:26 +0800 Subject: [PATCH 26/52] sft run --- areal/controller/train_controller.py | 8 ++++++-- areal/scheduler/local.py | 23 +++++++++++++++-------- areal/scheduler/rpc/sync_rpc_server.py | 11 ++++++++++- examples/single-controller/gsm8k_sft.py | 1 + examples/single-controller/gsm8k_sft.yaml | 4 +--- 5 files changed, 33 insertions(+), 14 deletions(-) diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index c7196b47e..7a445e0ef 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -127,13 +127,17 @@ def initialize( role=self._worker_role, ) # Create environment variables to mimic torchrun + # FIXME: here master_addr and master_port only work in the local setting + port = find_free_ports(1)[0] for i, task in enumerate(job.tasks): task.env_vars["RANK"] = str(i) task.env_vars["WORLD_SIZE"] = str(alloc_mode.train.world_size) - task.env_vars["LOCAL_RANK"] = str(i) + task.env_vars["LOCAL_RANK"] = str( + 0 + ) # because we have only set 1 CUDA_VISIBLE_DEVICES for each process # TODO: find a real master addr with scheduler task.env_vars["MASTER_ADDR"] = "localhost" - task.env_vars["MASTER_PORT"] = str(find_free_ports(1)[0]) + task.env_vars["MASTER_PORT"] = str(port) # Create workers via scheduler self.logger.info("Creating workers via scheduler...") diff --git a/areal/scheduler/local.py b/areal/scheduler/local.py index a80f3987c..ed19d1513 100644 --- a/areal/scheduler/local.py +++ b/areal/scheduler/local.py @@ -4,6 +4,7 @@ import os import shlex import subprocess +import sys import time from dataclasses import dataclass, field from pathlib import Path @@ -370,17 +371,23 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: cmd.extend(["--port", str(ports[0])]) logger.info(f"Starting worker {worker_id}: {' '.join(cmd)}") + if cmd[0].startswith("python"): + cmd[0] = sys.executable + cmd = ( + " ".join(str(k) + "=" + str(v) for k, v in env.items()) + + " stdbuf -oL " + + " ".join(cmd) + ) + cmd = f"{cmd} 2>&1 | tee -a {log_file}" # Spawn subprocess try: - with open(log_file, "w") as log_f: - process = subprocess.Popen( - cmd, - env=env, - stdout=log_f, - stderr=subprocess.STDOUT, - start_new_session=True, # Create new process group - ) + process = subprocess.Popen( + cmd, + shell=isinstance(cmd, str), + stdout=sys.stdout, + stderr=sys.stdout, + ) except Exception as e: self._cleanup_workers(workers) raise WorkerCreationError( diff --git a/areal/scheduler/rpc/sync_rpc_server.py b/areal/scheduler/rpc/sync_rpc_server.py index 03866327d..12b539e44 100644 --- a/areal/scheduler/rpc/sync_rpc_server.py +++ b/areal/scheduler/rpc/sync_rpc_server.py @@ -6,6 +6,7 @@ from typing import Any from areal.api.engine_api import TrainEngine +from areal.platforms import current_platform from areal.scheduler.rpc.serialization import deserialize_value, serialize_value from areal.utils import logging, stats_tracker @@ -191,13 +192,21 @@ def _handle_call_engine_method(self) -> None: logger.info( f"Broadcasting data for TrainEngine method: {method_name}" ) - from areal.utils.data import broadcast_tensor_container + from areal.utils.data import ( + broadcast_tensor_container, + tensor_container_to, + ) + # TODO: to device here + args = tensor_container_to(args, current_platform.current_device()) args = broadcast_tensor_container( args, src_rank=_engine.current_data_parallel_head(), group=_engine.context_and_model_parallel_group, ) + kwargs = tensor_container_to( + kwargs, current_platform.current_device() + ) kwargs = broadcast_tensor_container( kwargs, src_rank=_engine.current_data_parallel_head(), diff --git a/examples/single-controller/gsm8k_sft.py b/examples/single-controller/gsm8k_sft.py index b7906750f..724041a9f 100644 --- a/examples/single-controller/gsm8k_sft.py +++ b/examples/single-controller/gsm8k_sft.py @@ -71,6 +71,7 @@ def main(args): alloc_mode=allocation_mode, ft_spec=ft_spec, schedule_strategy=ScheduleStrategy(), + addr=None, ) # Run training. diff --git a/examples/single-controller/gsm8k_sft.yaml b/examples/single-controller/gsm8k_sft.yaml index 936435ddd..a3401fa17 100644 --- a/examples/single-controller/gsm8k_sft.yaml +++ b/examples/single-controller/gsm8k_sft.yaml @@ -38,9 +38,7 @@ model: type: worker port_count: 1 gpu: 1 - # AReaL will by default uses `python3 -m areal.scheduler.rpc.rpc_server --port {PORT}` - # where ${PORT} is dynamically allocated - # cmd: python3 -m areal.scheduler.rpc.rpc_server + cmd: python3 -m areal.scheduler.rpc.sync_rpc_server train_dataset: batch_size: 128 From c37732cbb311e83167a71e983b26c9570856217b Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Thu, 30 Oct 2025 14:46:32 +0800 Subject: [PATCH 27/52] fix sft; init grpo --- areal/api/cli_args.py | 6 +- areal/api/scheduler_api.py | 4 +- areal/controller/rollout_controller.py | 6 +- areal/controller/train_controller.py | 37 ++- areal/scheduler/local.py | 2 +- areal/tests/test_local_scheduler.py | 6 +- areal/tests/test_train_controller.py | 126 +++++----- areal/workflow/rlvr.py | 24 +- docs/cli_reference.md | 272 +++++++++++----------- examples/single-controller/gsm8k_sft.py | 117 +++++----- examples/single-controller/gsm8k_sft.yaml | 4 +- realhf/api/cli_args.py | 2 +- realhf/apps/main.py | 2 +- realhf/scheduler/client.py | 2 +- realhf/scheduler/slurm/client.py | 4 +- 15 files changed, 335 insertions(+), 279 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 51ae0c492..bad699fe6 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -313,7 +313,7 @@ class MegatronEngineConfig: @dataclass -class ScheduleStrategy: +class SchedulingStrategy: type: str = field( default="separation", metadata={"choices": ["separation", "colocation"]} ) @@ -437,7 +437,7 @@ class TrainEngineConfig: ), metadata={"help": "train engine schedule specs"}, ) - scheduling_strategy: ScheduleStrategy = field(default_factory=ScheduleStrategy) + scheduling_strategy: SchedulingStrategy = field(default_factory=SchedulingStrategy) @dataclass @@ -912,7 +912,7 @@ class InferenceEngineConfig: ), metadata={"help": "inference engine schedule specs"}, ) - scheduling_strategy: ScheduleStrategy = field(default_factory=ScheduleStrategy) + scheduling_strategy: SchedulingStrategy = field(default_factory=SchedulingStrategy) @dataclass diff --git a/areal/api/scheduler_api.py b/areal/api/scheduler_api.py index cd7685f94..7ad616ccb 100644 --- a/areal/api/scheduler_api.py +++ b/areal/api/scheduler_api.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from typing import Any -from areal.api.cli_args import ScheduleStrategy, SchedulingSpec +from areal.api.cli_args import SchedulingSpec, SchedulingStrategy @dataclass @@ -27,7 +27,7 @@ class Worker: class Job: replicas: int = 0 tasks: list[SchedulingSpec] = field(default_factory=list) - schedule_strategy: ScheduleStrategy | None = None + scheduling_strategy: SchedulingStrategy | None = None role: str = "" diff --git a/areal/controller/rollout_controller.py b/areal/controller/rollout_controller.py index 680fa9e4d..61156a435 100644 --- a/areal/controller/rollout_controller.py +++ b/areal/controller/rollout_controller.py @@ -16,7 +16,7 @@ from areal.api.controller_api import DistributedBatch from areal.api.engine_api import InferenceEngine from areal.api.io_struct import ModelRequest, ModelResponse, ParamSpec, WeightUpdateMeta -from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker +from areal.api.scheduler_api import Job, Scheduler, Worker from areal.controller.batch import DistributedBatchMemory from areal.core.async_task_runner import AsyncTaskRunner, TaskQueueFullError from areal.core.staleness_manager import StalenessManager @@ -76,6 +76,7 @@ def __init__( config: InferenceEngineConfig, scheduler: Scheduler, ): + # FIXME: add seeding """Initialize the RolloutController. Parameters @@ -120,7 +121,6 @@ def initialize( self, role: str, alloc_mode: AllocationMode, - schedule_strategy: ScheduleStrategy | None = None, *args, **kwargs, ): @@ -148,7 +148,7 @@ def initialize( job = Job( replicas=alloc_mode.gen.dp_size, tasks=[self.config.scheduling_spec for _ in range(alloc_mode.gen.dp_size)], - schedule_strategy=schedule_strategy, + scheduling_strategy=self.config.scheduling_strategy, role=self._worker_role, ) diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index 7a445e0ef..e237d10a8 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -5,6 +5,7 @@ from typing import Any import torch +from torchdata.stateful_dataloader import StatefulDataLoader from areal.api.alloc_mode import ParallelStrategy from areal.api.cli_args import TrainEngineConfig @@ -16,7 +17,7 @@ SaveLoadMeta, WeightUpdateMeta, ) -from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker +from areal.api.scheduler_api import Job, Scheduler, Worker from areal.controller.batch import DistributedBatchMemory from areal.controller.rollout_controller import RolloutController from areal.platforms import current_platform @@ -60,6 +61,7 @@ def __init__( config: TrainEngineConfig, scheduler: Scheduler, ): + # FIXME: add seeding self.train_engine = train_engine self.config = config self.scheduler = scheduler @@ -84,7 +86,6 @@ def initialize( role: str, alloc_mode: AllocationMode, ft_spec: FinetuneSpec, - schedule_strategy: ScheduleStrategy, **kwargs, ): """Initialize environments for distributed training and load models. @@ -99,8 +100,6 @@ def initialize( Allocation mode configuration for distributed setup ft_spec : FinetuneSpec Finetune specification for model initialization - schedule_strategy : ScheduleStrategy - Strategy for scheduling workers **kwargs Additional keyword arguments passed to engine initialization """ @@ -123,7 +122,7 @@ def initialize( deepcopy(self.config.scheduling_spec) for _ in range(alloc_mode.train.world_size) ], - schedule_strategy=schedule_strategy, + scheduling_strategy=self.config.scheduling_strategy, role=self._worker_role, ) # Create environment variables to mimic torchrun @@ -370,6 +369,34 @@ def connect_engine(self, rollout: RolloutController, meta: WeightUpdateMeta): self._init_weight_update_from_distributed(meta) self.weight_update_group_initialized = True + def prepare_batch( + self, + dataloader: StatefulDataLoader, + workflow_path: str, + workflow_kwargs: dict[str, Any], + should_accept_path: str | None = None, + ) -> DistributedBatch: + return self.rollout.prepare_batch( + dataloader=dataloader, + workflow_path=workflow_path, + workflow_kwargs=workflow_kwargs, + should_accept_path=should_accept_path, + ) + + def rollout_batch( + self, + data: list[dict[str, Any]], + workflow_path: str, + workflow_kwargs: dict[str, Any], + should_accept_path: str | None = None, + ) -> DistributedBatch: + return self.rollout.rollout_batch( + data=data, + workflow_path=workflow_path, + workflow_kwargs=workflow_kwargs, + should_accept_path=should_accept_path, + ) + def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta): raise NotImplementedError() diff --git a/areal/scheduler/local.py b/areal/scheduler/local.py index ed19d1513..0160aa31a 100644 --- a/areal/scheduler/local.py +++ b/areal/scheduler/local.py @@ -287,7 +287,7 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: schedulings = self._prepare_worker_specs(role, num_workers, job.tasks) # Determine scheduling strategy - strategy = job.schedule_strategy + strategy = job.scheduling_strategy if strategy is None: strategy_type = "separation" colocate_role = None diff --git a/areal/tests/test_local_scheduler.py b/areal/tests/test_local_scheduler.py index a06ca170a..b5eba66bc 100644 --- a/areal/tests/test_local_scheduler.py +++ b/areal/tests/test_local_scheduler.py @@ -9,8 +9,8 @@ from areal.api.scheduler_api import ( Job, - ScheduleStrategy, SchedulingSpec, + SchedulingStrategy, Worker, ) from areal.scheduler.exceptions import ( @@ -531,7 +531,7 @@ def test_create_workers_with_colocate_strategy( replicas=2, role="critic", tasks=[SchedulingSpec(cpu=1, mem=1024, gpu=2, port_count=2)], - schedule_strategy=ScheduleStrategy(type="colocation", target="actor"), + scheduling_strategy=SchedulingStrategy(type="colocation", target="actor"), ) critic_ids = scheduler.create_workers(critic_job) @@ -673,7 +673,7 @@ def test_create_workers_colocate_strategy_missing_target(self, tmp_path): replicas=1, role="test", tasks=[SchedulingSpec(cpu=1, mem=1024, gpu=1, port_count=2)], - schedule_strategy=ScheduleStrategy( + scheduling_strategy=SchedulingStrategy( type="colocation", target="" ), # Missing target ) diff --git a/areal/tests/test_train_controller.py b/areal/tests/test_train_controller.py index a89be5181..5ca2fa7f2 100644 --- a/areal/tests/test_train_controller.py +++ b/areal/tests/test_train_controller.py @@ -11,7 +11,7 @@ import torch from areal.api.alloc_mode import ParallelStrategy -from areal.api.cli_args import ScheduleStrategy, SchedulingSpec, TrainEngineConfig +from areal.api.cli_args import SchedulingSpec, SchedulingStrategy, TrainEngineConfig from areal.api.engine_api import TrainEngine from areal.api.io_struct import ( AllocationMode, @@ -154,9 +154,9 @@ def ft_spec(): @pytest.fixture -def schedule_strategy(): - """Provide a ScheduleStrategy for testing.""" - return ScheduleStrategy(type="separation", target="") +def scheduling_strategy(): + """Provide a SchedulingStrategy for testing.""" + return SchedulingStrategy(type="separation", target="") @pytest.fixture @@ -196,13 +196,15 @@ def test_constructor(self, mock_scheduler, train_config): assert controller.worker_is_dp_head == [] assert controller.logger is None - def test_initialize(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + def test_initialize( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): """Test initialize method creates workers and engines.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) # Verify workers were created @@ -248,14 +250,14 @@ def test_create_process_group_sets_parallel_strategy( ) def test_identify_dp_heads( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test _identify_dp_heads correctly identifies DP head workers.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) # MockScheduler returns True for even-indexed workers @@ -266,14 +268,14 @@ def test_identify_dp_heads( class TestTrainControllerDestroy: """Tests for TrainController cleanup and destruction.""" - def test_destroy(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + def test_destroy(self, train_controller, alloc_mode, ft_spec, scheduling_strategy): """Test destroy method cleans up resources.""" # Initialize first train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) initial_worker_count = len(train_controller.workers) @@ -288,14 +290,14 @@ def test_destroy(self, train_controller, alloc_mode, ft_spec, schedule_strategy) assert "train_worker" in train_controller.scheduler.deleted_roles def test_destroy_handles_errors( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test destroy handles errors gracefully.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) # Make delete_workers raise an exception @@ -315,14 +317,14 @@ class TestTrainControllerBatchOperations: """Tests for batch splitting and alignment operations.""" def test_align_batches_with_dp_rebalance( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test _align_batches_with_dp with rebalance=True.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) batch = create_mock_distributed_batch(size=16) @@ -336,14 +338,14 @@ def test_align_batches_with_dp_rebalance( assert isinstance(chunk, DistributedBatchMemory) def test_align_batches_with_dp_no_rebalance( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test _align_batches_with_dp with rebalance=False.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) batch = create_mock_distributed_batch(size=16) @@ -413,13 +415,15 @@ def test_merge_results_accepts_method_parameter(self, train_controller): class TestTrainControllerRPCWrappers: """Tests for RPC wrapper methods.""" - def test_train_mode(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + def test_train_mode( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): """Test train() method sets training mode.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) result = train_controller.train(mode=True) @@ -431,13 +435,15 @@ def test_train_mode(self, train_controller, alloc_mode, ft_spec, schedule_strate engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] assert "train" in engine_calls - def test_eval_mode(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + def test_eval_mode( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): """Test eval() method sets evaluation mode.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) result = train_controller.eval() @@ -449,13 +455,13 @@ def test_eval_mode(self, train_controller, alloc_mode, ft_spec, schedule_strateg engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] assert "train" in engine_calls - def test_forward(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + def test_forward(self, train_controller, alloc_mode, ft_spec, scheduling_strategy): """Test forward() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) batch = create_mock_distributed_batch(size=8) @@ -469,14 +475,14 @@ def test_forward(self, train_controller, alloc_mode, ft_spec, schedule_strategy) assert "forward" in engine_calls def test_train_batch( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test train_batch() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) batch = create_mock_distributed_batch(size=8) @@ -496,13 +502,15 @@ def loss_weight_fn(batch_data): engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] assert "train_batch" in engine_calls - def test_eval_batch(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + def test_eval_batch( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): """Test eval_batch() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) batch = create_mock_distributed_batch(size=8) @@ -523,14 +531,14 @@ def loss_weight_fn(batch_data): assert "eval_batch" in engine_calls def test_step_lr_scheduler( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test step_lr_scheduler() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) train_controller.step_lr_scheduler() @@ -544,14 +552,14 @@ class TestTrainControllerPPOMethods: """Tests for PPO-specific methods.""" def test_compute_logp( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test compute_logp() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) result = train_controller.compute_logp() @@ -564,14 +572,14 @@ def test_compute_logp( assert "compute_logp" in engine_calls def test_compute_advantages( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test compute_advantages() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) result = train_controller.compute_advantages() @@ -583,13 +591,15 @@ def test_compute_advantages( engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] assert "compute_advantages" in engine_calls - def test_ppo_update(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + def test_ppo_update( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): """Test ppo_update() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) batch = create_mock_distributed_batch(size=8) @@ -606,13 +616,13 @@ def test_ppo_update(self, train_controller, alloc_mode, ft_spec, schedule_strate class TestTrainControllerSFTMethods: """Tests for SFT-specific methods.""" - def test_train_lm(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + def test_train_lm(self, train_controller, alloc_mode, ft_spec, scheduling_strategy): """Test train_lm() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) batch = create_mock_distributed_batch(size=8) @@ -626,14 +636,14 @@ def test_train_lm(self, train_controller, alloc_mode, ft_spec, schedule_strategy assert "train_lm" in engine_calls def test_evaluate_lm( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test evaluate_lm() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) batch = create_mock_distributed_batch(size=8) @@ -651,14 +661,14 @@ class TestTrainControllerWeightManagement: """Tests for weight management operations.""" def test_set_version( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test set_version() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) train_controller.set_version(42) @@ -668,14 +678,14 @@ def test_set_version( assert "set_version" in engine_calls def test_get_version( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test get_version() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) version = train_controller.get_version() @@ -688,14 +698,14 @@ def test_get_version( assert "get_version" in engine_calls def test_update_weights( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test update_weights() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) meta = WeightUpdateMeta(type="disk", path="/tmp/weights") @@ -705,13 +715,13 @@ def test_update_weights( engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] assert "update_weights" in engine_calls - def test_save(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + def test_save(self, train_controller, alloc_mode, ft_spec, scheduling_strategy): """Test save() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) meta = SaveLoadMeta( @@ -723,13 +733,13 @@ def test_save(self, train_controller, alloc_mode, ft_spec, schedule_strategy): engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] assert "save" in engine_calls - def test_load(self, train_controller, alloc_mode, ft_spec, schedule_strategy): + def test_load(self, train_controller, alloc_mode, ft_spec, scheduling_strategy): """Test load() method.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) meta = SaveLoadMeta( @@ -759,14 +769,14 @@ class TestTrainControllerCustomFunctionCall: """Tests for custom_function_call orchestration.""" def test_custom_function_call_with_distributed_batch( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test custom_function_call with DistributedBatch argument.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) # Clear previous calls from initialization @@ -785,14 +795,14 @@ def test_custom_function_call_with_distributed_batch( assert worker_calls == len(train_controller.workers) def test_custom_function_call_with_regular_args( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test custom_function_call with non-DistributedBatch arguments.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) # Clear previous calls @@ -810,14 +820,14 @@ def test_custom_function_call_with_regular_args( ) def test_custom_function_call_filters_dp_heads( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test custom_function_call only returns results from DP heads.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) batch = create_mock_distributed_batch(size=8) @@ -831,14 +841,14 @@ class TestTrainControllerEdgeCases: """Tests for edge cases and error handling.""" def test_empty_distributed_batch( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test handling of empty DistributedBatch.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) empty_batch = DistributedBatchMemory.from_dict({}) @@ -856,14 +866,14 @@ def test_create_process_group_requires_workers( train_controller.create_process_group(parallel_strategy) def test_method_chaining( - self, train_controller, alloc_mode, ft_spec, schedule_strategy + self, train_controller, alloc_mode, ft_spec, scheduling_strategy ): """Test that train() and eval() support method chaining.""" train_controller.initialize( role="train_worker", alloc_mode=alloc_mode, ft_spec=ft_spec, - schedule_strategy=schedule_strategy, + scheduling_strategy=scheduling_strategy, ) # Should be able to chain calls diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index e4ecf6860..693505e61 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -1,7 +1,8 @@ import asyncio +import importlib import os import uuid -from typing import Callable +from collections.abc import Callable import aiofiles import aiofiles.os @@ -37,9 +38,9 @@ def default_data_extract_prompt_fn(data): class RLVRWorkflow(RolloutWorkflow): def __init__( self, - reward_fn, + reward_fn: Callable | str, gconfig: GenerationHyperparameters, - tokenizer: PreTrainedTokenizerFast, + tokenizer: PreTrainedTokenizerFast | str, enable_thinking: bool = False, rollout_stat_scope: str = "rollout", dump_dir: str | None = None, @@ -59,6 +60,23 @@ def __init__( os.makedirs(self.dump_dir, exist_ok=True) async def arun_episode(self, engine: InferenceEngine, data): + # NOTE: tokenizer and reward_fn are not jsonifiable for remote execution, + # so we need to load it eagerly during execution. + if isinstance(self.tokenizer, str): + from areal.utils.hf_utils import load_hf_tokenizer + + tokenizer = load_hf_tokenizer(self.tokenizer) + if tokenizer.pad_token_id not in self.gconfig.stop_token_ids: + self.gconfig.stop_token_ids.append(tokenizer.pad_token_id) + if tokenizer.eos_token_id not in self.gconfig.stop_token_ids: + self.gconfig.stop_token_ids.append(tokenizer.eos_token_id) + self.tokenizer = tokenizer + + if isinstance(self.reward_fn, str): + module_path, fname = self.reward_fn.rsplit(".", 1) + module = importlib.import_module(module_path) + self.reward_fn = getattr(module, fname) + input_ids = self.get_input_ids_fn( self.data_extract_prompt_fn(data), self.tokenizer, self.enable_thinking ) diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 22b66b775..861afdf0d 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -73,9 +73,9 @@ For detailed examples, see the experiment configurations in the `examples/` dire - [DistributedDataParallel Configuration](section-distributed-data-parallel) - [MegatronEngine Configuration](section-megatron-engine) -- [ScheduleStrategy](section-schedule-strategy) - [Scheduler Configuration](section-scheduler) - [Scheduling Specification](section-scheduling) +- [SchedulingStrategy](section-scheduling-strategy) ______________________________________________________________________ @@ -312,58 +312,58 @@ Configuration for model optimization during training. Configuration for PPO actor model, a subclass of a TrainEngine. -| Parameter | Type | Default | Description | -| ------------------------- | ------------------------------------------------- | --------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"disk"` | - | -| `backend` | string | `""` | Training backend (refer to documentation) | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | -| `scheduling_strategy` | [`ScheduleStrategy`](section-schedule-strategy) | **Required** | - | -| `group_size` | integer | `1` | Number of sequences in each group | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | -| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | -| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | -| `temperature` | float | `1.0` | Temperature during generation. | -| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | -| `reward_scaling` | float | `1.0` | Reward scaling factor | -| `reward_bias` | float | `0.0` | Reward bias | -| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | -| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | -| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | -| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | -| `discount` | float | `1.0` | Discount factor for future rewards | -| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | -| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | -| `kl_ctl` | float | `0.1` | KL divergence coefficient | -| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | -| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | -| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | -| `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | -| `dynamic_sampling` | boolean | `False` | Enable dynamic sampling (within DAPO). If enabled, groups with the same reward will be masked out. Note that enabling this option will lead to variable batch sizes. If you want to use a constant batch size with dynamic filtering, you should use the `should_accept` parameter in `rollout_batch` and `prepare_batch`. | -| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | -| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | -| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | +| Parameter | Type | Default | Description | +| ------------------------- | --------------------------------------------------- | --------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"disk"` | - | +| `backend` | string | `""` | Training backend (refer to documentation) | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | - | +| `group_size` | integer | `1` | Number of sequences in each group | +| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | +| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | +| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | +| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | +| `temperature` | float | `1.0` | Temperature during generation. | +| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | +| `reward_scaling` | float | `1.0` | Reward scaling factor | +| `reward_bias` | float | `0.0` | Reward bias | +| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | +| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | +| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | +| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | +| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | +| `discount` | float | `1.0` | Discount factor for future rewards | +| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | +| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | +| `kl_ctl` | float | `0.1` | KL divergence coefficient | +| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | +| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | +| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | +| `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | +| `dynamic_sampling` | boolean | `False` | Enable dynamic sampling (within DAPO). If enabled, groups with the same reward will be masked out. Note that enabling this option will lead to variable batch sizes. If you want to use a constant batch size with dynamic filtering, you should use the `should_accept` parameter in `rollout_batch` and `prepare_batch`. | +| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | +| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | +| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | (section-ppo-critic)= @@ -371,35 +371,35 @@ Configuration for PPO actor model, a subclass of a TrainEngine. Configuration for PPO critic model, a subclass of a TrainEngine. -| Parameter | Type | Default | Description | -| ------------------------ | ------------------------------------------------- | --------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"disk"` | - | -| `backend` | string | `""` | Training backend (refer to documentation) | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | -| `scheduling_strategy` | [`ScheduleStrategy`](section-schedule-strategy) | **Required** | - | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.5` | Clipping factor for value loss | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | +| Parameter | Type | Default | Description | +| ------------------------ | --------------------------------------------------- | --------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"disk"` | - | +| `backend` | string | `""` | Training backend (refer to documentation) | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | - | +| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | +| `eps_clip` | float | `0.5` | Clipping factor for value loss | +| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | (section-train-engine)= @@ -407,32 +407,32 @@ Configuration for PPO critic model, a subclass of a TrainEngine. Core configuration for model training, including optimization and backend settings. -| Parameter | Type | Default | Description | -| ------------------------ | ------------------------------------------------- | --------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"disk"` | - | -| `backend` | string | `""` | Training backend (refer to documentation) | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | -| `scheduling_strategy` | [`ScheduleStrategy`](section-schedule-strategy) | **Required** | - | +| Parameter | Type | Default | Description | +| ------------------------ | --------------------------------------------------- | --------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"disk"` | - | +| `backend` | string | `""` | Training backend (refer to documentation) | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | - | (section-generation-hyperparameters)= @@ -460,23 +460,23 @@ Controls text generation behavior for rollout. Configuration for inference servers, including offpolicyness control. -| Parameter | Type | Default | Description | -| ------------------------- | ----------------------------------------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string \| None | `None` | - | -| `trial_name` | string \| None | `None` | - | -| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | -| `queue_size` | integer \| None | `None` | (Deprecated) Input/Output queue size for async rollout. | -| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | -| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | -| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | -| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | -| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | -| `setup_timeout` | float | `120.0` | Timeout in seconds of connecting to remote servers or launching local servers. | -| `request_timeout` | float | `3600` | Timeout for HTTP requests. | -| `request_retries` | integer | `3` | Number of retries for failed requests. | -| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | -| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | inference engine schedule specs | -| `scheduling_strategy` | [`ScheduleStrategy`](section-schedule-strategy) | **Required** | - | +| Parameter | Type | Default | Description | +| ------------------------- | --------------------------------------------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string \| None | `None` | - | +| `trial_name` | string \| None | `None` | - | +| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | +| `queue_size` | integer \| None | `None` | (Deprecated) Input/Output queue size for async rollout. | +| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | +| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | +| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | +| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | +| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | +| `setup_timeout` | float | `120.0` | Timeout in seconds of connecting to remote servers or launching local servers. | +| `request_timeout` | float | `3600` | Timeout for HTTP requests. | +| `request_retries` | integer | `3` | Number of retries for failed requests. | +| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | inference engine schedule specs | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | - | (section-sg-lang)= @@ -807,17 +807,6 @@ Refer to Megatron-LM documentation for implementation details. | `distribute_saved_activations` | boolean \| None | `None` | - | | `recompute_modules` | list of string \| None | `None` | - | -(section-schedule-strategy)= - -## ScheduleStrategy - -Configuration class: ScheduleStrategy - -| Parameter | Type | Default | Description | -| --------- | -------------- | -------------- | ----------------------------------------- | -| `type` | string | `"separation"` | - **Choices:** `separation`, `colocation` | -| `target` | string \| None | `None` | The target role to be colocated with | - (section-scheduler)= ## Scheduler Configuration @@ -855,3 +844,14 @@ Configuration class: SchedulingSpec | `time_limit` | string \| None | `None` | - | | `begin` | string \| None | `None` | - | | `deadline` | string \| None | `None` | - | + +(section-scheduling-strategy)= + +## SchedulingStrategy + +Configuration class: SchedulingStrategy + +| Parameter | Type | Default | Description | +| --------- | -------------- | -------------- | ----------------------------------------- | +| `type` | string | `"separation"` | - **Choices:** `separation`, `colocation` | +| `target` | string \| None | `None` | The target role to be colocated with | diff --git a/examples/single-controller/gsm8k_sft.py b/examples/single-controller/gsm8k_sft.py index 724041a9f..6bf231874 100644 --- a/examples/single-controller/gsm8k_sft.py +++ b/examples/single-controller/gsm8k_sft.py @@ -3,7 +3,7 @@ from areal.api.alloc_mode import AllocationMode from areal.api.cli_args import SFTConfig, load_expr_config from areal.api.io_struct import FinetuneSpec, StepInfo -from areal.api.scheduler_api import ScheduleStrategy +from areal.controller.batch import DistributedBatchMemory from areal.controller.train_controller import TrainController from areal.dataset import get_custom_dataset from areal.engine.sft.lm_engine import FSDPLMEngine @@ -50,7 +50,6 @@ def main(args): collate_fn=pad_sequences_to_tensors, ) - # Initialize engine ft_spec = FinetuneSpec( total_train_epochs=config.total_train_epochs, dataset_size=len(train_dataloader) * config.train_dataset.batch_size, @@ -70,77 +69,79 @@ def main(args): role="default", alloc_mode=allocation_mode, ft_spec=ft_spec, - schedule_strategy=ScheduleStrategy(), addr=None, ) - # Run training. saver = Saver(config.saver, ft_spec) stats_logger = StatsLogger(config, ft_spec) evaluator = Evaluator(config.evaluator, ft_spec) recover_handler = RecoverHandler(config.recover, ft_spec) - recover_info = recover_handler.load( - engine, - saver, - evaluator, - stats_logger, - train_dataloader, - ) - start_step = ( - recover_info.last_step_info.next().global_step - if recover_info is not None - else 0 - ) - total_epochs = config.total_train_epochs - - global_step = 0 - for epoch in range(total_epochs): - for step, data in enumerate(train_dataloader): - if global_step < start_step: - global_step += 1 - continue - step_info = StepInfo( - global_step=global_step, - epoch=epoch, - epoch_step=step, - steps_per_epoch=len(train_dataloader), - ) - - with ( - stats_tracker.record_timing("train_step"), - ): - engine.train_lm(data) - engine.step_lr_scheduler() - - with stats_tracker.record_timing("save"): - saver.save(engine, epoch, step, global_step, tokenizer=tokenizer) - - with stats_tracker.record_timing("checkpoint_for_recover"): - recover_handler.dump( - engine, - step_info, - saver, - evaluator, - stats_logger, - train_dataloader, - tokenizer=tokenizer, + try: + # Run training. + recover_info = recover_handler.load( + engine, + saver, + evaluator, + stats_logger, + train_dataloader, + ) + start_step = ( + recover_info.last_step_info.next().global_step + if recover_info is not None + else 0 + ) + + total_epochs = config.total_train_epochs + + global_step = 0 + for epoch in range(total_epochs): + for step, data in enumerate(train_dataloader): + if global_step < start_step: + global_step += 1 + continue + step_info = StepInfo( + global_step=global_step, + epoch=epoch, + epoch_step=step, + steps_per_epoch=len(train_dataloader), ) - with stats_tracker.record_timing("eval"): + with ( + stats_tracker.record_timing("train_step"), + ): + engine.train_lm(DistributedBatchMemory.from_dict(data)) + engine.step_lr_scheduler() - def evaluate_fn(): - for data in valid_dataloader: - engine.evaluate_lm(data) + with stats_tracker.record_timing("save"): + saver.save(engine, epoch, step, global_step, tokenizer=tokenizer) - evaluator.evaluate(evaluate_fn, epoch, step, global_step) + with stats_tracker.record_timing("checkpoint_for_recover"): + recover_handler.dump( + engine, + step_info, + saver, + evaluator, + stats_logger, + train_dataloader, + tokenizer=tokenizer, + ) - stats_logger.commit(epoch, step, global_step, engine.export_stats()) - global_step += 1 + with stats_tracker.record_timing("eval"): + + def evaluate_fn(): + for data in valid_dataloader: + engine.evaluate_lm(DistributedBatchMemory.from_dict(data)) + + evaluator.evaluate(evaluate_fn, epoch, step, global_step) + + stats_logger.commit(epoch, step, global_step, engine.export_stats()) + global_step += 1 - stats_logger.close() - engine.destroy() + finally: + stats_logger.close() + engine.destroy() if __name__ == "__main__": diff --git a/examples/single-controller/gsm8k_sft.yaml b/examples/single-controller/gsm8k_sft.yaml index a3401fa17..0efeeaac3 100644 --- a/examples/single-controller/gsm8k_sft.yaml +++ b/examples/single-controller/gsm8k_sft.yaml @@ -78,8 +78,8 @@ evaluator: experiment_name: ${experiment_name} trial_name: ${trial_name} fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null + freq_epochs: null + freq_steps: 1 freq_secs: null stats_logger: diff --git a/realhf/api/cli_args.py b/realhf/api/cli_args.py index 6de76c710..0a9934028 100644 --- a/realhf/api/cli_args.py +++ b/realhf/api/cli_args.py @@ -980,7 +980,7 @@ class BaseExperimentConfig: partition: str = field( default="dev", metadata={"help": "SLURM partition for running the experiment."} ) - schedule_strategy: str = field( + scheduling_strategy: str = field( default="empty_first", metadata={"help": "Resource scheduling strategy."} ) wandb: WandBConfig = field( diff --git a/realhf/apps/main.py b/realhf/apps/main.py index 416d656a2..5165d1d24 100644 --- a/realhf/apps/main.py +++ b/realhf/apps/main.py @@ -335,7 +335,7 @@ def main(): choices=["local", "slurm", "ray"], ) subparser.add_argument( - "--schedule_strategy", + "--scheduling_strategy", default="empty_first", choices=["empty_first", "allocated_first"], help="Schedule strategy for scheduler. Currently only effective in slurm mode. " diff --git a/realhf/scheduler/client.py b/realhf/scheduler/client.py index 3662b7c24..894f0e66e 100644 --- a/realhf/scheduler/client.py +++ b/realhf/scheduler/client.py @@ -160,7 +160,7 @@ def make(args: "BaseExperimentConfig", **kwargs) -> SchedulerClient: evaluator = kwargs.get("evaluator", None) return SlurmSchedulerClient( args, - args.schedule_strategy, + args.scheduling_strategy, evaluator, job_group_id, job_group_index, diff --git a/realhf/scheduler/slurm/client.py b/realhf/scheduler/slurm/client.py index 772b60138..e094ee6f4 100644 --- a/realhf/scheduler/slurm/client.py +++ b/realhf/scheduler/slurm/client.py @@ -81,14 +81,14 @@ class SlurmSchedulerClient(SchedulerClient): def __init__( self, args, - schedule_strategy: str, + scheduling_strategy: str, evaluator: Optional[AutomaticEvaluator], job_group_id: str, job_group_index: int, ): super().__init__(args) - self.__schedule_strategy = schedule_strategy + self.__schedule_strategy = scheduling_strategy self.__pending_jobs: Dict[str, SlurmLaunchInfo] = dict() self.__committed_jobs: Dict[str, SlurmLaunchInfo] = dict() From e50b9b06ec9674ed2f9c99353598b2af052228bf Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Thu, 30 Oct 2025 16:53:59 +0800 Subject: [PATCH 28/52] add rpc server configuration --- areal/api/cli_args.py | 10 +- areal/controller/rollout_controller.py | 27 +++-- areal/controller/train_controller.py | 45 +++----- areal/scheduler/exceptions.py | 11 ++ areal/scheduler/local.py | 81 ++++++++++++- areal/scheduler/rpc/async_rpc_server.py | 40 ++++++- areal/scheduler/rpc/sync_rpc_server.py | 40 ++++++- areal/utils/data.py | 144 ++++++++++++------------ areal/workflow/rlvr.py | 1 + 9 files changed, 277 insertions(+), 122 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index bad699fe6..6e0bc80e7 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -791,8 +791,8 @@ def build_args( sglang_config: "SGLangConfig", tp_size, base_gpu_id, - host, - port, + host=None, + port=None, dist_init_addr: str | None = None, n_nodes: int = 1, node_rank: int = 0, @@ -819,8 +819,6 @@ def build_args( x.replace("-linear", "") for x in args["lora_target_modules"] ] args = dict( - host=host, - port=port, # Model and tokenizer tokenizer_path=sglang_config.model_path, tokenizer_mode="auto", @@ -838,6 +836,10 @@ def build_args( dist_init_addr=dist_init_addr, **args, ) + if host is not None: + args["host"] = host + if port is not None: + args["port"] = port if not pkg_version.is_version_greater_or_equal("sglang", "0.4.9.post2"): raise RuntimeError("Needs sglang>=0.4.9.post2 to run the code.") return args diff --git a/areal/controller/rollout_controller.py b/areal/controller/rollout_controller.py index 61156a435..a010ec035 100644 --- a/areal/controller/rollout_controller.py +++ b/areal/controller/rollout_controller.py @@ -76,18 +76,6 @@ def __init__( config: InferenceEngineConfig, scheduler: Scheduler, ): - # FIXME: add seeding - """Initialize the RolloutController. - - Parameters - ---------- - inf_engine : type[InferenceEngine] - The inference engine class (not instance) to create on workers - config : InferenceEngineConfig - Configuration for the inference engines - scheduler : Scheduler - Scheduler for managing workers - """ self.inf_engine = inf_engine self.config = config self.scheduler = scheduler @@ -121,6 +109,7 @@ def initialize( self, role: str, alloc_mode: AllocationMode, + engine_args: dict[str, Any], *args, **kwargs, ): @@ -156,6 +145,7 @@ def initialize( asyncio.run( self._async_initialize( job, + engine_args, *args, **kwargs, ) @@ -183,7 +173,9 @@ def initialize( max_staleness=self.config.max_head_offpolicyness, ) - async def _async_initialize(self, job: Job, *args, **kwargs): + async def _async_initialize( + self, job: Job, engine_args: dict[str, Any], *args, **kwargs + ): # Create workers via scheduler self.logger.info("Creating workers via scheduler...") worker_ids = self.scheduler.create_workers(job=job) @@ -212,6 +204,13 @@ async def _async_initialize(self, job: Job, *args, **kwargs): self.logger.info("Engine created on all workers!") self.logger.info("Calling engine initialization...") + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, method="create_engine", engine_args=engine_args + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) tasks = [ self.scheduler.async_call_engine( worker_id=worker.id, method="initialize", *args, **kwargs @@ -688,7 +687,7 @@ async def _update_all_workers(): await asyncio.gather(*tasks) def update_all_workers(): - asyncio.run(_update_all_workers) + asyncio.run(_update_all_workers()) return self.executor.submit(update_all_workers) diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index e237d10a8..34b3e0903 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -28,40 +28,12 @@ class TrainController: - """A centralized controller that manages multiple distributed TrainEngine workers. - - TrainController serves as a high-level orchestrator for distributed training across - multiple concurrent workers, each running TrainEngine instances. It provides a - unified interface for coordinating training operations while abstracting away the - complexities of inter-worker communication and data distribution. - - Key differences from TrainEngine: - - Operates at a higher abstraction level, managing multiple engine instances - - Does not directly perform collective communications (no rank and process group APIs) - - Uses `DistributedBatch` for data that spans multiple workers - - Provides centralized coordination for distributed training workflows - - The controller handles workload distribution, synchronization, and aggregation - of results from the underlying TrainEngine workers, enabling scalable and - efficient distributed training. - - Parameters - ---------- - train_engine : type[TrainEngine] - The engine class (not instance) to instantiate on each worker - config : TrainEngineConfig - Configuration for training engines - scheduler : Scheduler - Scheduler for worker management - """ - def __init__( self, train_engine: type[TrainEngine], config: TrainEngineConfig, scheduler: Scheduler, ): - # FIXME: add seeding self.train_engine = train_engine self.config = config self.scheduler = scheduler @@ -71,7 +43,7 @@ def __init__( self.workers_is_dp_head: list[bool] = [] # Only DP head workers self.parallel_strategy: ParallelStrategy | None = None - self.rollout: RolloutController + self.rollout: RolloutController = None self.weight_update_group_initialized = False self._worker_role: str @@ -312,6 +284,21 @@ def _merge_results(self, results, method): first_result = results[0] # FIXME: should use a more general data conversion strategy + if isinstance(first_result, torch.Tensor): + # Assume that tensor shapes are [bs, seqlen, *] + max_length = max(tensor.shape[1] for tensor in results) + n_dim = first_result.ndim + padded_tensors = [] + for tensor in results: + pad_mode = ( + (0,) * (2 * (n_dim - 2)) + + (0, max_length - tensor.shape[1]) + + (0, 0) + ) + padded_tensor = torch.nn.functional.pad(tensor, pad_mode, value=0.0) + padded_tensors.append(padded_tensor) + return torch.cat(padded_tensors, dim=0) + if isinstance(first_result, dict): if len(first_result) == 0: return DistributedBatchMemory.from_dict({}) diff --git a/areal/scheduler/exceptions.py b/areal/scheduler/exceptions.py index 93a007b64..29f746eff 100644 --- a/areal/scheduler/exceptions.py +++ b/areal/scheduler/exceptions.py @@ -20,6 +20,17 @@ def __init__(self, worker_key: str, reason: str, details: str = ""): super().__init__(message) +class WorkerConfigurationError(SchedulerError): + def __init__(self, worker_key: str, reason: str, details: str = ""): + self.worker_key = worker_key + self.reason = reason + self.details = details + message = f"Failed to configure worker '{worker_key}': {reason}" + if details: + message += f"\nDetails: {details}" + super().__init__(message) + + class WorkerFailedError(SchedulerError): """Raised when a worker process fails or exits unexpectedly.""" diff --git a/areal/scheduler/local.py b/areal/scheduler/local.py index 0160aa31a..8e9ff9d7e 100644 --- a/areal/scheduler/local.py +++ b/areal/scheduler/local.py @@ -14,6 +14,7 @@ import orjson import psutil +from areal.api.cli_args import BaseExperimentConfig from areal.api.scheduler_api import Job, Scheduler, SchedulingSpec, Worker from areal.platforms import current_platform from areal.scheduler.exceptions import ( @@ -24,6 +25,7 @@ PortAllocationError, RPCConnectionError, SchedulerError, + WorkerConfigurationError, WorkerCreationError, WorkerFailedError, WorkerNotFoundError, @@ -70,6 +72,7 @@ class LocalScheduler(Scheduler): def __init__( self, gpu_devices: list[int] | None = None, + exp_config: BaseExperimentConfig | None = None, fileroot: str | None = None, experiment_name: str | None = None, trial_name: str | None = None, @@ -94,6 +97,9 @@ def __init__( assert experiment_name is not None assert trial_name is not None assert fileroot is not None + experiment_name = experiment_name or exp_config.experiment_name + trial_name = trial_name or exp_config.trial_name + fileroot = fileroot or exp_config.cluster.fileroot self.log_dir = ( Path(fileroot) / "logs" @@ -101,7 +107,9 @@ def __init__( / experiment_name / trial_name ) - self.cluster_name = cluster_name + self.cluster_name = cluster_name or exp_config.cluster.cluster_name + self.exp_config = exp_config + self.startup_timeout = startup_timeout self.health_check_interval = health_check_interval @@ -438,7 +446,6 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: logger.info( f"Successfully created {len(workers)} workers for role '{role}'" ) - return worker_ids except Exception as e: # Clean up any workers created before the failure @@ -447,6 +454,73 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: raise raise WorkerCreationError(role, "Unexpected error", str(e)) from e + # Send HTTP request to configure workers + for worker_rank, worker_info in enumerate(workers): + worker_id = worker_info.worker.id + port = int(worker_info.worker.worker_ports[0]) + url = f"http://{worker_info.worker.ip}:{port}/configure" + + try: + response = self._http_client.post( + url, + content=orjson.dumps( + serialize_value( + dict( + config=self.exp_config, + role=worker_info.role, + rank=worker_rank, + ) + ) + ), + headers={"Content-Type": "application/json"}, + timeout=300.0, + ) + + if response.status_code == 200: + result = response.json() + logger.info(f"Configuration successfully on worker '{worker_id}'") + return result.get("result") + elif response.status_code == 400: + # Import error or bad request + error_detail = response.json().get("detail", "Unknown error") + raise WorkerConfigurationError(worker_id, error_detail, 400) + elif response.status_code == 500: + # Engine initialization failed + error_detail = response.json().get("detail", "Unknown error") + raise WorkerConfigurationError(worker_id, error_detail, 500) + else: + raise WorkerConfigurationError( + worker_id, + f"Unexpected status code: {response.status_code}", + response.status_code, + ) + + except httpx.ConnectError as e: + # Check if worker died + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_id, worker_info.process.returncode, stderr + ) from e + raise RPCConnectionError( + worker_id, worker_info.worker.ip, port, str(e) + ) from e + + except httpx.TimeoutException as e: + raise WorkerConfigurationError( + worker_id, f"Request timed out: {e}" + ) from e + + except WorkerConfigurationError: + raise + + except Exception as e: + raise WorkerConfigurationError( + worker_id, f"Unexpected error: {str(e)}" + ) from e + + return worker_ids + def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: """ Get workers and wait for them to be ready. @@ -875,6 +949,9 @@ async def async_call_engine( url = f"http://{worker_info.worker.ip}:{port}/run_workflow" # Serialize kwargs for workflow execution payload = serialize_value(kwargs) + elif method == "configure": + url = f"http://{worker_info.worker.ip}:{port}/configure" + payload = serialize_value(kwargs) elif method == "export_stats": url = f"http://{worker_info.worker.ip}:{port}/export_stats" payload = None diff --git a/areal/scheduler/rpc/async_rpc_server.py b/areal/scheduler/rpc/async_rpc_server.py index 576362807..763687ee8 100644 --- a/areal/scheduler/rpc/async_rpc_server.py +++ b/areal/scheduler/rpc/async_rpc_server.py @@ -22,9 +22,10 @@ from fastapi import Body, FastAPI, HTTPException, Request from fastapi.responses import ORJSONResponse +from areal.api.cli_args import BaseExperimentConfig from areal.api.engine_api import InferenceEngine, TrainEngine from areal.scheduler.rpc.serialization import deserialize_value, serialize_value -from areal.utils import logging, stats_tracker +from areal.utils import logging, name_resolve, seeding, stats_tracker logger = logging.getLogger("RPCServer") @@ -287,6 +288,7 @@ async def run_workflow(request: Request): ) # Instantiate workflow try: + workflow_kwargs = deserialize_value(workflow_kwargs) workflow = workflow_class(**workflow_kwargs) logger.info(f"Workflow '{workflow_path}' instantiated successfully") except Exception as e: @@ -378,6 +380,42 @@ async def run_workflow(request: Request): raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") +@app.post("/configure") +async def configure(request: Request): + try: + body = await request.body() + data = orjson.loads(body) + + config = data.get("config") + if not config: + raise HTTPException( + status_code=400, detail="Missing 'config' field in request" + ) + role = data.get("role") + if not role: + raise HTTPException( + status_code=400, detail="Missing 'role' field in request" + ) + rank = data.get("rank") + if not rank: + raise HTTPException( + status_code=400, detail="Missing 'rank' field in request" + ) + + config = deserialize_value(config) + config: BaseExperimentConfig + + name_resolve.reconfigure(config.cluster.name_resolve) + + seeding.set_random_seed(config.seed, key=f"{role}{rank}") + + except HTTPException: + raise + except Exception as e: + logger.error(f"Unexpected error in configure: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + @app.post("/export_stats") async def export_stats(data: dict[str, Any] | None = Body(None)): try: diff --git a/areal/scheduler/rpc/sync_rpc_server.py b/areal/scheduler/rpc/sync_rpc_server.py index 12b539e44..66643e6f4 100644 --- a/areal/scheduler/rpc/sync_rpc_server.py +++ b/areal/scheduler/rpc/sync_rpc_server.py @@ -5,10 +5,11 @@ from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any +from areal.api.cli_args import BaseExperimentConfig from areal.api.engine_api import TrainEngine from areal.platforms import current_platform from areal.scheduler.rpc.serialization import deserialize_value, serialize_value -from areal.utils import logging, stats_tracker +from areal.utils import logging, name_resolve, seeding, stats_tracker logger = logging.getLogger("SyncRPCServer") @@ -60,6 +61,8 @@ def do_POST(self) -> None: self._handle_call_engine_method() elif self.path == "/export_stats": self._handle_export_stats() + elif self.path == "/configure": + self._handle_configure() else: self._send_json_response({"error": f"Not found: {self.path}"}, 404) @@ -70,6 +73,41 @@ def _handle_health_check(self) -> None: {"status": "healthy", "engine_initialized": _engine is not None} ) + def _handle_configure(self) -> None: + try: + data = self._read_json_body() + if data is None: + return + + config = data.get("config") + if not config: + raise self._send_json_response( + {"error": "Missing 'config' field in request"}, 400 + ) + role = data.get("role") + if not role: + raise self._send_json_response( + {"error": "Missing 'role' field in request"}, 400 + ) + rank = data.get("rank") + if not rank: + raise self._send_json_response( + {"error": "Missing 'rank' field in request"}, 400 + ) + + config = deserialize_value(config) + config: BaseExperimentConfig + + name_resolve.reconfigure(config.cluster.name_resolve) + + seeding.set_random_seed(config.seed, key=f"{role}{rank}") + + except Exception as e: + logger.error( + f"Unexpected error in configure: {e}\n{traceback.format_exc()}" + ) + self._send_json_response({"error": f"Internal server error: {str(e)}"}, 500) + def _handle_create_engine(self) -> None: """ Create and initialize a TrainEngine instance on this worker. diff --git a/areal/utils/data.py b/areal/utils/data.py index 8d905613f..58d7f7654 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -1,8 +1,9 @@ # Pad/unpad operations are modified from flash-attention under BSD-3 license. # Copyright (c) 2023, Tri Dao. +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any import numpy as np import torch @@ -18,7 +19,7 @@ logger = logging.getLogger("data utils") -def get_batch_size(data: Dict[str, Any]) -> int: +def get_batch_size(data: dict[str, Any]) -> int: if not data: return 0 @@ -41,18 +42,18 @@ def get_batch_size(data: Dict[str, Any]) -> int: return 0 -def reorder_list(xs: List, indices: List[int]) -> List: +def reorder_list(xs: list, indices: list[int]) -> list: assert len(set(indices)) == len(xs) return [xs[i] for i in indices] -def dict_map(x: Dict, fn: Callable) -> Dict: +def dict_map(x: dict, fn: Callable) -> dict: return {k: fn(v) for k, v in x.items()} def dict_of_list2list_of_dict( - dict_of_lists: Dict[str, List[Any]], -) -> List[Dict[str, Any]]: + dict_of_lists: dict[str, list[Any]], +) -> list[dict[str, Any]]: if not dict_of_lists: return [] keys = list(dict_of_lists.keys()) @@ -66,8 +67,8 @@ def dict_of_list2list_of_dict( def list_of_dict2dict_of_list( - list_of_dicts: List[Dict[str, Any]], -) -> Dict[str, List[Any]]: + list_of_dicts: list[dict[str, Any]], +) -> dict[str, list[Any]]: if not list_of_dicts: return {} keys = list(list_of_dicts[0].keys()) @@ -80,8 +81,8 @@ def list_of_dict2dict_of_list( def pad_sequences_to_tensors( - sequence_list: List[Dict[str, Any]], pad_value: float = 0.0 -) -> Dict[str, Any]: + sequence_list: list[dict[str, Any]], pad_value: float = 0.0 +) -> dict[str, Any]: if not sequence_list: return {} skip_keys = {"multi_modal_input"} @@ -130,7 +131,7 @@ def pad_sequences_to_tensors( def unpad_input( hidden_states, attention_mask -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() @@ -150,8 +151,8 @@ def pad_input(hidden_states, indices, batch, seqlen): def concat_padded_tensors( - tensor_dicts: List[Dict[str, Any]], pad_value: float = 0.0 -) -> Dict[str, Any]: + tensor_dicts: list[dict[str, Any]], pad_value: float = 0.0 +) -> dict[str, Any]: """Concatenate and pad tensors from multiple dictionaries of padded tensors.""" if not tensor_dicts: return {} @@ -223,8 +224,8 @@ def concat_padded_tensors( def unpack_sequence( x: torch.Tensor, - cu_seqlens: Optional[torch.Tensor] = None, - lens: Optional[List[int]] = None, + cu_seqlens: torch.Tensor | None = None, + lens: list[int] | None = None, dim: int = 0, ): """Unpack a sequence tensor into a list of tensors based on cumulative sequence lengths.""" @@ -237,7 +238,7 @@ def unpack_sequence( raise ValueError("Either cu_seqlens or input_lens must be provided.") -def allocate_balanced_mbs(mb_spec: MicroBatchSpec, lens: List[int]) -> List[List[int]]: +def allocate_balanced_mbs(mb_spec: MicroBatchSpec, lens: list[int]) -> list[list[int]]: assert mb_spec.max_tokens_per_mb is not None group_indices = datapack.ffd_allocate( lens, mb_spec.max_tokens_per_mb, min_groups=mb_spec.n_mbs @@ -248,9 +249,9 @@ def allocate_balanced_mbs(mb_spec: MicroBatchSpec, lens: List[int]) -> List[List def allocate_balanced_mbs_synced( mb_spec: MicroBatchSpec, - lens: List[int], - group: Optional[dist.ProcessGroup] = None, -) -> List[List[int]]: + lens: list[int], + group: dist.ProcessGroup | None = None, +) -> list[list[int]]: group_indices = allocate_balanced_mbs(mb_spec, lens) if not dist.is_initialized(): return group_indices @@ -263,7 +264,7 @@ def allocate_balanced_mbs_synced( ) -def pack_tensor_dict(data: Dict[str, Any]) -> Dict[str, Any]: +def pack_tensor_dict(data: dict[str, Any]) -> dict[str, Any]: """Pack a dict of tensors of shape [B, S, ...] into [total_length, ...], leaving other keys unchanged. Args: @@ -315,12 +316,12 @@ def pack_tensor_dict(data: Dict[str, Any]) -> Dict[str, Any]: return packed_data -def pad_and_stack_tensors_along_first_dim(tensor_list: List[torch.Tensor]): +def pad_and_stack_tensors_along_first_dim(tensor_list: list[torch.Tensor]): max_length = max(tensor.shape[0] for tensor in tensor_list) n_dim = tensor_list[0].ndim - assert all( - tensor.ndim == n_dim for tensor in tensor_list - ), "All tensors must have the same number of dimensions." + assert all(tensor.ndim == n_dim for tensor in tensor_list), ( + "All tensors must have the same number of dimensions." + ) padded_tensors = [] for tensor in tensor_list: @@ -331,17 +332,19 @@ def pad_and_stack_tensors_along_first_dim(tensor_list: List[torch.Tensor]): def tensor_container_to( - d: Dict[str, Any] | torch.Tensor | List[torch.Tensor], *args, **kwargs + d: dict[str, Any] | torch.Tensor | list[torch.Tensor], *args, **kwargs ): """Apply `t.to(*args, **kwargs)` to all tensors in the dictionary. Support nested dictionaries. """ - new_dict = {} if torch.is_tensor(d): return d.to(*args, **kwargs) - elif isinstance(d, list): + + if isinstance(d, list): return [tensor_container_to(v, *args, **kwargs) for v in d] - elif isinstance(d, dict): + + if isinstance(d, dict): + new_dict = {} for key, value in d.items(): if isinstance(value, dict) or isinstance(value, list): new_dict[key] = tensor_container_to(value, *args, **kwargs) @@ -350,25 +353,25 @@ def tensor_container_to( else: new_dict[key] = value return new_dict - else: - raise ValueError(f"Unsupported type: {type(d)}") + + return d @dataclass class MicroBatchList: - data: Dict[str, Any] + data: dict[str, Any] mb_spec: MicroBatchSpec - mbs: List[Dict[str, Any]] - forward_indices: List[int] - backward_indices: List[int] - group_lens: List[int] - padded_mbs: List[Dict[str, Any]] | None = None + mbs: list[dict[str, Any]] + forward_indices: list[int] + backward_indices: list[int] + group_lens: list[int] + padded_mbs: list[dict[str, Any]] | None = None # Batch-level padding information - padding_lengths: List[int] | None = None - padded_to_lengths: List[int] | None = None + padding_lengths: list[int] | None = None + padded_to_lengths: list[int] | None = None # sequence-level padding information - align_to_lengths: List[int] | None = None - old_cu_seqlens_list: List[torch.Tensor] | None = None + align_to_lengths: list[int] | None = None + old_cu_seqlens_list: list[torch.Tensor] | None = None def to(self, *args, **kwargs): mbs = [tensor_container_to(mb, *args, **kwargs) for mb in self.mbs] @@ -402,9 +405,9 @@ def to(self, *args, **kwargs): def split_padded_tensor_dict_into_mb_list( - data: Dict[str, Any], + data: dict[str, Any], mb_spec: MicroBatchSpec, - group: Optional[dist.ProcessGroup] = None, + group: dist.ProcessGroup | None = None, ) -> MicroBatchList: """Split a padded dict of tensors into micro-batches based on the attention mask. @@ -417,9 +420,9 @@ def split_padded_tensor_dict_into_mb_list( MicroBatchList: A structure containing the split micro-batches and metadata. """ # TODO: should align sequences first and then split, needs refactor - assert ( - "attention_mask" in data - ), "Input data must be padded and contain 'attention_mask' key." + assert "attention_mask" in data, ( + "Input data must be padded and contain 'attention_mask' key." + ) if mb_spec.max_tokens_per_mb is None: mb_spec = MicroBatchSpec.new( mb_spec, max_tokens_per_mb=DEFAULT_MAX_TOKENS_PER_MB @@ -522,12 +525,12 @@ def _split(tensor): def pad_packed_tensor_dict( - data: Dict[str, Any], + data: dict[str, Any], pad_to_length: int, pad_value: float = 0.0, align_sequences: bool = False, - align_to_multiple_of: Optional[int] = None, -) -> Tuple[Dict[str, Any], int, torch.Tensor, int]: + align_to_multiple_of: int | None = None, +) -> tuple[dict[str, Any], int, torch.Tensor, int]: """Pad a packed dict of tensors to a specified length. This function assumes that the input data contains "cu_seqlens" and "max_seqlen" key, and all other tensors of shape [total_length, ] will be padded to `pad_to_length`. @@ -553,9 +556,9 @@ def pad_packed_tensor_dict( sequence_padded_data = {} align_to_length = None if align_sequences: - assert ( - align_to_multiple_of is not None - ), "align_to_multiple_of must be specified when align_sequences is True." + assert align_to_multiple_of is not None, ( + "align_to_multiple_of must be specified when align_sequences is True." + ) input_lens = cu_seqlens[1:] - cu_seqlens[:-1] batch_size = input_lens.shape[0] # Align sequences to an integer multiple of align_to_multiple_of @@ -642,9 +645,9 @@ def pad_packed_tensor_dict( # Pad batch pad_length = pad_to_length - total_length - assert ( - pad_length >= 0 - ), f"pad_to_length {pad_to_length} must be greater than or equal to total length {total_length}." + assert pad_length >= 0, ( + f"pad_to_length {pad_to_length} must be greater than or equal to total length {total_length}." + ) new_cu_seqlens = F.pad(cu_seqlens, (0, 1), value=pad_to_length) new_max_seqlen = max(max_seqlen, pad_length) padded_data = {} @@ -687,7 +690,7 @@ def pad_mb_list( pad_value: float = 0.0, pad_to_maximum: bool = False, align_sequences: bool = False, - align_to_multiple_of: Optional[int] = None, + align_to_multiple_of: int | None = None, ) -> MicroBatchList: """Pad the micro-batch list to the maximum length or to a specific size to: 1. Reduce memory fragmentation. @@ -705,9 +708,9 @@ def pad_mb_list( MicroBatchList: The padded micro-batch list. """ if align_sequences: - assert ( - align_to_multiple_of is not None - ), "align_to_multiple_of must be specified when align_sequences is True." + assert align_to_multiple_of is not None, ( + "align_to_multiple_of must be specified when align_sequences is True." + ) padded_mb_inputs, pad_lengths = [], [] pad_to_lengths = [] old_cu_seqlens_list = [] @@ -717,10 +720,10 @@ def pad_mb_list( or mb_list.mb_spec.max_tokens_per_mb == DEFAULT_MAX_TOKENS_PER_MB ): logger.warning( - f"Unable to pad to maximum because max_tokens_per_mb is not properly set." + "Unable to pad to maximum because max_tokens_per_mb is not properly set." ) pad_to_maximum = False - for mb, l in zip(mb_list.mbs, mb_list.group_lens): + for mb, _len in zip(mb_list.mbs, mb_list.group_lens): if pad_to_maximum: pad_to_length = mb_list.mb_spec.max_tokens_per_mb else: @@ -728,7 +731,7 @@ def pad_mb_list( # Take hidden size 4096 with bf16 dtype as an example, # the batch size of a page is 256 pad_to_length = ( - (int(l) + N_TOKENS_PER_PAGE - 1) + (int(_len) + N_TOKENS_PER_PAGE - 1) // N_TOKENS_PER_PAGE * N_TOKENS_PER_PAGE ) @@ -756,8 +759,8 @@ def pad_mb_list( def unpad_logits( logits: torch.Tensor, padding_length: int, - cu_seqlens: Optional[torch.Tensor] = None, - old_cu_seqlens: Optional[torch.Tensor] = None, + cu_seqlens: torch.Tensor | None = None, + old_cu_seqlens: torch.Tensor | None = None, ): # TODO: when using megatron, logits are in fp32, # create new logits in bucket to reduce peak memory usage @@ -785,8 +788,8 @@ def unpad_logits( def unsqueeze_packed_tensor_dict( - data: Dict[str, Any], -) -> Dict[str, Any]: + data: dict[str, Any], +) -> dict[str, Any]: assert "cu_seqlens" in data, "Input data must contain 'cu_seqlens' key." assert "max_seqlen" in data, "Input data must contain 'max_seqlen' key." @@ -820,7 +823,7 @@ def unsqueeze_mb_list( return mb_list -def amend_position_ids(data: Dict) -> Dict: +def amend_position_ids(data: dict) -> dict: assert "attention_mask" in data, "Input data must contain 'attention_mask' key." attn_mask = data["attention_mask"] @@ -912,9 +915,8 @@ def _flatten_pad_to_max_numel(x, shapes): return torch.nn.functional.pad(x.view(-1), (0, pad_size), value=0) -def all_gather_tensor_container(data, group=None) -> List: +def all_gather_tensor_container(data, group=None) -> list: if torch.is_tensor(data): - local_shape = list(data.shape) shapes = [None for _ in range(dist.get_world_size(group))] dist.all_gather_object(shapes, local_shape, group=group) @@ -1105,7 +1107,7 @@ def __init__(self, config: NormConfig): def __call__( self, x: torch.Tensor, - loss_mask: Optional[torch.Tensor] = None, + loss_mask: torch.Tensor | None = None, high_precision: bool = True, reduce_group=None, ) -> torch.Tensor: @@ -1206,7 +1208,7 @@ def __call__( @staticmethod def _compute_mean( x: torch.Tensor, - mask: Optional[torch.Tensor], + mask: torch.Tensor | None, high_precision: bool, leave_one_out: bool, all_reduce: bool, @@ -1263,7 +1265,7 @@ def _compute_mean( @staticmethod def _compute_std( x: torch.Tensor, - mask: Optional[torch.Tensor], + mask: torch.Tensor | None, mean: torch.Tensor, unbiased: bool, high_precision: bool, diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index 693505e61..9efb4dcfe 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -76,6 +76,7 @@ async def arun_episode(self, engine: InferenceEngine, data): module_path, fname = self.reward_fn.rsplit(".", 1) module = importlib.import_module(module_path) self.reward_fn = getattr(module, fname) + self.async_reward_fn = AsyncRewardWrapper(self.reward_fn) input_ids = self.get_input_ids_fn( self.data_extract_prompt_fn(data), self.tokenizer, self.enable_thinking From a8e75dec0615845ed686d031b7d401b9df170142 Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Thu, 30 Oct 2025 19:34:16 +0800 Subject: [PATCH 29/52] except update weight --- areal/controller/rollout_controller.py | 139 ++++------------- areal/controller/train_controller.py | 196 +++++++++++++++--------- areal/core/local_inf_engine.py | 109 ++----------- areal/engine/sglang_local.py | 1 - areal/scheduler/local.py | 20 ++- areal/scheduler/rpc/async_rpc_server.py | 33 ++-- areal/scheduler/rpc/sync_rpc_server.py | 30 ++-- 7 files changed, 212 insertions(+), 316 deletions(-) diff --git a/areal/controller/rollout_controller.py b/areal/controller/rollout_controller.py index a010ec035..fc5a500d9 100644 --- a/areal/controller/rollout_controller.py +++ b/areal/controller/rollout_controller.py @@ -5,7 +5,6 @@ import random import time from collections.abc import Callable -from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from typing import Any @@ -90,9 +89,6 @@ def __init__( # Async task execution self.runner: AsyncTaskRunner | None = None - # Thread pool for weight updates - self.executor: ThreadPoolExecutor | None = None - # Logging self.logger = None @@ -158,9 +154,6 @@ def initialize( ) self.runner.initialize(logger=self.logger) - # Initialize thread pool for weight updates - self.executor = ThreadPoolExecutor(max_workers=alloc_mode.gen.dp_size) - # Initialize staleness manager for global capacity control max_concurrent_rollouts = ( self.config.max_concurrent_rollouts or self.config.consumer_batch_size @@ -241,11 +234,6 @@ def destroy(self): self.workers.clear() - # Shutdown executor - if self.executor is not None: - self.executor.shutdown(wait=True) - self.executor = None - self.logger.info("RolloutController destroyed") def get_capacity(self) -> int: @@ -592,104 +580,41 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse: req=req, ) - def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: - """Initialize the weight update process group for distributed weight updates. - - This method should be called before performing any weight updates to ensure - that the necessary communication groups are set up correctly across all workers. - - Parameters - ---------- - meta : WeightUpdateMeta - Metadata containing information about the weight update, such as the - type of communication backend and allocation mode. - - Returns - ------- - Future[None] - A future object representing the asynchronous initialization operation. - """ - - async def _init_all_workers(): - tasks = [ - self.scheduler.async_call_engine( - worker_id=worker.id, - method="init_weights_update_group", - meta=meta, - ) - for worker in self.workers - ] - await asyncio.gather(*tasks) - - def init_all_workers(): - asyncio.run(_init_all_workers()) - - return self.executor.submit(init_all_workers) + async def init_weights_update_group(self, meta: WeightUpdateMeta) -> None: + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, + method="init_weights_update_group", + meta=meta, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) - def update_weights_from_distributed( + async def update_weights_from_distributed( self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] - ) -> Future[None]: - """Update weights in the inference engine in a non-blocking manner from distributed memory. - - Parameters - ---------- - meta : WeightUpdateMeta - Metadata containing information about the weight update - param_specs : list[ParamSpec] - A list of parameter specifications for the weights to be updated - - Returns - ------- - Future[None] - A future object representing the asynchronous weight update operation - """ - - async def _update_all_workers(): - tasks = [ - self.scheduler.call_engine( - worker_id=worker.id, - method="update_weights_from_distributed", - meta=meta, - param_specs=param_specs, - ) - for worker in self.workers - ] - await asyncio.gather(*tasks) - - def update_all_workers(): - asyncio.run(_update_all_workers()) - - return self.executor.submit(update_all_workers) - - def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: - """Update weights in the inference engine from disk in a non-blocking manner. - - Parameters - ---------- - meta : WeightUpdateMeta - Metadata containing information about the weight update - - Returns - ------- - Future[None] - A future object representing the asynchronous weight update operation - """ - - async def _update_all_workers(): - tasks = [ - self.scheduler.call_engine( - worker_id=worker.id, - method="update_weights_from_disk", - meta=meta, - ) - for worker in self.workers - ] - await asyncio.gather(*tasks) - - def update_all_workers(): - asyncio.run(_update_all_workers()) + ): + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, + method="update_weights_from_distributed", + meta=meta, + param_specs=param_specs, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) - return self.executor.submit(update_all_workers) + async def update_weights_from_disk(self, meta: WeightUpdateMeta): + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, + method="update_weights_from_disk", + meta=meta, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) def set_version(self, version: int) -> None: """Set the current weight version in the inference engine. diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index 34b3e0903..819a9105e 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -125,14 +125,29 @@ def initialize( engine_path = f"{engine_class.__module__}.{engine_class.__name__}" # Create and initialize engines on workers - asyncio.run(self._async_create_engines(engine_path)) - asyncio.run(self._async_initialize_engines(ft_spec, **kwargs)) + self._run_async_task(self._async_create_engines(engine_path)) + self._run_async_task(self._async_initialize_engines(ft_spec, **kwargs)) # Identify DP head workers self._identify_dp_heads() self.logger.info("TrainController initialization complete") + def _run_async_task(self, task): + def _run_in_thread(): + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + return new_loop.run_until_complete(task) + finally: + new_loop.close() + + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor() as executor: + future = executor.submit(_run_in_thread) + return future.result() + async def _async_create_engines(self, engine_path: str): # Create engines on workers self.logger.info("Creating engines on workers...") @@ -187,7 +202,7 @@ async def _get_dp_head(): ] return await asyncio.gather(*tasks) - self.workers_is_dp_head = asyncio.run(_get_dp_head()) + self.workers_is_dp_head = self._run_async_task(_get_dp_head()) def destroy(self): """Destroy the controller and release GPU memory of models. @@ -209,11 +224,29 @@ def destroy(self): self.logger.info("TrainController destroyed") - def custom_function_call(self, method: str, *args, **kwargs): + def _custom_function_call(self, method: str, *args, **kwargs): """Dispatch method call to appropriate workers based on input type. If any argument is a DistributedBatch, split data. Call only DP heads. """ + dp_split_args, dp_split_kwargs = self._dispatch_inputs(*args, **kwargs) + results = self._run_async_task( + self._call_with_dispatched_inputs(method, dp_split_args, dp_split_kwargs) + ) + # Only remain data from DP head. + results = [r for idx, r in enumerate(results) if self.workers_is_dp_head[idx]] + return self._merge_results(results, method) + + async def _async_custom_function_call(self, method: str, *args, **kwargs): + dp_split_args, dp_split_kwargs = self._dispatch_inputs(*args, **kwargs) + results = await self._call_with_dispatched_inputs( + method, dp_split_args, dp_split_kwargs + ) + # Only remain data from DP head. + results = [r for idx, r in enumerate(results) if self.workers_is_dp_head[idx]] + return self._merge_results(results, method) + + def _dispatch_inputs(self, *args, **kwargs): # Find and split DistributedBatch arguments split_args = [] for arg in args: @@ -230,50 +263,51 @@ def custom_function_call(self, method: str, *args, **kwargs): split_kwargs[k] = self._align_batches_with_dp(v, rebalance=True) else: split_kwargs[k] = [v] * self.parallel_strategy.dp_size + return split_args, split_kwargs + async def _call_with_dispatched_inputs( + self, + method: str, + dp_split_args: list[list[Any]], + dp_worker_kwargs: list[dict[str, Any]], + ): # Call all workers. # ONLY DP head workers get their data slice. # Other workers will get data by broadcasting in RPC server. - async def _call_all(): - tasks = [] - dp_idx = 0 - for idx, worker in enumerate(self.workers): - if self.workers_is_dp_head[idx]: - # Get this worker's slice of each argument - worker_args = [splits[dp_idx] for splits in split_args] - worker_kwargs = { - k: splits[dp_idx] for k, splits in split_kwargs.items() - } - - # Convert DistributedBatch to dict for RPC - # FIXME: pass metadata instead of real tensors - worker_args = [ - arg.get_data() if isinstance(arg, DistributedBatch) else arg - for arg in worker_args - ] - worker_kwargs = { - k: v.get_data() if isinstance(v, DistributedBatch) else v - for k, v in worker_kwargs.items() - } - dp_idx += 1 - else: - worker_args = [] - worker_kwargs = {} - - tasks.append( - self.scheduler.async_call_engine( - worker.id, - method, - *worker_args, - **worker_kwargs, - ) - ) - return await asyncio.gather(*tasks) + tasks = [] + dp_idx = 0 + for idx, worker in enumerate(self.workers): + if self.workers_is_dp_head[idx]: + # Get this worker's slice of each argument + worker_args = [splits[dp_idx] for splits in dp_split_args] + worker_kwargs = { + k: splits[dp_idx] for k, splits in dp_worker_kwargs.items() + } + + # Convert DistributedBatch to dict for RPC + # FIXME: pass metadata instead of real tensors + worker_args = [ + arg.get_data() if isinstance(arg, DistributedBatch) else arg + for arg in worker_args + ] + worker_kwargs = { + k: v.get_data() if isinstance(v, DistributedBatch) else v + for k, v in worker_kwargs.items() + } + dp_idx += 1 + else: + worker_args = [] + worker_kwargs = {} - results = asyncio.run(_call_all()) - # Only remain data from DP head. - results = [r for idx, r in enumerate(results) if self.workers_is_dp_head[idx]] - return self._merge_results(results, method) + tasks.append( + self.scheduler.async_call_engine( + worker.id, + method, + *worker_args, + **worker_kwargs, + ) + ) + return await asyncio.gather(*tasks) def _merge_results(self, results, method): """Merge results from DP head workers based on result type. @@ -392,26 +426,34 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta): def _update_weights_from_disk(self, meta: WeightUpdateMeta): # Update all LocalInfEngine's local weight - fut = self.rollout.update_weights_from_disk(meta) - self.save( - SaveLoadMeta( - path=meta.path, - weight_format="hf", - with_optim=False, - tokenizer=None, - processor=None, - ) - ) - update_name = names.update_weights_from_disk( - self.config.experiment_name, - self.config.trial_name, - self.get_version(), - ) - name_resolve.add( - update_name, str(datetime.now().timestamp()), keepalive_ttl=120 + save_meta = SaveLoadMeta( + path=meta.path, + weight_format="hf", + with_optim=False, + tokenizer=None, + processor=None, ) - fut.result() + async def _actor_save(): + await self._async_custom_function_call("save", save_meta) + update_name = names.update_weights_from_disk( + self.config.experiment_name, + self.config.trial_name, + self.get_version(), + ) + name_resolve.add( + update_name, + str(datetime.now().timestamp()), + keepalive_ttl=120, + replace=True, + ) + + async def _run(): + rollout_load = self.rollout.update_weights_from_disk(meta) + actor_save = _actor_save() + await asyncio.gather(rollout_load, actor_save) + + self._run_async_task(_run()) def _check_rollout_engine_connected(self): """Validate that rollout engine has been connected via connect_engine().""" @@ -439,7 +481,7 @@ async def _call_all(): ] return await asyncio.gather(*tasks) - results = asyncio.run(_call_all()) + results = self._run_async_task(_call_all()) # stats have been aggregated and synchronized. return results[0] @@ -457,7 +499,7 @@ def train(self, mode: bool = True): TrainController Returns self for method chaining """ - self.custom_function_call("train", mode) + self._custom_function_call("train", mode) return self def eval(self): @@ -480,7 +522,7 @@ def set_version(self, version: int): version : int The weight version number to set """ - self.custom_function_call("set_version", version) + self._custom_function_call("set_version", version) def get_version(self) -> int: """Get the current weight version in the training engine. @@ -490,7 +532,7 @@ def get_version(self) -> int: int The current weight version number """ - return self.custom_function_call("get_version") + return self._custom_function_call("get_version") def save(self, meta: SaveLoadMeta): """Save model weights and optimizer states for later use. @@ -500,7 +542,7 @@ def save(self, meta: SaveLoadMeta): meta : SaveLoadMeta Metadata containing information about where and how to save """ - self.custom_function_call("save", meta) + self._custom_function_call("save", meta) def load(self, meta: SaveLoadMeta): """Load model weights and optimizer states from a file. @@ -510,7 +552,7 @@ def load(self, meta: SaveLoadMeta): meta : SaveLoadMeta Metadata containing information about where and how to load """ - self.custom_function_call("load", meta) + self._custom_function_call("load", meta) def step_lr_scheduler(self): """Step the learning rate scheduler. @@ -519,7 +561,7 @@ def step_lr_scheduler(self): (e.g., once per PPO step). It is separated from train_batch to allow for more flexible learning rate scheduling. """ - self.custom_function_call("step_lr_scheduler") + self._custom_function_call("step_lr_scheduler") def forward( self, @@ -553,7 +595,7 @@ def forward( Any or None The result produced by `post_hook` and `aggregate_fn`. """ - return self.custom_function_call( + return self._custom_function_call( "forward", input_=input_, output_seqlens=output_seqlens, @@ -593,7 +635,9 @@ def train_batch( Scalar statistics after training, e.g., the current learning rate, gradient norm, etc. """ - return self.custom_function_call("train_batch", input_, loss_fn, loss_weight_fn) + return self._custom_function_call( + "train_batch", input_, loss_fn, loss_weight_fn + ) def eval_batch( self, @@ -627,7 +671,7 @@ def eval_batch( A scalar loss or None. The evaluation statistics should be aggregated with `stats_tracker`. """ - return self.custom_function_call("eval_batch", input_, loss_fn, loss_weight_fn) + return self._custom_function_call("eval_batch", input_, loss_fn, loss_weight_fn) # ==================== SFT RPC WRAPPERS ==================== def train_lm( @@ -652,7 +696,7 @@ def train_lm( Dict[str, float] Scalar statistics after training """ - return self.custom_function_call("train_lm", input_, *args, **kwargs) + return self._custom_function_call("train_lm", input_, *args, **kwargs) def evaluate_lm( self, @@ -676,7 +720,7 @@ def evaluate_lm( torch.Tensor or None A scalar loss or None """ - return self.custom_function_call("evaluate_lm", input_, *args, **kwargs) + return self._custom_function_call("evaluate_lm", input_, *args, **kwargs) # ==================== PPO RPC WRAPPERS ==================== def compute_logp( @@ -698,7 +742,7 @@ def compute_logp( Any Log probabilities computed by the engine """ - return self.custom_function_call("compute_logp", *args, **kwargs) + return self._custom_function_call("compute_logp", *args, **kwargs) def compute_advantages( self, @@ -719,7 +763,7 @@ def compute_advantages( Any Advantages computed by the engine """ - return self.custom_function_call("compute_advantages", *args, **kwargs) + return self._custom_function_call("compute_advantages", *args, **kwargs) def ppo_update( self, @@ -737,4 +781,4 @@ def ppo_update( Dict[str, float] Scalar statistics after PPO update """ - return self.custom_function_call("ppo_update", input_) + return self._custom_function_call("ppo_update", input_) diff --git a/areal/core/local_inf_engine.py b/areal/core/local_inf_engine.py index cb58334a4..9e4964c67 100644 --- a/areal/core/local_inf_engine.py +++ b/areal/core/local_inf_engine.py @@ -2,7 +2,6 @@ import time import uuid from collections.abc import Callable -from concurrent.futures import Future, ThreadPoolExecutor from threading import Lock from typing import Any, Protocol @@ -203,7 +202,7 @@ def initialize( self.logger = logging.getLogger(f"[Local Inference Engine Rank {engine_id}]") # Initialize thread pool for non-blocking weight updates - self.executor = ThreadPoolExecutor(max_workers=1) + # FIXME: develop a principled update methods with/without thread pool # Initialize workflow executor self.workflow_executor = WorkflowExecutor( @@ -219,9 +218,6 @@ def destroy(self): if getattr(self, "workflow_executor"): self.workflow_executor.destroy() self.workflow_executor = None - if getattr(self, "executor"): - self.executor.shutdown() - self.executor = None def set_version(self, version: int): """Set the current weight version.""" @@ -339,19 +335,7 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse: ) return response - def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: - """Initialize the weight update process group for distributed weight updates. - - Parameters - ---------- - meta : WeightUpdateMeta - Metadata containing information about the weight update - - Returns - ------- - Future[None] - A future object representing the asynchronous initialization operation - """ + def init_weights_update_group(self, meta: WeightUpdateMeta) -> None: assert meta.type == current_platform.communication_backend assert not self.distributed_weight_update_initialized, ( "Weight update group already initialized." @@ -362,47 +346,18 @@ def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: "Local inference engine is not initialized, " "cannot init weight update group." ) - - # Compute rank offset for this engine - # For local engines, we assume single instance per process - rank_offset = 1 # Offset by 1 to leave rank 0 for the training engine - - fut = self.executor.submit( - self._init_weights_update_group_sync, meta, rank_offset - ) - - def callback(fut): - self.logger.info( - f"Initialized {current_platform.communication_backend.upper()} group " - f"for distributed weight update for {meta.nccl_group_name}." - ) - self.distributed_weight_update_initialized = True - - fut.add_done_callback(callback) - - return fut - - def _init_weights_update_group_sync(self, meta: WeightUpdateMeta, rank_offset: int): - """Synchronously initialize weight update group in thread pool.""" + # FIXME: get the real rank_offset from local process rank and tp size + rank_offset = 1 self.backend.init_update_weight_group(self.engine, meta, rank_offset) + self.logger.info( + f"Initialized {current_platform.communication_backend.upper()} group " + f"for distributed weight update for {meta.nccl_group_name}." + ) + self.distributed_weight_update_initialized = True def update_weights_from_distributed( self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] - ) -> Future[None]: - """Update weights in the inference engine from distributed memory. - - Parameters - ---------- - meta : WeightUpdateMeta - Metadata containing information about the weight update - param_specs : List[ParamSpec] - A list of parameter specifications for the weights to be updated - - Returns - ------- - Future[None] - A future object representing the asynchronous weight update operation - """ + ): assert meta.type == current_platform.communication_backend if self.engine is None: @@ -410,31 +365,9 @@ def update_weights_from_distributed( "Local inference engine is not initialized, cannot update weights." ) - fut = self.executor.submit( - self._update_weights_from_distributed_sync, meta, param_specs - ) - - return fut - - def _update_weights_from_distributed_sync( - self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] - ): - """Synchronously update weights from distributed memory in thread pool.""" self.backend.update_weight_xccl(self.engine, meta, param_specs) - def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: - """Update weights in the inference engine from disk. - - Parameters - ---------- - meta : WeightUpdateMeta - Metadata containing information about the weight update - - Returns - ------- - Future[None] - A future object representing the asynchronous weight update operation - """ + def update_weights_from_disk(self, meta: WeightUpdateMeta): assert meta.type == "disk" if self.engine is None: @@ -442,30 +375,12 @@ def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: "Local inference engine is not initialized, cannot update weights." ) - tik = time.perf_counter() - # Validate experiment and trial names if self.config.experiment_name is None or self.config.trial_name is None: raise RuntimeError( "Experiment and trial names must be set for disk-based weight updates." ) - fut = self.executor.submit(self._update_weights_from_disk_sync, meta) - - def callback(fut): - respond_time = fut.result() - self.logger.info( - f"Loading weights from disk done in " - f"{(time.perf_counter() - tik):.2f}s. " - f"Respond time: {respond_time:.2f}s." - ) - - fut.add_done_callback(callback) - - return fut - - def _update_weights_from_disk_sync(self, meta: WeightUpdateMeta) -> float: - """Synchronously update weights from disk in thread pool.""" # Wait for training engine to signal that weights are ready update_name = names.update_weights_from_disk( self.config.experiment_name, @@ -487,8 +402,6 @@ def _update_weights_from_disk_sync(self, meta: WeightUpdateMeta) -> float: f"Loading weights done in {(time.time() - load_timestamp) * 1000:.2f} ms" ) - return load_timestamp - save_timestamp - def submit( self, data: dict[str, Any], diff --git a/areal/engine/sglang_local.py b/areal/engine/sglang_local.py index ee08bbd08..b565da12e 100644 --- a/areal/engine/sglang_local.py +++ b/areal/engine/sglang_local.py @@ -134,7 +134,6 @@ def update_weight_disk(self, engine: Any, model_path: str) -> None: model_path : str Path to the model weights on disk """ - # Call SGLang's update_weights_from_disk method engine.update_weights_from_disk(model_path=model_path) def update_weight_xccl( diff --git a/areal/scheduler/local.py b/areal/scheduler/local.py index 8e9ff9d7e..bee6318d1 100644 --- a/areal/scheduler/local.py +++ b/areal/scheduler/local.py @@ -94,12 +94,12 @@ def __init__( if log_dir is not None: self.log_dir = Path(log_dir) else: - assert experiment_name is not None - assert trial_name is not None - assert fileroot is not None experiment_name = experiment_name or exp_config.experiment_name trial_name = trial_name or exp_config.trial_name fileroot = fileroot or exp_config.cluster.fileroot + assert experiment_name is not None + assert trial_name is not None + assert fileroot is not None self.log_dir = ( Path(fileroot) / "logs" @@ -456,6 +456,8 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: # Send HTTP request to configure workers for worker_rank, worker_info in enumerate(workers): + while not self._is_worker_ready(worker_info): + time.sleep(0.1) worker_id = worker_info.worker.id port = int(worker_info.worker.worker_ports[0]) url = f"http://{worker_info.worker.ip}:{port}/configure" @@ -477,22 +479,21 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: ) if response.status_code == 200: - result = response.json() logger.info(f"Configuration successfully on worker '{worker_id}'") - return result.get("result") + continue elif response.status_code == 400: # Import error or bad request error_detail = response.json().get("detail", "Unknown error") - raise WorkerConfigurationError(worker_id, error_detail, 400) + raise WorkerConfigurationError(worker_id, error_detail, str(400)) elif response.status_code == 500: # Engine initialization failed error_detail = response.json().get("detail", "Unknown error") - raise WorkerConfigurationError(worker_id, error_detail, 500) + raise WorkerConfigurationError(worker_id, error_detail, str(500)) else: raise WorkerConfigurationError( worker_id, f"Unexpected status code: {response.status_code}", - response.status_code, + str(response.status_code), ) except httpx.ConnectError as e: @@ -949,9 +950,6 @@ async def async_call_engine( url = f"http://{worker_info.worker.ip}:{port}/run_workflow" # Serialize kwargs for workflow execution payload = serialize_value(kwargs) - elif method == "configure": - url = f"http://{worker_info.worker.ip}:{port}/configure" - payload = serialize_value(kwargs) elif method == "export_stats": url = f"http://{worker_info.worker.ip}:{port}/export_stats" payload = None diff --git a/areal/scheduler/rpc/async_rpc_server.py b/areal/scheduler/rpc/async_rpc_server.py index 763687ee8..70eeaba02 100644 --- a/areal/scheduler/rpc/async_rpc_server.py +++ b/areal/scheduler/rpc/async_rpc_server.py @@ -387,17 +387,17 @@ async def configure(request: Request): data = orjson.loads(body) config = data.get("config") - if not config: + if config is None: raise HTTPException( status_code=400, detail="Missing 'config' field in request" ) role = data.get("role") - if not role: + if role is None: raise HTTPException( status_code=400, detail="Missing 'role' field in request" ) rank = data.get("rank") - if not rank: + if rank is None: raise HTTPException( status_code=400, detail="Missing 'rank' field in request" ) @@ -406,9 +406,9 @@ async def configure(request: Request): config: BaseExperimentConfig name_resolve.reconfigure(config.cluster.name_resolve) - seeding.set_random_seed(config.seed, key=f"{role}{rank}") + return {"status": "success", "result": None} except HTTPException: raise except Exception as e: @@ -462,15 +462,22 @@ def main(): logger.info(f"Starting async RPC server on {args.host}:{args.port}") - # Run uvicorn server with a single worker (required for GPU workloads) - uvicorn.run( - app, - host=args.host, - port=args.port, - workers=1, # Single worker required for GPU memory management - log_level="info", - access_log=True, - ) + try: + # Run uvicorn server with a single worker (required for GPU workloads) + uvicorn.run( + app, + host=args.host, + port=args.port, + workers=1, # Single worker required for GPU memory management + log_level="info", + access_log=True, + ) + finally: + global _engine + if _engine is not None: + assert isinstance(_engine, InferenceEngine) + _engine.destroy_engine() + _engine.destroy() if __name__ == "__main__": diff --git a/areal/scheduler/rpc/sync_rpc_server.py b/areal/scheduler/rpc/sync_rpc_server.py index 66643e6f4..8d0260627 100644 --- a/areal/scheduler/rpc/sync_rpc_server.py +++ b/areal/scheduler/rpc/sync_rpc_server.py @@ -80,19 +80,19 @@ def _handle_configure(self) -> None: return config = data.get("config") - if not config: - raise self._send_json_response( - {"error": "Missing 'config' field in request"}, 400 + if config is None: + self._send_json_response( + {"detail": "Missing 'config' field in request"}, 400 ) role = data.get("role") - if not role: - raise self._send_json_response( - {"error": "Missing 'role' field in request"}, 400 + if role is None: + self._send_json_response( + {"detail": "Missing 'role' field in request"}, 400 ) rank = data.get("rank") - if not rank: - raise self._send_json_response( - {"error": "Missing 'rank' field in request"}, 400 + if rank is None: + self._send_json_response( + {"detail": "Missing 'rank' field in request"}, 400 ) config = deserialize_value(config) @@ -101,7 +101,13 @@ def _handle_configure(self) -> None: name_resolve.reconfigure(config.cluster.name_resolve) seeding.set_random_seed(config.seed, key=f"{role}{rank}") - + self._send_json_response( + { + "status": "success", + "message": "Worker configured successful.", + "result": None, + } + ) except Exception as e: logger.error( f"Unexpected error in configure: {e}\n{traceback.format_exc()}" @@ -333,6 +339,10 @@ def main(): except KeyboardInterrupt: logger.info("Shutting down sync RPC server") server.shutdown() + finally: + global _engine + if _engine is not None: + _engine.destroy() if __name__ == "__main__": From beeedd7acbcb2d66d33c45b662dc32f240c873b6 Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Thu, 30 Oct 2025 21:43:58 +0800 Subject: [PATCH 30/52] grpo run --- areal/api/io_struct.py | 4 + areal/controller/rollout_controller.py | 99 +++++++------- areal/controller/train_controller.py | 49 ++++--- areal/core/local_inf_engine.py | 8 ++ areal/core/remote_inf_engine.py | 3 +- areal/engine/ppo/actor.py | 177 ++++++++++++------------- areal/engine/sglang_local.py | 41 ++++++ areal/engine/sglang_remote.py | 25 +++- areal/scheduler/local.py | 3 +- areal/scheduler/rpc/sync_rpc_server.py | 70 ++++++++-- 10 files changed, 293 insertions(+), 186 deletions(-) diff --git a/areal/api/io_struct.py b/areal/api/io_struct.py index 0f9552fc9..90fed028c 100644 --- a/areal/api/io_struct.py +++ b/areal/api/io_struct.py @@ -114,6 +114,8 @@ class WeightUpdateMeta: use_lora: bool = False + clear_checkpoint: bool = True + @classmethod def from_disk( cls, @@ -122,6 +124,7 @@ def from_disk( file_root: str, name: str = "default", use_lora: bool = False, + clear_checkpoint: bool = True, ) -> "WeightUpdateMeta": from areal.utils.saver import Saver @@ -133,6 +136,7 @@ def from_disk( type="disk", path=path, use_lora=use_lora, + clear_checkpoint=clear_checkpoint, ) @classmethod diff --git a/areal/controller/rollout_controller.py b/areal/controller/rollout_controller.py index fc5a500d9..36b3d40a6 100644 --- a/areal/controller/rollout_controller.py +++ b/areal/controller/rollout_controller.py @@ -260,51 +260,6 @@ def _choose_worker(self) -> Worker: self._current_worker_idx = (self._current_worker_idx + 1) % len(self.workers) return worker - async def _run_workflow_on_worker( - self, - worker: Worker, - data: dict[str, Any], - workflow_path: str, - workflow_kwargs: dict[str, Any], - should_accept_path: str | None = None, - ) -> dict[str, Any] | None: - # Call run_workflow on worker via scheduler - # This will hit the /run_workflow endpoint - result = await self.scheduler.async_call_engine( - worker_id=worker.id, - method="run_workflow", - workflow=workflow_path, - workflow_kwargs=workflow_kwargs, - data=data, - should_accept_path=should_accept_path, - check_trajectory_format=self.config.check_trajectory_format, - ) - - # The RPCServer will return None if the - # trajectory is rejected. - if result is not None: - self.staleness_manager.on_rollout_accepted() - if self.config.enable_rollout_tracing: - stat = self.staleness_manager.get_stats() - self.logger.info( - f"Finish and accept rollout. " - f"Submit: {stat.submitted}, " - f"running: {stat.running}, " - f"accepted: {stat.accepted}." - ) - return result - else: - self.staleness_manager.on_rollout_rejected() - if self.config.enable_rollout_tracing: - stat = self.staleness_manager.get_stats() - self.logger.info( - f"Finish but reject rollout. " - f"Submit: {stat.submitted}, " - f"running: {stat.running}, " - f"accepted: {stat.accepted}." - ) - return None - def submit( self, data: dict[str, Any], @@ -340,6 +295,40 @@ def submit( ) ) + async def _wait_callback(self, worker: Worker): + # Wait for a generation to return + result = "NO_RESULT" + tik = time.time() + while result == "NO_RESULT" and time.time() - tik < self.config.request_timeout: + result = await self.scheduler.async_call_engine( + worker.id, "wait_quiet", count=1, timeout=1, max_retries=1 + ) + + # The RPCServer will return None if the + # trajectory is rejected. + if result is not None: + self.staleness_manager.on_rollout_accepted() + if self.config.enable_rollout_tracing: + stat = self.staleness_manager.get_stats() + self.logger.info( + f"Finish and accept rollout. " + f"Submit: {stat.submitted}, " + f"running: {stat.running}, " + f"accepted: {stat.accepted}." + ) + return result + else: + self.staleness_manager.on_rollout_rejected() + if self.config.enable_rollout_tracing: + stat = self.staleness_manager.get_stats() + self.logger.info( + f"Finish but reject rollout. " + f"Submit: {stat.submitted}, " + f"running: {stat.running}, " + f"accepted: {stat.accepted}." + ) + return None + def _commit_one_to_runner(self): """Commit one pending input to task runner with staleness tracking.""" task_input = self._pending_inputs.pop(0) @@ -347,15 +336,20 @@ def _commit_one_to_runner(self): # Choose worker via round-robin worker = self._choose_worker() - # Submit to AsyncTaskRunner + self.scheduler.call_engine( + worker.id, + "submit", + data=task_input.data, + workflow_path=task_input.workflow_path, + workflow_kwargs=task_input.workflow_kwargs, + should_accept_path=task_input.should_accept_path, + ) + + # Submit a wait callback to AsyncTaskRunner try: self.runner.submit( - self._run_workflow_on_worker, + self._wait_callback, worker, - task_input.data, - task_input.workflow_path, - task_input.workflow_kwargs, - task_input.should_accept_path, ) except TaskQueueFullError: raise queue.Full("Input queue full") @@ -600,6 +594,7 @@ async def update_weights_from_distributed( method="update_weights_from_distributed", meta=meta, param_specs=param_specs, + max_retries=1, ) for worker in self.workers ] @@ -611,6 +606,7 @@ async def update_weights_from_disk(self, meta: WeightUpdateMeta): worker_id=worker.id, method="update_weights_from_disk", meta=meta, + max_retries=1, ) for worker in self.workers ] @@ -634,6 +630,7 @@ def set_version(self, version: int) -> None: worker_id=worker.id, method="set_version", version=version, + max_retries=1, ) except Exception as e: self.logger.error(f"Error setting version for worker {worker.id}: {e}") diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index 819a9105e..d6d6e80a4 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -1,4 +1,5 @@ import asyncio +import shutil from collections.abc import Callable from copy import deepcopy from datetime import datetime @@ -426,34 +427,30 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta): def _update_weights_from_disk(self, meta: WeightUpdateMeta): # Update all LocalInfEngine's local weight - save_meta = SaveLoadMeta( - path=meta.path, - weight_format="hf", - with_optim=False, - tokenizer=None, - processor=None, - ) - - async def _actor_save(): - await self._async_custom_function_call("save", save_meta) - update_name = names.update_weights_from_disk( - self.config.experiment_name, - self.config.trial_name, - self.get_version(), + self.save( + SaveLoadMeta( + path=meta.path, + weight_format="hf", + with_optim=False, + tokenizer=None, + processor=None, ) - name_resolve.add( - update_name, - str(datetime.now().timestamp()), - keepalive_ttl=120, - replace=True, - ) - - async def _run(): - rollout_load = self.rollout.update_weights_from_disk(meta) - actor_save = _actor_save() - await asyncio.gather(rollout_load, actor_save) + ) + update_name = names.update_weights_from_disk( + self.config.experiment_name, + self.config.trial_name, + self.get_version(), + ) + name_resolve.add( + update_name, + str(datetime.now().timestamp()), + keepalive_ttl=120, + replace=True, + ) - self._run_async_task(_run()) + meta.clear_checkpoint = False + asyncio.run(self.rollout.update_weights_from_disk(meta)) + shutil.rmtree(meta.path, ignore_errors=True) def _check_rollout_engine_connected(self): """Validate that rollout engine has been connected via connect_engine().""" diff --git a/areal/core/local_inf_engine.py b/areal/core/local_inf_engine.py index 9e4964c67..f19d3ed17 100644 --- a/areal/core/local_inf_engine.py +++ b/areal/core/local_inf_engine.py @@ -446,6 +446,14 @@ def wait(self, count: int, timeout: float | None = None) -> dict[str, Any]: """ return self.workflow_executor.wait(count, timeout=timeout) + def wait_quiet( + self, count: int, timeout: float | None = None + ) -> dict[str, Any] | None: + try: + return self.workflow_executor.wait(count, timeout=timeout) + except TimeoutError: + return "NO_RESULT" + def rollout_batch( self, data: list[dict[str, Any]], diff --git a/areal/core/remote_inf_engine.py b/areal/core/remote_inf_engine.py index 44f82f52b..8e1050210 100644 --- a/areal/core/remote_inf_engine.py +++ b/areal/core/remote_inf_engine.py @@ -620,7 +620,8 @@ def callback(fut): # Update LoRA state if this was a LoRA update if meta.use_lora: self.lora_initialized = True - shutil.rmtree(meta.path, ignore_errors=True) + if meta.clear_checkpoint: + shutil.rmtree(meta.path, ignore_errors=True) fut.add_done_callback(callback) diff --git a/areal/engine/ppo/actor.py b/areal/engine/ppo/actor.py index c35527a04..c8a0d7ed3 100644 --- a/areal/engine/ppo/actor.py +++ b/areal/engine/ppo/actor.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Dict, List, Optional +from typing import Any import torch @@ -51,8 +51,8 @@ def __init__(self, config: PPOActorConfig, engine: TrainEngine): @torch.no_grad() def compute_logp( self, - data: Dict[str, Any], - temperature: Optional[float] = None, + data: dict[str, Any], + temperature: float | None = None, ) -> torch.Tensor | None: def calc_logprobs(logits, input_data): labels = input_data.get( @@ -69,7 +69,7 @@ def calc_logprobs(logits, input_data): aggregate_fn=lambda xs: torch.cat(xs, dim=-1), ) - def compute_advantages(self, data: Dict[str, Any]) -> Dict[str, Any]: + def compute_advantages(self, data: dict[str, Any]) -> dict[str, Any]: bs = data["input_ids"].shape[0] max_seqlen = data["input_ids"].shape[1] batch_indices = torch.arange( @@ -165,7 +165,7 @@ def compute_advantages(self, data: Dict[str, Any]) -> Dict[str, Any]: return data - def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]: + def ppo_update(self, data: dict[str, Any]) -> list[dict[str, float]]: if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: data, sampling_stat = dynamic_sampling(data, self.group_size) @@ -174,78 +174,69 @@ def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]: reward_score = data["rewards"] seqlens = attn_mask.sum(-1) - all_stats = [] ########## Logging code starts ########## - result_denominators = { - "correct_n_seqs": (reward_score > 0).bool(), - "incorrect_n_seqs": (reward_score <= 0).bool(), - } - if self.config.log_agent_stats: - assert ( - "begin_of_trajectory" in data - ), "'begin_of_trajectory' is expected to log agent statistics" - assert ( - len(self.config.log_agent_stats_keys) > 0 - ), "`log_agent_stats_keys` should not be empty when log_agent_stats=True" - agent_denominator = (data["begin_of_trajectory"] > 0).bool() - result_denominators["agent"] = agent_denominator - global_denominators = dict( - n_seqs=torch.ones_like(reward_score, dtype=torch.bool), - n_tokens=torch.ones_like(loss_mask, dtype=torch.bool), - n_valid_tokens=loss_mask.bool(), - **result_denominators, - ) - stats_tracker.denominator(**global_denominators) - stats_tracker.stat( - correct_seq_len=seqlens.float(), denominator="correct_n_seqs" - ) - stats_tracker.stat( - incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs" - ) - - stats = dict( - advantages=data["advantages"], - kl_rewards=data["kl_rewards"], - final_reward=data["tot_rewards"], - ) - stats_tracker.stat(**stats, denominator="n_valid_tokens") - - prompt_lens = [] - prompt_lens = data["attention_mask"].sum(-1) - data["loss_mask"].sum(-1) - seq_stats = dict( - no_eos_ratios=(seqlens == attn_mask.shape[-1]).float(), - task_reward=reward_score.float(), - prompt_len=prompt_lens.float(), - seq_len=seqlens.float(), - ) - stats_tracker.stat(**seq_stats, denominator="n_seqs") - scalars = dict( - mask_no_eos_with_zero=self.config.mask_no_eos_with_zero, - eps_clip=self.config.eps_clip, - ) - if self.config.c_clip is not None: - scalars["c_clip"] = self.config.c_clip - scalars["use_dual_clip"] = 1 - else: - scalars["use_dual_clip"] = 0 - if self.config.behav_imp_weight_cap is not None: - scalars["behav_imp_weight_cap"] = self.config.behav_imp_weight_cap - stats_tracker.scalar(**scalars) - - if self.config.log_agent_stats: + with stats_tracker.scope("ppo_actor"): + result_denominators = { + "correct_n_seqs": (reward_score > 0).bool(), + "incorrect_n_seqs": (reward_score <= 0).bool(), + } + if self.config.log_agent_stats: + assert "begin_of_trajectory" in data, ( + "'begin_of_trajectory' is expected to log agent statistics" + ) + assert len(self.config.log_agent_stats_keys) > 0, ( + "`log_agent_stats_keys` should not be empty when log_agent_stats=True" + ) + agent_denominator = (data["begin_of_trajectory"] > 0).bool() + result_denominators["agent"] = agent_denominator + global_denominators = dict( + n_seqs=torch.ones_like(reward_score, dtype=torch.bool), + n_tokens=torch.ones_like(loss_mask, dtype=torch.bool), + n_valid_tokens=loss_mask.bool(), + **result_denominators, + ) + stats_tracker.denominator(**global_denominators) + stats_tracker.stat( + correct_seq_len=seqlens.float(), denominator="correct_n_seqs" + ) stats_tracker.stat( - **{k: data[k].float() for k in self.config.log_agent_stats_keys}, - denominator="agent", + incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs" ) - global_stats = stats_tracker.export( - reduce_group=self.engine.data_parallel_group - ) - for k in global_denominators: - keys = list(global_stats.keys()) - for k2 in keys: - if k2.endswith(k): - global_stats.pop(k2) + stats = dict( + advantages=data["advantages"], + kl_rewards=data["kl_rewards"], + final_reward=data["tot_rewards"], + ) + stats_tracker.stat(**stats, denominator="n_valid_tokens") + + prompt_lens = [] + prompt_lens = data["attention_mask"].sum(-1) - data["loss_mask"].sum(-1) + seq_stats = dict( + no_eos_ratios=(seqlens == attn_mask.shape[-1]).float(), + task_reward=reward_score.float(), + prompt_len=prompt_lens.float(), + seq_len=seqlens.float(), + ) + stats_tracker.stat(**seq_stats, denominator="n_seqs") + scalars = dict( + mask_no_eos_with_zero=self.config.mask_no_eos_with_zero, + eps_clip=self.config.eps_clip, + ) + if self.config.c_clip is not None: + scalars["c_clip"] = self.config.c_clip + scalars["use_dual_clip"] = 1 + else: + scalars["use_dual_clip"] = 0 + if self.config.behav_imp_weight_cap is not None: + scalars["behav_imp_weight_cap"] = self.config.behav_imp_weight_cap + stats_tracker.scalar(**scalars) + + if self.config.log_agent_stats: + stats_tracker.stat( + **{k: data[k].float() for k in self.config.log_agent_stats_keys}, + denominator="agent", + ) ########## Logging code ends ########## for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]: @@ -257,24 +248,21 @@ def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]: mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches), ) for mb in mb_inputs.mbs: - train_stat = self.engine.train_batch( - mb, - loss_fn=functools.partial( - grpo_loss_fn, - temperature=self.temperature, - eps_clip=self.config.eps_clip, - eps_clip_higher=self.config.eps_clip_higher, - c_clip=self.config.c_clip, - behav_imp_weight_cap=self.config.behav_imp_weight_cap, - ), - loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), - ) - stats_tracker.scalar(**train_stat) - all_stats.append( - stats_tracker.export(reduce_group=self.engine.data_parallel_group) - ) - all_stats[0].update(global_stats) - return all_stats + with stats_tracker.scope("ppo_update"): + train_stat = self.engine.train_batch( + mb, + loss_fn=functools.partial( + grpo_loss_fn, + temperature=self.temperature, + eps_clip=self.config.eps_clip, + eps_clip_higher=self.config.eps_clip_higher, + c_clip=self.config.c_clip, + behav_imp_weight_cap=self.config.behav_imp_weight_cap, + ), + loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), + ) + stats_tracker.scalar(**train_stat) + return {} class FSDPPPOActor(FSDPEngine): @@ -290,12 +278,11 @@ def compute_logp(self, *args, **kwargs) -> torch.Tensor | None: def compute_advantages(self, *args, **kwargs): return self.actor.compute_advantages(*args, **kwargs) - def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]: + def ppo_update(self, *args, **kwargs) -> list[dict[str, float]]: return self.actor.ppo_update(*args, **kwargs) class MegatronPPOActor(MegatronEngine): - def __init__(self, config: PPOActorConfig): super().__init__(config) self.actor = PPOActor(config, self) @@ -308,13 +295,13 @@ def compute_logp(self, *args, **kwargs) -> torch.Tensor | None: def compute_advantages(self, *args, **kwargs) -> None: self.actor.compute_advantages(*args, **kwargs) - def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]: + def ppo_update(self, *args, **kwargs) -> list[dict[str, float]]: return self.actor.ppo_update(*args, **kwargs) def grpo_loss_fn( logits: torch.Tensor, - input_data: Dict, + input_data: dict, temperature: float, eps_clip: float, eps_clip_higher: float | None, diff --git a/areal/engine/sglang_local.py b/areal/engine/sglang_local.py index b565da12e..d7befda08 100644 --- a/areal/engine/sglang_local.py +++ b/areal/engine/sglang_local.py @@ -134,7 +134,43 @@ def update_weight_disk(self, engine: Any, model_path: str) -> None: model_path : str Path to the model weights on disk """ + # otherwise will encounter" eventloop is already running" issue + # def _run_in_thread(): + # print(11111111111111) + # # Call SGLang's update_weights_from_disk method + # try: + # cur_loop = asyncio.get_event_loop() + # except RuntimeError: + # cur_loop = None + # loop = asyncio.new_event_loop() + # asyncio.set_event_loop(loop) + # try: + # # Call SGLang's update_weights_from_distributed method + # from sglang.srt.managers.io_struct import ( + # UpdateWeightFromDiskReqInput, + # ) engine.update_weights_from_disk(model_path=model_path) + # obj = UpdateWeightFromDiskReqInput( + # model_path=model_path, + # abort_all_requests=False, + # ) + + # loop = asyncio.get_running_loop() + # future = asyncio.run_coroutine_threadsafe( + # engine.tokenizer_manager.update_weights_from_disk(obj, None), + # loop + # ) + # return future.result() # This blocks until complete + # print(2222222222, flush=True) + # finally: + # asyncio.set_event_loop(cur_loop) + # loop.close() + + # from concurrent.futures import ThreadPoolExecutor + + # with ThreadPoolExecutor() as executor: + # future = executor.submit(_run_in_thread) + # _ = future.result() def update_weight_xccl( self, @@ -325,3 +361,8 @@ def pause(self): def resume(self): """Resume request submission for async rollout.""" return self._engine.resume() + + def wait_quiet( + self, count: int, timeout: float | None = None + ) -> dict[str, Any] | None: + return self._engine.wait_quiet(count=count, timeout=timeout) diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 8f4f39d90..37d973176 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -1,10 +1,11 @@ +import os from collections.abc import Callable from concurrent.futures import Future from typing import Any, Optional from torchdata.stateful_dataloader import StatefulDataLoader -from areal.api.cli_args import InferenceEngineConfig +from areal.api.cli_args import InferenceEngineConfig, SGLangConfig from areal.api.engine_api import InferenceEngine from areal.api.io_struct import ( HttpGenerationResult, @@ -17,7 +18,9 @@ ) from areal.api.workflow_api import RolloutWorkflow from areal.core import RemoteInfEngine +from areal.launcher.sglang_server import launch_server_cmd, wait_for_server from areal.platforms import current_platform +from areal.utils.network import find_free_ports, gethostip class SGLangBackend: @@ -188,13 +191,19 @@ def __init__(self, config: InferenceEngineConfig): # Pure composition - create internal engine with SGLang backend self._engine = RemoteInfEngine(config, SGLangBackend()) + def create_engine(self, engine_args): + engine_args["host"] = host_ip = gethostip() + engine_args["port"] = server_port = find_free_ports(1)[0] + cmd = SGLangConfig.build_cmd_from_args(engine_args) + self.server_process = launch_server_cmd(cmd) + wait_for_server(f"http://{host_ip}:{server_port}") + print(f"SGLang server launched at: http://{host_ip}:{server_port}") + os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{host_ip}:{server_port}" + def configure(self, config): self.config = config self._engine.configure(config) - def create_engine(self, *args, **kwargs): - return self._engine.create_engine(*args, **kwargs) - def destroy_engine(self, *args, **kwargs): return self._engine.destroy_engine(*args, **kwargs) @@ -286,3 +295,11 @@ def pause_generation(self): def continue_generation(self): return self._engine.continue_generation() + + def wait_quiet( + self, count: int, timeout: float | None = None + ) -> dict[str, Any] | None: + try: + return self._engine.wait(count, timeout=timeout) + except TimeoutError: + return "NO_RESULT" diff --git a/areal/scheduler/local.py b/areal/scheduler/local.py index bee6318d1..3865daedc 100644 --- a/areal/scheduler/local.py +++ b/areal/scheduler/local.py @@ -126,6 +126,7 @@ def __init__( self._allocated_ports = set() # HTTP clients for RPC communication + # FIXME: httpx may encounter "all connection attempts failed error" self._http_client = httpx.Client(timeout=3600.0) # Sync client - 1 hour timeout self._async_http_client = httpx.AsyncClient(timeout=3600.0) # Async client @@ -978,7 +979,7 @@ async def async_call_engine( ) try: - logger.debug( + logger.info( f"Async calling method '{method}' on worker '{worker_id}' (attempt {attempt})" ) diff --git a/areal/scheduler/rpc/sync_rpc_server.py b/areal/scheduler/rpc/sync_rpc_server.py index 8d0260627..a581ee0c9 100644 --- a/areal/scheduler/rpc/sync_rpc_server.py +++ b/areal/scheduler/rpc/sync_rpc_server.py @@ -2,11 +2,12 @@ import importlib import json import traceback +from concurrent.futures import Future from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any from areal.api.cli_args import BaseExperimentConfig -from areal.api.engine_api import TrainEngine +from areal.api.engine_api import InferenceEngine, TrainEngine from areal.platforms import current_platform from areal.scheduler.rpc.serialization import deserialize_value, serialize_value from areal.utils import logging, name_resolve, seeding, stats_tracker @@ -14,7 +15,7 @@ logger = logging.getLogger("SyncRPCServer") # Global engine instance - must be TrainEngine -_engine: TrainEngine | None = None +_engine: TrainEngine | InferenceEngine | None = None class SyncRPCHandler(BaseHTTPRequestHandler): @@ -150,10 +151,12 @@ def _handle_create_engine(self) -> None: engine_class = getattr(module, class_name) # Validate that the class is a TrainEngine - if not issubclass(engine_class, TrainEngine): + if not issubclass(engine_class, TrainEngine) and not issubclass( + engine_class, InferenceEngine + ): raise TypeError( - f"Engine class must be a subclass of TrainEngine, " - f"got {engine_class}. Use async_rpc_server for InferenceEngine." + f"Engine class must be a subclass of TrainEngine or InferenceEngine, " + f"got {engine_class}.." ) except (ValueError, ImportError, AttributeError) as e: logger.error(f"Failed to import engine '{engine_path}': {e}") @@ -194,7 +197,7 @@ def _handle_create_engine(self) -> None: def _handle_call_engine_method(self) -> None: """ - Call a method on the TrainEngine instance. + Call a method on the engine instance. Expected JSON payload: { @@ -232,7 +235,7 @@ def _handle_call_engine_method(self) -> None: try: should_bcast = kwargs.pop("_should_bcast", True) - if should_bcast: + if should_bcast and isinstance(_engine, TrainEngine): logger.info( f"Broadcasting data for TrainEngine method: {method_name}" ) @@ -241,7 +244,6 @@ def _handle_call_engine_method(self) -> None: tensor_container_to, ) - # TODO: to device here args = tensor_container_to(args, current_platform.current_device()) args = broadcast_tensor_container( args, @@ -266,6 +268,53 @@ def _handle_call_engine_method(self) -> None: ) return + # Special case for `submit` on infernece engines + try: + if method_name == "submit" and isinstance(_engine, InferenceEngine): + workflow_path = kwargs["workflow_path"] + workflow_kwargs = kwargs["workflow_kwargs"] + episode_data = kwargs["data"] + should_accept_path = kwargs["should_accept_path"] + + # Deserialize episode_data (may contain tensors) + episode_data = deserialize_value(episode_data) + + # Dynamic import workflow + module_path, class_name = workflow_path.rsplit(".", 1) + module = importlib.import_module(module_path) + workflow_class = getattr(module, class_name) + logger.info(f"Imported workflow class: {workflow_path}") + + # Instantiate workflow + workflow_kwargs = deserialize_value(workflow_kwargs) + workflow = workflow_class(**workflow_kwargs) + logger.info(f"Workflow '{workflow_path}' instantiated successfully") + + should_accept = None + if should_accept_path is not None: + # Dynamic import filtering function + module_path, fn_name = should_accept_path.rsplit(".", 1) + module = importlib.import_module(module_path) + should_accept = getattr(module, fn_name) + logger.info( + f"Imported filtering function: {should_accept_path}" + ) + + args = [] + kwargs = dict( + data=episode_data, + workflow=workflow, + should_accept=should_accept, + ) + except Exception as e: + logger.error( + f"Worklow data conversion failed: {e}\n{traceback.format_exc()}" + ) + self._send_json_response( + {"error": f"workflow data conversion failed: {str(e)}"}, 500 + ) + return + # Call method directly logger.info(f"Calling engine method: {method_name}") try: @@ -273,6 +322,10 @@ def _handle_call_engine_method(self) -> None: method = getattr(_engine, method_name) result = method(*args, **kwargs) + # HACK: handle update weights future + if isinstance(result, Future): + result = result.result() + # Serialize result (convert tensors to SerializedTensor dicts) serialized_result = serialize_value(result) self._send_json_response( @@ -305,6 +358,7 @@ def _handle_export_stats(self) -> None: return # TrainEngine: reduce stats across data_parallel_group + assert isinstance(_engine, TrainEngine) result = stats_tracker.export(reduce_group=_engine.data_parallel_group) self._send_json_response({"status": "success", "result": result}) From ce23d475c6d3ac8eebf0631f2bef0e7cfc6874bd Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Fri, 31 Oct 2025 15:16:48 +0800 Subject: [PATCH 31/52] update to flask rpc server --- areal/api/cli_args.py | 4 +- areal/scheduler/local.py | 4 +- areal/scheduler/rpc/async_rpc_server.py | 484 --------------------- areal/scheduler/rpc/rpc_server.py | 497 ++++++++++------------ areal/scheduler/rpc/sync_rpc_server.py | 403 ------------------ examples/single-controller/gsm8k_sft.yaml | 2 +- 6 files changed, 240 insertions(+), 1154 deletions(-) delete mode 100644 areal/scheduler/rpc/async_rpc_server.py delete mode 100644 areal/scheduler/rpc/sync_rpc_server.py diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 6e0bc80e7..4b439b562 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -433,7 +433,7 @@ class TrainEngineConfig: ) scheduling_spec: SchedulingSpec = field( default_factory=lambda: SchedulingSpec( - cmd="python -m areal.scheduler.rpc.sync_rpc_server" + cmd="python -m areal.scheduler.rpc.rpc_server" ), metadata={"help": "train engine schedule specs"}, ) @@ -910,7 +910,7 @@ class InferenceEngineConfig: ) scheduling_spec: SchedulingSpec = field( default_factory=lambda: SchedulingSpec( - cmd="python -m areal.scheduler.rpc.async_rpc_server" + cmd="python -m areal.scheduler.rpc.rpc_server" ), metadata={"help": "inference engine schedule specs"}, ) diff --git a/areal/scheduler/local.py b/areal/scheduler/local.py index 3865daedc..4cdcb5b55 100644 --- a/areal/scheduler/local.py +++ b/areal/scheduler/local.py @@ -371,8 +371,8 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: raise WorkerCreationError( role, f"SchedulingSpec.cmd is required but not set for worker {worker_id}", - "Specify either 'python -m areal.scheduler.rpc.async_rpc_server' or " - "'python -m areal.scheduler.rpc.sync_rpc_server' in your config.", + "Specify either 'python -m areal.scheduler.rpc.rpc_server' or " + "'python -m areal.scheduler.rpc.rpc_server' in your config.", ) cmd = shlex.split(scheduling.cmd) diff --git a/areal/scheduler/rpc/async_rpc_server.py b/areal/scheduler/rpc/async_rpc_server.py deleted file mode 100644 index 70eeaba02..000000000 --- a/areal/scheduler/rpc/async_rpc_server.py +++ /dev/null @@ -1,484 +0,0 @@ -"""Async FastAPI-based RPC server for InferenceEngine workers. - -This server runs on worker nodes to expose InferenceEngine methods via HTTP/JSON RPC. -It uses safe JSON serialization instead of cloudpickle and supports async workflow -execution via the /run_workflow endpoint. - -Key differences from sync_rpc_server: -- Multi-threaded: Uses FastAPI/uvicorn with async support -- InferenceEngine: Primarily for InferenceEngine (async rollout generation) -- Has /run_workflow: Supports direct workflow execution -- All async endpoints: All HTTP handlers are async functions -""" - -import argparse -import importlib -import traceback -from contextlib import asynccontextmanager -from typing import Any - -import orjson -import uvicorn -from fastapi import Body, FastAPI, HTTPException, Request -from fastapi.responses import ORJSONResponse - -from areal.api.cli_args import BaseExperimentConfig -from areal.api.engine_api import InferenceEngine, TrainEngine -from areal.scheduler.rpc.serialization import deserialize_value, serialize_value -from areal.utils import logging, name_resolve, seeding, stats_tracker - -logger = logging.getLogger("RPCServer") - -# Global engine instance - must be TrainEngine or InferenceEngine -_engine: TrainEngine | InferenceEngine | None = None - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Lifespan context manager for startup and shutdown events.""" - # Startup - logger.info("RPC server starting up...") - yield - # Shutdown - global _engine - logger.info("Shutting down RPC server...") - if _engine is not None: - try: - # Call destroy method if available - if hasattr(_engine, "destroy"): - _engine.destroy() - logger.info("Engine destroyed successfully") - except Exception as e: - logger.error(f"Error destroying engine: {e}") - _engine = None - - -app = FastAPI( - title="AReaL Worker RPC Server", - description="FastAPI-based RPC server for remote engine operations", - default_response_class=ORJSONResponse, - lifespan=lifespan, -) -app._expected_trajectory_keys = None - - -@app.get("/health") -async def health_check(): - """Health check endpoint to verify server is alive.""" - return {"status": "healthy", "engine_initialized": _engine is not None} - - -@app.post("/create_engine") -async def create_engine(data: dict[str, Any] = Body(...)): - """ - Create and initialize an engine instance on this worker. - - Expected JSON payload: - { - "engine": "areal.engine.ppo.actor.FSDPPPOActor", # Import path - "init_args": [...], # Positional arguments - "init_kwargs": {...} # Keyword arguments - } - """ - global _engine - - try: - engine_path = data.get("engine") - # Deserialize init_args and init_kwargs (may contain tensors or dataclasses) - init_args = deserialize_value(data.get("init_args", [])) - init_kwargs = deserialize_value(data.get("init_kwargs", {})) - - if not engine_path: - raise HTTPException( - status_code=400, detail="Missing 'engine' field in request" - ) - - # Dynamic import - try: - module_path, class_name = engine_path.rsplit(".", 1) - module = importlib.import_module(module_path) - engine_class = getattr(module, class_name) - - # Validate that the class is a TrainEngine or InferenceEngine - if not ( - issubclass(engine_class, TrainEngine) - or issubclass(engine_class, InferenceEngine) - ): - raise TypeError( - f"Engine class must be a subclass of TrainEngine or InferenceEngine, " - f"got {engine_class}" - ) - except (ValueError, ImportError, AttributeError) as e: - logger.error(f"Failed to import engine '{engine_path}': {e}") - raise HTTPException( - status_code=400, - detail=f"Failed to import engine '{engine_path}': {str(e)}", - ) - except TypeError as e: - logger.error(f"Invalid engine type: {e}") - raise HTTPException( - status_code=400, - detail=str(e), - ) - - # Instantiate engine - try: - _engine = engine_class(*init_args, **init_kwargs) - logger.info(f"Engine '{engine_path}' instantiated successfully") - return { - "status": "success", - "message": f"Engine '{engine_path}' created and initialized", - "result": None, - } - except Exception as e: - logger.error(f"Failed to instantiate engine: {e}\n{traceback.format_exc()}") - raise HTTPException( - status_code=500, - detail=f"Failed to instantiate engine: {str(e)}", - ) - - except HTTPException: - raise - except Exception as e: - logger.error( - f"Unexpected error in create_engine: {e}\n{traceback.format_exc()}" - ) - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@app.post("/call") -async def call_engine_method(data: dict[str, Any] = Body(...)): - """ - Call a method on the engine instance. - - Expected JSON payload: - { - "method": "train_batch", - "args": [...], - "kwargs": {...} - } - """ - global _engine - - if _engine is None: - raise HTTPException( - status_code=503, - detail="Engine not initialized. Call /create_engine first.", - ) - - try: - method_name = data.get("method") - args = data.get("args", []) - kwargs = data.get("kwargs", {}) - - if not method_name: - raise HTTPException( - status_code=400, detail="Missing 'method' field in request" - ) - - # Deserialize args and kwargs (convert SerializedTensor dicts to tensors) - args = deserialize_value(args) - kwargs = deserialize_value(kwargs) - - try: - should_bcast = kwargs.pop("_should_bcast", True) - if isinstance(_engine, TrainEngine) and should_bcast: - logger.info(f"Broadcasting data for TrainEngine method: {method_name}") - from areal.utils.data import broadcast_tensor_container - - args = broadcast_tensor_container( - args, - src_rank=_engine.current_data_parallel_head(), - group=_engine.context_and_model_parallel_group, - ) - kwargs = broadcast_tensor_container( - kwargs, - src_rank=_engine.current_data_parallel_head(), - group=_engine.context_and_model_parallel_group, - ) - logger.info("Broadcasting data done.") - except Exception as e: - logger.error( - f"Broadcasting data for method '{method_name}' failed: {e}\n{traceback.format_exc()}" - ) - raise HTTPException( - status_code=500, - detail=f"Data bcast '{method_name}' failed: {str(e)}", - ) - - # Call method directly (no need for hasattr/getattr with typed engine) - logger.info(f"Calling engine method: {method_name}") - try: - # Get the method - will raise AttributeError if it doesn't exist - method = getattr(_engine, method_name) - result = method(*args, **kwargs) - - # Serialize result (convert tensors to SerializedTensor dicts) - serialized_result = serialize_value(result) - return {"status": "success", "result": serialized_result} - - except AttributeError as e: - logger.error(f"Method '{method_name}' not found on engine: {e}") - raise HTTPException( - status_code=400, - detail=f"Engine does not have method '{method_name}'", - ) - except Exception as e: - logger.error( - f"Engine method '{method_name}' failed: {e}\n{traceback.format_exc()}" - ) - raise HTTPException( - status_code=500, - detail=f"Engine method '{method_name}' failed: {str(e)}", - ) - - except HTTPException: - raise - except Exception as e: - logger.error(f"Unexpected error in call: {e}\n{traceback.format_exc()}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@app.post("/run_workflow") -async def run_workflow(request: Request): - """ - Run a workflow's arun_episode method directly without using the engine. - - Expected JSON payload: - { - "workflow": "areal.api.workflow_api.RolloutWorkflow", # Import path - "workflow_kwargs": {...}, # Keyword arguments for workflow instantiation - "data": {...} # Data to pass to arun_episode - } - """ - try: - body = await request.body() - data = orjson.loads(body) - - workflow_path = data.get("workflow") - workflow_kwargs = data.get("workflow_kwargs") - episode_data = data.get("data") - should_accept_path = data.get("should_accept_path", None) - check_trajectory_format = data.get("check_trajectory_format") - - if not workflow_path: - raise HTTPException( - status_code=400, detail="Missing 'workflow' field in request" - ) - - if episode_data is None: - raise HTTPException( - status_code=400, detail="Missing 'data' field in request" - ) - - # Deserialize episode_data (may contain tensors) - episode_data = deserialize_value(episode_data) - - # Dynamic import workflow - try: - module_path, class_name = workflow_path.rsplit(".", 1) - module = importlib.import_module(module_path) - workflow_class = getattr(module, class_name) - logger.info(f"Imported workflow class: {workflow_path}") - except (ValueError, ImportError, AttributeError) as e: - logger.error(f"Failed to import workflow '{workflow_path}': {e}") - raise HTTPException( - status_code=400, - detail=f"Failed to import workflow '{workflow_path}': {str(e)}", - ) - # Instantiate workflow - try: - workflow_kwargs = deserialize_value(workflow_kwargs) - workflow = workflow_class(**workflow_kwargs) - logger.info(f"Workflow '{workflow_path}' instantiated successfully") - except Exception as e: - logger.error( - f"Failed to instantiate workflow: {e}\n{traceback.format_exc()}" - ) - raise HTTPException( - status_code=500, - detail=f"Failed to instantiate workflow: {str(e)}", - ) - - should_accept = None - if should_accept_path is not None: - # Dynamic import filtering function - try: - module_path, fn_name = should_accept_path.rsplit(".", 1) - module = importlib.import_module(module_path) - should_accept = getattr(module, fn_name) - logger.info(f"Imported filtering function: {should_accept_path}") - except (ValueError, ImportError, AttributeError) as e: - logger.error( - f"Failed to import filtering function '{should_accept_path}': {e}" - ) - raise HTTPException( - status_code=400, - detail=f"Failed to import filtering function '{should_accept_path}': {str(e)}", - ) - - # Run episode - try: - global _engine - traj = await workflow.arun_episode(_engine, episode_data) - - global app - if check_trajectory_format and traj is not None: - from areal.core.workflow_executor import ( - check_trajectory_format as check_fn, - ) - - check_fn( - traj, - expected_keys=app._expected_trajectory_keys, - logger=logger, - ) - # Track expected keys for consistency checking - if isinstance(traj, dict) and "input_ids" in traj: - if app._expected_trajectory_keys is None: - app._expected_trajectory_keys = set(traj.keys()) - logger.info( - f"Trajectory format check: tracking keys " - f"{app._expected_trajectory_keys}" - ) - - from areal.experimental.openai.types import InteractionWithTokenLogpReward - from areal.utils.data import concat_padded_tensors - - # Convert InteractionWithTokenLogpReward to tensor dict if needed - if isinstance(traj, dict) and all( - isinstance(v, InteractionWithTokenLogpReward) for v in traj.values() - ): - traj = concat_padded_tensors( - [v.to_tensor_dict() for v in traj.values()] - ) - - assert traj is None or isinstance(traj, dict), traj - - # Apply should_accept filtering - accept_this = traj is not None and ( - should_accept is None or should_accept(traj) - ) - - # Serialize trajectory result (convert tensors to SerializedTensor dicts) - if accept_this: - serialized_traj = serialize_value(traj) - return {"status": "success", "result": serialized_traj} - else: - return {"status": "success", "result": None} - except Exception as e: - logger.error(f"Workflow arun_episode failed: {e}\n{traceback.format_exc()}") - raise HTTPException( - status_code=500, - detail=f"Workflow arun_episode failed: {str(e)}", - ) - - except HTTPException: - raise - except Exception as e: - logger.error(f"Unexpected error in run_workflow: {e}\n{traceback.format_exc()}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@app.post("/configure") -async def configure(request: Request): - try: - body = await request.body() - data = orjson.loads(body) - - config = data.get("config") - if config is None: - raise HTTPException( - status_code=400, detail="Missing 'config' field in request" - ) - role = data.get("role") - if role is None: - raise HTTPException( - status_code=400, detail="Missing 'role' field in request" - ) - rank = data.get("rank") - if rank is None: - raise HTTPException( - status_code=400, detail="Missing 'rank' field in request" - ) - - config = deserialize_value(config) - config: BaseExperimentConfig - - name_resolve.reconfigure(config.cluster.name_resolve) - seeding.set_random_seed(config.seed, key=f"{role}{rank}") - - return {"status": "success", "result": None} - except HTTPException: - raise - except Exception as e: - logger.error(f"Unexpected error in configure: {e}\n{traceback.format_exc()}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@app.post("/export_stats") -async def export_stats(data: dict[str, Any] | None = Body(None)): - try: - assert data is None - - global _engine - if isinstance(_engine, TrainEngine): - return { - "status": "success", - "result": stats_tracker.export( - reduce_group=_engine.data_parallel_group - ), - } - else: - assert isinstance(_engine, InferenceEngine) - # Rollout engines do not have the collective communication channel. - # Return individual results and reduce them in the client side. - raw_stats = {} - for name, tracker in stats_tracker.TRACKERS.items(): - s = {name.strip("/") + k: v for k, v in tracker.stats.items()} - raw_stats.update(s) - # clear stats tracker - stats_tracker.export_all() - return {"status": "success", "result": raw_stats} - - except HTTPException: - raise - except Exception as e: - logger.error(f"Unexpected error in export_stats: {e}\n{traceback.format_exc()}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -def main(): - """Main entry point for the async RPC server.""" - parser = argparse.ArgumentParser( - description="AReaL Async RPC Server for InferenceEngine" - ) - parser.add_argument("--port", type=int, required=True, help="Port to serve on") - parser.add_argument( - "--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)" - ) - - args, _ = parser.parse_known_args() - - logger.info(f"Starting async RPC server on {args.host}:{args.port}") - - try: - # Run uvicorn server with a single worker (required for GPU workloads) - uvicorn.run( - app, - host=args.host, - port=args.port, - workers=1, # Single worker required for GPU memory management - log_level="info", - access_log=True, - ) - finally: - global _engine - if _engine is not None: - assert isinstance(_engine, InferenceEngine) - _engine.destroy_engine() - _engine.destroy() - - -if __name__ == "__main__": - main() diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index 08e081601..ad54b4f99 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -1,69 +1,78 @@ -"""Modern FastAPI-based RPC server for engine workers. - -This server runs on worker nodes to expose engine methods via HTTP/JSON RPC. -It uses safe JSON serialization instead of cloudpickle. -""" - import argparse import importlib import traceback -from contextlib import asynccontextmanager -from typing import Any +from concurrent.futures import Future -import orjson -import uvicorn -from fastapi import Body, FastAPI, HTTPException, Request -from fastapi.responses import ORJSONResponse +from flask import Flask, jsonify, request -from areal.api.engine_api import InferenceEngine +from areal.api.cli_args import BaseExperimentConfig +from areal.api.engine_api import InferenceEngine, TrainEngine +from areal.platforms import current_platform from areal.scheduler.rpc.serialization import deserialize_value, serialize_value -from areal.utils import logging, stats_tracker +from areal.utils import logging, name_resolve, seeding, stats_tracker +from areal.utils.data import ( + broadcast_tensor_container, + tensor_container_to, +) + +logger = logging.getLogger("SyncRPCServer") -logger = logging.getLogger("RPCServer") +# Global engine instance - must be TrainEngine or InferenceEngine +_engine: TrainEngine | InferenceEngine | None = None -# Global engine instance - must be InferenceEngine -_engine: InferenceEngine | None = None +# Create Flask app +app = Flask(__name__) -@asynccontextmanager -async def lifespan(app: FastAPI): - """Lifespan context manager for startup and shutdown events.""" - # Startup - logger.info("RPC server starting up...") - yield - # Shutdown +@app.route("/health", methods=["GET"]) +def health_check(): + """Health check endpoint to verify server is alive.""" global _engine - logger.info("Shutting down RPC server...") - if _engine is not None: - try: - # Call destroy method if available - if hasattr(_engine, "destroy"): - _engine.destroy() - logger.info("Engine destroyed successfully") - except Exception as e: - logger.error(f"Error destroying engine: {e}") - _engine = None + return jsonify({"status": "healthy", "engine_initialized": _engine is not None}) -app = FastAPI( - title="AReaL Worker RPC Server", - description="FastAPI-based RPC server for remote engine operations", - default_response_class=ORJSONResponse, - lifespan=lifespan, -) -app._expected_trajectory_keys = None +@app.route("/configure", methods=["POST"]) +def configure(): + """Configure worker with experiment config.""" + try: + data = request.get_json() + if data is None: + return jsonify({"detail": "Invalid JSON in request body"}), 400 + config = data.get("config") + if config is None: + return jsonify({"detail": "Missing 'config' field in request"}), 400 -@app.get("/health") -async def health_check(): - """Health check endpoint to verify server is alive.""" - return {"status": "healthy", "engine_initialized": _engine is not None} + role = data.get("role") + if role is None: + return jsonify({"detail": "Missing 'role' field in request"}), 400 + + rank = data.get("rank") + if rank is None: + return jsonify({"detail": "Missing 'rank' field in request"}), 400 + + config = deserialize_value(config) + config: BaseExperimentConfig + name_resolve.reconfigure(config.cluster.name_resolve) + seeding.set_random_seed(config.seed, key=f"{role}{rank}") -@app.post("/create_engine") -def create_engine(data: dict[str, Any] = Body(...)): + return jsonify( + { + "status": "success", + "message": "Worker configured successful.", + "result": None, + } + ) + except Exception as e: + logger.error(f"Unexpected error in configure: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + +@app.route("/create_engine", methods=["POST"]) +def create_engine(): """ - Create and initialize an engine instance on this worker. + Create and initialize a TrainEngine or InferenceEngine instance on this worker. Expected JSON payload: { @@ -75,15 +84,17 @@ def create_engine(data: dict[str, Any] = Body(...)): global _engine try: + data = request.get_json() + if data is None: + return jsonify({"error": "Invalid JSON in request body"}), 400 + engine_path = data.get("engine") # Deserialize init_args and init_kwargs (may contain tensors or dataclasses) init_args = deserialize_value(data.get("init_args", [])) init_kwargs = deserialize_value(data.get("init_kwargs", {})) if not engine_path: - raise HTTPException( - status_code=400, detail="Missing 'engine' field in request" - ) + return jsonify({"error": "Missing 'engine' field in request"}), 400 # Dynamic import try: @@ -91,51 +102,50 @@ def create_engine(data: dict[str, Any] = Body(...)): module = importlib.import_module(module_path) engine_class = getattr(module, class_name) - if not (issubclass(engine_class, InferenceEngine)): + # Validate that the class is a TrainEngine or InferenceEngine + if not issubclass(engine_class, TrainEngine) and not issubclass( + engine_class, InferenceEngine + ): raise TypeError( - f"Engine class must be a subclass of InferenceEngine, " - f"got {engine_class}" + f"Engine class must be a subclass of TrainEngine or InferenceEngine, " + f"got {engine_class}.." ) except (ValueError, ImportError, AttributeError) as e: logger.error(f"Failed to import engine '{engine_path}': {e}") - raise HTTPException( - status_code=400, - detail=f"Failed to import engine '{engine_path}': {str(e)}", + return ( + jsonify( + {"error": f"Failed to import engine '{engine_path}': {str(e)}"} + ), + 400, ) except TypeError as e: logger.error(f"Invalid engine type: {e}") - raise HTTPException( - status_code=400, - detail=str(e), - ) + return jsonify({"error": str(e)}), 400 # Instantiate engine try: _engine = engine_class(*init_args, **init_kwargs) logger.info(f"Engine '{engine_path}' instantiated successfully") - return { - "status": "success", - "message": f"Engine '{engine_path}' created and initialized", - "result": None, - } + return jsonify( + { + "status": "success", + "message": f"Engine '{engine_path}' created and initialized", + "result": None, + } + ) except Exception as e: logger.error(f"Failed to instantiate engine: {e}\n{traceback.format_exc()}") - raise HTTPException( - status_code=500, - detail=f"Failed to instantiate engine: {str(e)}", - ) + return jsonify({"error": f"Failed to instantiate engine: {str(e)}"}), 500 - except HTTPException: - raise except Exception as e: logger.error( f"Unexpected error in create_engine: {e}\n{traceback.format_exc()}" ) - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 -@app.post("/call") -def call_engine_method(data: dict[str, Any] = Body(...)): +@app.route("/call", methods=["POST"]) +def call_engine_method(): """ Call a method on the engine instance. @@ -149,244 +159,207 @@ def call_engine_method(data: dict[str, Any] = Body(...)): global _engine if _engine is None: - raise HTTPException( - status_code=503, - detail="Engine not initialized. Call /create_engine first.", + return ( + jsonify({"error": "Engine not initialized. Call /create_engine first."}), + 503, ) try: + data = request.get_json() + if data is None: + return jsonify({"error": "Invalid JSON in request body"}), 400 + method_name = data.get("method") args = data.get("args", []) kwargs = data.get("kwargs", {}) if not method_name: - raise HTTPException( - status_code=400, detail="Missing 'method' field in request" - ) + return jsonify({"error": "Missing 'method' field in request"}), 400 # Deserialize args and kwargs (convert SerializedTensor dicts to tensors) args = deserialize_value(args) kwargs = deserialize_value(kwargs) - # Call method directly (no need for hasattr/getattr with typed engine) + try: + should_bcast = kwargs.pop("_should_bcast", True) + if should_bcast and isinstance(_engine, TrainEngine): + logger.info(f"Broadcasting data for TrainEngine method: {method_name}") + + args = tensor_container_to(args, current_platform.current_device()) + args = broadcast_tensor_container( + args, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + kwargs = tensor_container_to(kwargs, current_platform.current_device()) + kwargs = broadcast_tensor_container( + kwargs, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + logger.info("Broadcasting data done.") + except Exception as e: + logger.error( + f"Broadcasting data for method '{method_name}' failed: {e}\n{traceback.format_exc()}" + ) + return ( + jsonify({"error": f"Data broadcast '{method_name}' failed: {str(e)}"}), + 500, + ) + + # Special case for `submit` on inference engines + try: + if method_name == "submit" and isinstance(_engine, InferenceEngine): + workflow_path = kwargs["workflow_path"] + workflow_kwargs = kwargs["workflow_kwargs"] + episode_data = kwargs["data"] + should_accept_path = kwargs["should_accept_path"] + + # Deserialize episode_data (may contain tensors) + episode_data = deserialize_value(episode_data) + + # Dynamic import workflow + module_path, class_name = workflow_path.rsplit(".", 1) + module = importlib.import_module(module_path) + workflow_class = getattr(module, class_name) + logger.info(f"Imported workflow class: {workflow_path}") + + # Instantiate workflow + workflow_kwargs = deserialize_value(workflow_kwargs) + workflow = workflow_class(**workflow_kwargs) + logger.info(f"Workflow '{workflow_path}' instantiated successfully") + + should_accept = None + if should_accept_path is not None: + # Dynamic import filtering function + module_path, fn_name = should_accept_path.rsplit(".", 1) + module = importlib.import_module(module_path) + should_accept = getattr(module, fn_name) + logger.info(f"Imported filtering function: {should_accept_path}") + + args = [] + kwargs = dict( + data=episode_data, + workflow=workflow, + should_accept=should_accept, + ) + except Exception as e: + logger.error( + f"Workflow data conversion failed: {e}\n{traceback.format_exc()}" + ) + return ( + jsonify({"error": f"workflow data conversion failed: {str(e)}"}), + 500, + ) + + # Call method directly logger.info(f"Calling engine method: {method_name}") try: # Get the method - will raise AttributeError if it doesn't exist method = getattr(_engine, method_name) result = method(*args, **kwargs) + # HACK: handle update weights future + if isinstance(result, Future): + result = result.result() + # Serialize result (convert tensors to SerializedTensor dicts) serialized_result = serialize_value(result) - return {"status": "success", "result": serialized_result} + return jsonify({"status": "success", "result": serialized_result}) except AttributeError as e: logger.error(f"Method '{method_name}' not found on engine: {e}") - raise HTTPException( - status_code=400, - detail=f"Engine does not have method '{method_name}'", + return ( + jsonify({"error": f"Engine does not have method '{method_name}'"}), + 400, ) except Exception as e: logger.error( f"Engine method '{method_name}' failed: {e}\n{traceback.format_exc()}" ) - raise HTTPException( - status_code=500, - detail=f"Engine method '{method_name}' failed: {str(e)}", + return ( + jsonify({"error": f"Engine method '{method_name}' failed: {str(e)}"}), + 500, ) - except HTTPException: - raise except Exception as e: logger.error(f"Unexpected error in call: {e}\n{traceback.format_exc()}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 -@app.post("/run_workflow") -async def run_workflow(request: Request): - """ - Run a workflow's arun_episode method directly without using the engine. - Expected JSON payload: - { - "workflow": "areal.api.workflow_api.RolloutWorkflow", # Import path - "workflow_kwargs": {...}, # Keyword arguments for workflow instantiation - "data": {...} # Data to pass to arun_episode - } - """ +@app.route("/export_stats", methods=["POST"]) +def export_stats(): + """Export training statistics from stats_tracker.""" try: - body = await request.body() - data = orjson.loads(body) - - workflow_path = data.get("workflow") - workflow_kwargs = data.get("workflow_kwargs") - episode_data = data.get("data") - should_accept_path = data.get("should_accept_path", None) - check_trajectory_format = data.get("check_trajectory_format") - - if not workflow_path: - raise HTTPException( - status_code=400, detail="Missing 'workflow' field in request" - ) - - if episode_data is None: - raise HTTPException( - status_code=400, detail="Missing 'data' field in request" - ) + global _engine + if _engine is None: + return jsonify({"error": "Engine not initialized"}), 503 - # Deserialize episode_data (may contain tensors) - episode_data = deserialize_value(episode_data) + # TrainEngine: reduce stats across data_parallel_group + assert isinstance(_engine, TrainEngine) + result = stats_tracker.export(reduce_group=_engine.data_parallel_group) + return jsonify({"status": "success", "result": result}) - # Dynamic import workflow - try: - module_path, class_name = workflow_path.rsplit(".", 1) - module = importlib.import_module(module_path) - workflow_class = getattr(module, class_name) - logger.info(f"Imported workflow class: {workflow_path}") - except (ValueError, ImportError, AttributeError) as e: - logger.error(f"Failed to import workflow '{workflow_path}': {e}") - raise HTTPException( - status_code=400, - detail=f"Failed to import workflow '{workflow_path}': {str(e)}", - ) - # Instantiate workflow - try: - workflow = workflow_class(**workflow_kwargs) - logger.info(f"Workflow '{workflow_path}' instantiated successfully") - except Exception as e: - logger.error( - f"Failed to instantiate workflow: {e}\n{traceback.format_exc()}" - ) - raise HTTPException( - status_code=500, - detail=f"Failed to instantiate workflow: {str(e)}", - ) + except Exception as e: + logger.error(f"Unexpected error in export_stats: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 - should_accept = None - if should_accept_path is not None: - # Dynamic import filtering function - try: - module_path, fn_name = should_accept_path.rsplit(".", 1) - module = importlib.import_module(module_path) - should_accept = getattr(module, fn_name) - logger.info(f"Imported filtering function: {should_accept_path}") - except (ValueError, ImportError, AttributeError) as e: - logger.error( - f"Failed to import filtering function '{should_accept_path}': {e}" - ) - raise HTTPException( - status_code=400, - detail=f"Failed to import filtering function '{should_accept_path}': {str(e)}", - ) - # Run episode +def cleanup_engine(): + """Clean up engine on shutdown.""" + global _engine + if _engine is not None: try: - global _engine - traj = await workflow.arun_episode(_engine, episode_data) - - global app - if check_trajectory_format and traj is not None: - from areal.core.workflow_executor import ( - check_trajectory_format as check_fn, - ) - - check_fn( - traj, - expected_keys=app._expected_trajectory_keys, - logger=logger, - ) - # Track expected keys for consistency checking - if isinstance(traj, dict) and "input_ids" in traj: - if app._expected_trajectory_keys is None: - app._expected_trajectory_keys = set(traj.keys()) - logger.info( - f"Trajectory format check: tracking keys " - f"{app._expected_trajectory_keys}" - ) - - from areal.experimental.openai.types import InteractionWithTokenLogpReward - from areal.utils.data import concat_padded_tensors - - # Convert InteractionWithTokenLogpReward to tensor dict if needed - if isinstance(traj, dict) and all( - isinstance(v, InteractionWithTokenLogpReward) for v in traj.values() - ): - traj = concat_padded_tensors( - [v.to_tensor_dict() for v in traj.values()] - ) - - assert traj is None or isinstance(traj, dict), traj - - # Apply should_accept filtering - accept_this = traj is not None and ( - should_accept is None or should_accept(traj) - ) - - # Serialize trajectory result (convert tensors to SerializedTensor dicts) - if accept_this: - serialized_traj = serialize_value(traj) - return {"status": "success", "result": serialized_traj} - else: - return {"status": "success", "result": None} + _engine.destroy() + logger.info("Engine destroyed successfully") except Exception as e: - logger.error(f"Workflow arun_episode failed: {e}\n{traceback.format_exc()}") - raise HTTPException( - status_code=500, - detail=f"Workflow arun_episode failed: {str(e)}", - ) - - except HTTPException: - raise - except Exception as e: - logger.error(f"Unexpected error in run_workflow: {e}\n{traceback.format_exc()}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@app.post("/export_stats") -def export_stats(data: dict[str, Any] | None = Body(None)): - try: - assert data is None - - global _engine - assert isinstance(_engine, InferenceEngine) - # Rollout engines do not have the collective communication channel. - # Return individual results and reduce them in the client side. - raw_stats = {} - for name, tracker in stats_tracker.TRACKERS.items(): - s = {name.strip("/") + k: v for k, v in tracker.stats.items()} - raw_stats.update(s) - # clear stats tracker - stats_tracker.export_all() - return {"status": "success", "result": raw_stats} - - except HTTPException: - raise - except Exception as e: - logger.error(f"Unexpected error in run_workflow: {e}\n{traceback.format_exc()}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + logger.error(f"Error destroying engine: {e}") + _engine = None def main(): - """Main entry point for the RPC server.""" - parser = argparse.ArgumentParser(description="AReaL Worker RPC Server") + """Main entry point for the sync RPC server.""" + parser = argparse.ArgumentParser( + description="AReaL Sync RPC Server for TrainEngine/InferenceEngine" + ) parser.add_argument("--port", type=int, required=True, help="Port to serve on") parser.add_argument( "--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)" ) + parser.add_argument( + "--werkzeug-log-level", + type=str, + default="WARNING", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Log level for Werkzeug (Flask's WSGI server). Default: WARNING", + ) args, _ = parser.parse_known_args() - port = args.port - - logger.info(f"Starting RPC server on {args.host}:{port}") - - # Run uvicorn server with a single worker (required for GPU workloads) - uvicorn.run( - app, - host=args.host, - port=port, - workers=1, # Single worker required for GPU memory management - log_level="info", - access_log=True, - ) + + # Configure Werkzeug logging + import logging as stdlib_logging + + werkzeug_logger = stdlib_logging.getLogger("werkzeug") + werkzeug_logger.setLevel(getattr(stdlib_logging, args.werkzeug_log_level)) + + logger.info(f"Starting sync RPC server on {args.host}:{args.port}") + logger.info(f"Werkzeug log level: {args.werkzeug_log_level}") + + # Run Flask app with single-threaded synchronous mode + # threaded=False ensures NCCL compatibility + try: + app.run( + host=args.host, + port=args.port, + threaded=False, # Single-threaded synchronous execution + processes=1, # Single process + debug=False, + use_reloader=False, + ) + except KeyboardInterrupt: + logger.info("Shutting down sync RPC server") + finally: + cleanup_engine() if __name__ == "__main__": diff --git a/areal/scheduler/rpc/sync_rpc_server.py b/areal/scheduler/rpc/sync_rpc_server.py deleted file mode 100644 index a581ee0c9..000000000 --- a/areal/scheduler/rpc/sync_rpc_server.py +++ /dev/null @@ -1,403 +0,0 @@ -import argparse -import importlib -import json -import traceback -from concurrent.futures import Future -from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Any - -from areal.api.cli_args import BaseExperimentConfig -from areal.api.engine_api import InferenceEngine, TrainEngine -from areal.platforms import current_platform -from areal.scheduler.rpc.serialization import deserialize_value, serialize_value -from areal.utils import logging, name_resolve, seeding, stats_tracker - -logger = logging.getLogger("SyncRPCServer") - -# Global engine instance - must be TrainEngine -_engine: TrainEngine | InferenceEngine | None = None - - -class SyncRPCHandler(BaseHTTPRequestHandler): - """HTTP request handler for sync RPC server endpoints.""" - - def log_message(self, format: str, *args: Any) -> None: - """Override to use our logger instead of stderr.""" - logger.debug(f"{self.address_string()} - {format % args}") - - def _send_json_response(self, data: dict, status_code: int = 200) -> None: - """Send JSON response with appropriate headers.""" - self.send_response(status_code) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(json.dumps(data).encode("utf-8")) - - def _read_json_body(self) -> dict | None: - """Read and parse JSON request body.""" - try: - content_length = int(self.headers.get("Content-Length", 0)) - if content_length == 0: - return {} - body = self.rfile.read(content_length) - return json.loads(body.decode("utf-8")) - except (json.JSONDecodeError, ValueError) as e: - logger.error(f"Failed to parse JSON body: {e}") - self._send_json_response( - {"error": f"Invalid JSON in request body: {str(e)}"}, 400 - ) - return None - - def do_GET(self) -> None: - """Handle GET requests.""" - if self.path == "/health": - self._handle_health_check() - else: - self._send_json_response({"error": f"Not found: {self.path}"}, 404) - - def do_POST(self) -> None: - """Handle POST requests.""" - if self.path == "/create_engine": - self._handle_create_engine() - elif self.path == "/call": - self._handle_call_engine_method() - elif self.path == "/export_stats": - self._handle_export_stats() - elif self.path == "/configure": - self._handle_configure() - else: - self._send_json_response({"error": f"Not found: {self.path}"}, 404) - - def _handle_health_check(self) -> None: - """Health check endpoint to verify server is alive.""" - global _engine - self._send_json_response( - {"status": "healthy", "engine_initialized": _engine is not None} - ) - - def _handle_configure(self) -> None: - try: - data = self._read_json_body() - if data is None: - return - - config = data.get("config") - if config is None: - self._send_json_response( - {"detail": "Missing 'config' field in request"}, 400 - ) - role = data.get("role") - if role is None: - self._send_json_response( - {"detail": "Missing 'role' field in request"}, 400 - ) - rank = data.get("rank") - if rank is None: - self._send_json_response( - {"detail": "Missing 'rank' field in request"}, 400 - ) - - config = deserialize_value(config) - config: BaseExperimentConfig - - name_resolve.reconfigure(config.cluster.name_resolve) - - seeding.set_random_seed(config.seed, key=f"{role}{rank}") - self._send_json_response( - { - "status": "success", - "message": "Worker configured successful.", - "result": None, - } - ) - except Exception as e: - logger.error( - f"Unexpected error in configure: {e}\n{traceback.format_exc()}" - ) - self._send_json_response({"error": f"Internal server error: {str(e)}"}, 500) - - def _handle_create_engine(self) -> None: - """ - Create and initialize a TrainEngine instance on this worker. - - Expected JSON payload: - { - "engine": "areal.engine.ppo.actor.FSDPPPOActor", # Import path - "init_args": [...], # Positional arguments - "init_kwargs": {...} # Keyword arguments - } - """ - global _engine - - try: - data = self._read_json_body() - if data is None: - return - - engine_path = data.get("engine") - # Deserialize init_args and init_kwargs (may contain tensors or dataclasses) - init_args = deserialize_value(data.get("init_args", [])) - init_kwargs = deserialize_value(data.get("init_kwargs", {})) - - if not engine_path: - self._send_json_response( - {"error": "Missing 'engine' field in request"}, 400 - ) - return - - # Dynamic import - try: - module_path, class_name = engine_path.rsplit(".", 1) - module = importlib.import_module(module_path) - engine_class = getattr(module, class_name) - - # Validate that the class is a TrainEngine - if not issubclass(engine_class, TrainEngine) and not issubclass( - engine_class, InferenceEngine - ): - raise TypeError( - f"Engine class must be a subclass of TrainEngine or InferenceEngine, " - f"got {engine_class}.." - ) - except (ValueError, ImportError, AttributeError) as e: - logger.error(f"Failed to import engine '{engine_path}': {e}") - self._send_json_response( - {"error": f"Failed to import engine '{engine_path}': {str(e)}"}, - 400, - ) - return - except TypeError as e: - logger.error(f"Invalid engine type: {e}") - self._send_json_response({"error": str(e)}, 400) - return - - # Instantiate engine - try: - _engine = engine_class(*init_args, **init_kwargs) - logger.info(f"Engine '{engine_path}' instantiated successfully") - self._send_json_response( - { - "status": "success", - "message": f"Engine '{engine_path}' created and initialized", - "result": None, - } - ) - except Exception as e: - logger.error( - f"Failed to instantiate engine: {e}\n{traceback.format_exc()}" - ) - self._send_json_response( - {"error": f"Failed to instantiate engine: {str(e)}"}, 500 - ) - - except Exception as e: - logger.error( - f"Unexpected error in create_engine: {e}\n{traceback.format_exc()}" - ) - self._send_json_response({"error": f"Internal server error: {str(e)}"}, 500) - - def _handle_call_engine_method(self) -> None: - """ - Call a method on the engine instance. - - Expected JSON payload: - { - "method": "train_batch", - "args": [...], - "kwargs": {...} - } - """ - global _engine - - if _engine is None: - self._send_json_response( - {"error": "Engine not initialized. Call /create_engine first."}, 503 - ) - return - - try: - data = self._read_json_body() - if data is None: - return - - method_name = data.get("method") - args = data.get("args", []) - kwargs = data.get("kwargs", {}) - - if not method_name: - self._send_json_response( - {"error": "Missing 'method' field in request"}, 400 - ) - return - - # Deserialize args and kwargs (convert SerializedTensor dicts to tensors) - args = deserialize_value(args) - kwargs = deserialize_value(kwargs) - - try: - should_bcast = kwargs.pop("_should_bcast", True) - if should_bcast and isinstance(_engine, TrainEngine): - logger.info( - f"Broadcasting data for TrainEngine method: {method_name}" - ) - from areal.utils.data import ( - broadcast_tensor_container, - tensor_container_to, - ) - - args = tensor_container_to(args, current_platform.current_device()) - args = broadcast_tensor_container( - args, - src_rank=_engine.current_data_parallel_head(), - group=_engine.context_and_model_parallel_group, - ) - kwargs = tensor_container_to( - kwargs, current_platform.current_device() - ) - kwargs = broadcast_tensor_container( - kwargs, - src_rank=_engine.current_data_parallel_head(), - group=_engine.context_and_model_parallel_group, - ) - logger.info("Broadcasting data done.") - except Exception as e: - logger.error( - f"Broadcasting data for method '{method_name}' failed: {e}\n{traceback.format_exc()}" - ) - self._send_json_response( - {"error": f"Data broadcast '{method_name}' failed: {str(e)}"}, 500 - ) - return - - # Special case for `submit` on infernece engines - try: - if method_name == "submit" and isinstance(_engine, InferenceEngine): - workflow_path = kwargs["workflow_path"] - workflow_kwargs = kwargs["workflow_kwargs"] - episode_data = kwargs["data"] - should_accept_path = kwargs["should_accept_path"] - - # Deserialize episode_data (may contain tensors) - episode_data = deserialize_value(episode_data) - - # Dynamic import workflow - module_path, class_name = workflow_path.rsplit(".", 1) - module = importlib.import_module(module_path) - workflow_class = getattr(module, class_name) - logger.info(f"Imported workflow class: {workflow_path}") - - # Instantiate workflow - workflow_kwargs = deserialize_value(workflow_kwargs) - workflow = workflow_class(**workflow_kwargs) - logger.info(f"Workflow '{workflow_path}' instantiated successfully") - - should_accept = None - if should_accept_path is not None: - # Dynamic import filtering function - module_path, fn_name = should_accept_path.rsplit(".", 1) - module = importlib.import_module(module_path) - should_accept = getattr(module, fn_name) - logger.info( - f"Imported filtering function: {should_accept_path}" - ) - - args = [] - kwargs = dict( - data=episode_data, - workflow=workflow, - should_accept=should_accept, - ) - except Exception as e: - logger.error( - f"Worklow data conversion failed: {e}\n{traceback.format_exc()}" - ) - self._send_json_response( - {"error": f"workflow data conversion failed: {str(e)}"}, 500 - ) - return - - # Call method directly - logger.info(f"Calling engine method: {method_name}") - try: - # Get the method - will raise AttributeError if it doesn't exist - method = getattr(_engine, method_name) - result = method(*args, **kwargs) - - # HACK: handle update weights future - if isinstance(result, Future): - result = result.result() - - # Serialize result (convert tensors to SerializedTensor dicts) - serialized_result = serialize_value(result) - self._send_json_response( - {"status": "success", "result": serialized_result} - ) - - except AttributeError as e: - logger.error(f"Method '{method_name}' not found on engine: {e}") - self._send_json_response( - {"error": f"Engine does not have method '{method_name}'"}, 400 - ) - except Exception as e: - logger.error( - f"Engine method '{method_name}' failed: {e}\n{traceback.format_exc()}" - ) - self._send_json_response( - {"error": f"Engine method '{method_name}' failed: {str(e)}"}, 500 - ) - - except Exception as e: - logger.error(f"Unexpected error in call: {e}\n{traceback.format_exc()}") - self._send_json_response({"error": f"Internal server error: {str(e)}"}, 500) - - def _handle_export_stats(self) -> None: - """Export training statistics from stats_tracker.""" - try: - global _engine - if _engine is None: - self._send_json_response({"error": "Engine not initialized"}, 503) - return - - # TrainEngine: reduce stats across data_parallel_group - assert isinstance(_engine, TrainEngine) - result = stats_tracker.export(reduce_group=_engine.data_parallel_group) - self._send_json_response({"status": "success", "result": result}) - - except Exception as e: - logger.error( - f"Unexpected error in export_stats: {e}\n{traceback.format_exc()}" - ) - self._send_json_response({"error": f"Internal server error: {str(e)}"}, 500) - - -def main(): - """Main entry point for the sync RPC server.""" - parser = argparse.ArgumentParser( - description="AReaL Sync RPC Server for TrainEngine" - ) - parser.add_argument("--port", type=int, required=True, help="Port to serve on") - parser.add_argument( - "--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)" - ) - - args, _ = parser.parse_known_args() - - logger.info(f"Starting sync RPC server on {args.host}:{args.port}") - - # Create and run single-threaded HTTP server - # HTTPServer is single-threaded by default (processes one request at a time) - # This ensures NCCL compatibility - server = HTTPServer((args.host, args.port), SyncRPCHandler) - - try: - server.serve_forever() - except KeyboardInterrupt: - logger.info("Shutting down sync RPC server") - server.shutdown() - finally: - global _engine - if _engine is not None: - _engine.destroy() - - -if __name__ == "__main__": - main() diff --git a/examples/single-controller/gsm8k_sft.yaml b/examples/single-controller/gsm8k_sft.yaml index 0efeeaac3..ee3ea0726 100644 --- a/examples/single-controller/gsm8k_sft.yaml +++ b/examples/single-controller/gsm8k_sft.yaml @@ -38,7 +38,7 @@ model: type: worker port_count: 1 gpu: 1 - cmd: python3 -m areal.scheduler.rpc.sync_rpc_server + cmd: python3 -m areal.scheduler.rpc.rpc_server train_dataset: batch_size: 128 From bb60c357da53394251cc73e9b5c28366bbc71d5b Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Fri, 31 Oct 2025 15:56:03 +0800 Subject: [PATCH 32/52] add grpo example --- examples/single-controller/gsm8k_grpo.py | 216 +++++++++++++++++++++ examples/single-controller/gsm8k_grpo.yaml | 171 ++++++++++++++++ 2 files changed, 387 insertions(+) create mode 100644 examples/single-controller/gsm8k_grpo.py create mode 100644 examples/single-controller/gsm8k_grpo.yaml diff --git a/examples/single-controller/gsm8k_grpo.py b/examples/single-controller/gsm8k_grpo.py new file mode 100644 index 000000000..415028966 --- /dev/null +++ b/examples/single-controller/gsm8k_grpo.py @@ -0,0 +1,216 @@ +import os +import sys + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import GRPOConfig, SGLangConfig, load_expr_config +from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta +from areal.controller.rollout_controller import RolloutController +from areal.controller.train_controller import TrainController +from areal.dataset import get_custom_dataset +from areal.engine.ppo.actor import FSDPPPOActor +from areal.engine.sglang_remote import RemoteSGLangEngine +from areal.scheduler.local import LocalScheduler +from areal.utils import stats_tracker +from areal.utils.data import ( + cycle_dataloader, +) +from areal.utils.dataloader import create_dataloader +from areal.utils.device import log_gpu_stats +from areal.utils.evaluator import Evaluator +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.recover import RecoverHandler +from areal.utils.saver import Saver +from areal.utils.stats_logger import StatsLogger + + +def main(args): + config, _ = load_expr_config(args, GRPOConfig) + config: GRPOConfig + + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + # Create dataset and dataloaders + train_dataset = get_custom_dataset( + split="train", dataset_config=config.train_dataset, tokenizer=tokenizer + ) + + train_dataloader = create_dataloader( + train_dataset, + rank=0, + world_size=1, + dataset_config=config.train_dataset, + ) + + ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=len(train_dataloader) * config.train_dataset.batch_size, + train_batch_size=config.train_dataset.batch_size, + ) + + # Initialize scheduler + scheduler = LocalScheduler(exp_config=config) + + # Initialize train controller + allocation_mode = AllocationMode.from_str(config.allocation_mode) + actor = TrainController(FSDPPPOActor, config=config.actor, scheduler=scheduler) + actor.initialize( + role="actor", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None + ) + + # Initialize inference engine + rollout = RolloutController( + RemoteSGLangEngine, config=config.rollout, scheduler=scheduler + ) + rollout.initialize( + role="rollout", + alloc_mode=allocation_mode, + engine_args=SGLangConfig.build_args( + sglang_config=config.sglang, + tp_size=allocation_mode.gen.tp_size, + base_gpu_id=0, + ), + ) + + weight_update_meta = WeightUpdateMeta.from_disk( + experiment_name=config.experiment_name, + trial_name=config.trial_name, + file_root=config.cluster.fileroot, + ) + actor.connect_engine(rollout, weight_update_meta) + + ref = None + if config.actor.kl_ctl > 0 and config.ref is not None: + ref = TrainController(FSDPPPOActor, config=config.ref, scheduler=scheduler) + ref.initialize( + role="ref", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None + ) + + # Run training. + saver = Saver(config.saver, ft_spec) + stats_logger = StatsLogger(config, ft_spec) + evaluator = Evaluator(config.evaluator, ft_spec) + + recover_handler = RecoverHandler(config.recover, ft_spec) + + try: + recover_info = recover_handler.load( + actor, + saver, + evaluator, + stats_logger, + train_dataloader, + inference_engine=rollout, + weight_update_meta=weight_update_meta, + ) + start_step = ( + recover_info.last_step_info.next().global_step + if recover_info is not None + else 0 + ) + + total_epochs = config.total_train_epochs + steps_per_epoch = len(train_dataloader) + max_steps = total_epochs * steps_per_epoch + + data_generator = cycle_dataloader(train_dataloader) + for global_step in range(start_step, max_steps): + epoch = global_step // steps_per_epoch + step = global_step % steps_per_epoch + step_info = StepInfo( + global_step=global_step, + epoch=epoch, + epoch_step=step, + steps_per_epoch=steps_per_epoch, + ) + + with stats_tracker.record_timing("rollout"): + if config.async_training: + batch = actor.prepare_batch( + train_dataloader, + workflow_path="areal.workflow.rlvr.RLVRWorkflow", + workflow_kwargs=dict( + reward_fn="areal.reward.gsm8k.gsm8k_reward_fn", + gconfig=config.gconfig, + tokenizer=config.tokenizer_path, + enable_thinking=False, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), + "generated", + ), + ), + ) + else: + batch = actor.rollout_batch( + next(data_generator), + workflow_path="areal.workflow.rlvr.RLVRWorkflow", + workflow_kwargs=dict( + reward_fn="areal.reward.gsm8k.gsm8k_reward_fn", + gconfig=config.gconfig, + tokenizer=config.tokenizer_path, + enable_thinking=False, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), + "generated", + ), + ), + ) + + if config.actor.recompute_logprob or config.actor.use_decoupled_loss: + with stats_tracker.record_timing("recompute_logp"): + logp = actor.compute_logp(batch) + batch["prox_logp"] = logp + log_gpu_stats("recompute logp") + + if ref is not None: + with stats_tracker.record_timing("ref_logp"): + batch["ref_logp"] = ref.compute_logp(batch) + log_gpu_stats("ref logp") + + with stats_tracker.record_timing("compute_advantage"): + batch = actor.compute_advantages(batch) + log_gpu_stats("compute advantages") + + with stats_tracker.record_timing("train_step"): + actor.ppo_update(batch) + actor.step_lr_scheduler() + log_gpu_stats("ppo update") + + # pause inference for updating weights, save, and evaluation + rollout.pause() + + with stats_tracker.record_timing("update_weights"): + actor.update_weights(weight_update_meta) + + actor.set_version(global_step + 1) + rollout.set_version(global_step + 1) + + with stats_tracker.record_timing("save"): + saver.save(actor, epoch, step, global_step, tokenizer=tokenizer) + + with stats_tracker.record_timing("checkpoint_for_recover"): + recover_handler.dump( + actor, + step_info, + saver, + evaluator, + stats_logger, + train_dataloader, + tokenizer=tokenizer, + ) + + # Upload statistics to the logger (e.g., wandb) + stats_logger.commit(epoch, step, global_step, actor.export_stats()) + + # Resume rollout + rollout.resume() + + finally: + stats_logger.close() + rollout.destroy() + if ref is not None: + ref.destroy() + actor.destroy() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/single-controller/gsm8k_grpo.yaml b/examples/single-controller/gsm8k_grpo.yaml new file mode 100644 index 000000000..ef0b0a36b --- /dev/null +++ b/examples/single-controller/gsm8k_grpo.yaml @@ -0,0 +1,171 @@ +experiment_name: gsm8k-grpo +trial_name: trial0 + +seed: 1 +total_train_epochs: 10 +tokenizer_path: ${actor.path} +async_training: true + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /storage/openpsi/experiments + name_resolve: + type: nfs + nfs_record_root: /storage/openpsi/name_resolve + +allocation_mode: sglang.d4p1t1+d4p1t1 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 16 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: + type: worker + port_count: 1 + gpu: 1 + cmd: python3 -m areal.scheduler.rpc.sync_rpc_server + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /storage/openpsi/models/Qwen__Qwen2.5-1.5B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + backend: fsdp + group_size: ${gconfig.n_samples} + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: 5.0 + dynamic_sampling: false + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + scheduling_spec: + type: worker + port_count: 1 + gpu: 1 + cmd: python3 -m areal.scheduler.rpc.sync_rpc_server + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + backend: fsdp + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: + type: worker + port_count: 1 + gpu: 1 + cmd: python3 -m areal.scheduler.rpc.sync_rpc_server + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +# datasets +train_dataset: + batch_size: 16 + shuffle: true + pin_memory: true + num_workers: 4 + path: /storage/openpsi/data/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 16 + shuffle: true + pin_memory: true + num_workers: 4 + path: /storage/openpsi/data/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +launcher: + inference_server_cpus_per_gpu: 4 + inference_server_mem_per_gpu: 32768 + trainer_cpus_per_gpu: 4 + trainer_mem_per_gpu: 32768 From ae1d6a2afa57bb0daf4b3baa97d820ec1cd1877c Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Fri, 31 Oct 2025 16:49:59 +0800 Subject: [PATCH 33/52] remove local inference engine --- areal/core/local_inf_engine.py | 548 -------------------------- areal/engine/sglang_local.py | 368 ----------------- areal/engine/vllm_local.py | 373 ------------------ areal/tests/test_inference_engines.py | 148 +++---- 4 files changed, 55 insertions(+), 1382 deletions(-) delete mode 100644 areal/core/local_inf_engine.py delete mode 100644 areal/engine/sglang_local.py delete mode 100644 areal/engine/vllm_local.py diff --git a/areal/core/local_inf_engine.py b/areal/core/local_inf_engine.py deleted file mode 100644 index f19d3ed17..000000000 --- a/areal/core/local_inf_engine.py +++ /dev/null @@ -1,548 +0,0 @@ -import asyncio -import time -import uuid -from collections.abc import Callable -from threading import Lock -from typing import Any, Protocol - -import torch.distributed as dist -from torchdata.stateful_dataloader import StatefulDataLoader - -from areal.api.cli_args import InferenceEngineConfig -from areal.api.io_struct import ( - ModelRequest, - ModelResponse, - ParamSpec, - WeightUpdateMeta, -) -from areal.api.workflow_api import RolloutWorkflow -from areal.platforms import current_platform -from areal.utils import logging, name_resolve, names - -from .workflow_executor import WorkflowExecutor - - -class LocalInfBackendProtocol(Protocol): - """Protocol defining backend-specific operations for local inference engines. - - This protocol abstracts the differences between various local inference engines - (SGLang, vLLM, etc.) by defining a common interface for: - - Creating and managing local engine instances - - Performing async generation - - Handling weight updates (both disk and distributed) - - Managing engine lifecycle - - Implementations can raise NotImplementedError for unsupported features. - """ - - def create_engine(self, engine_args: dict[str, Any]) -> Any: - """Create a local inference engine instance. - - Parameters - ---------- - engine_args : Dict[str, Any] - Arguments to pass to the engine constructor - - Returns - ------- - Any - The created engine instance - """ - ... - - async def async_generation(self, engine: Any, req: ModelRequest) -> ModelResponse: - """Perform async generation using the local engine. - - Parameters - ---------- - engine : Any - The engine instance - req : ModelRequest - The generation request containing input and parameters - - Returns - ------- - ModelResponse - The generated response with tokens, logprobs, and metadata - """ - ... - - def update_weight_disk(self, engine: Any, model_path: str) -> None: - """Update weights from disk synchronously. - - Parameters - ---------- - engine : Any - The engine instance - model_path : str - Path to the model weights on disk - """ - ... - - def update_weight_xccl( - self, - engine: Any, - meta: WeightUpdateMeta, - param_specs: list[ParamSpec], - ) -> None: - """Update weights from distributed memory via NCCL/XCCL synchronously. - - Parameters - ---------- - engine : Any - The engine instance - meta : WeightUpdateMeta - Metadata containing communication group info - param_specs : List[ParamSpec] - Specifications for parameters to be updated - """ - ... - - def init_update_weight_group( - self, engine: Any, meta: WeightUpdateMeta, rank_offset: int - ) -> None: - """Initialize weight update communication group synchronously. - - Parameters - ---------- - engine : Any - The engine instance - meta : WeightUpdateMeta - Metadata containing communication backend configuration - rank_offset : int - Rank offset for this engine in the communication group - """ - ... - - def destroy(self, engine: Any) -> None: - """Destroy the engine and release resources. - - Parameters - ---------- - engine : Any - The engine instance to destroy - """ - ... - - def pause_generation(self) -> None: - """Pause generation.""" - ... - - def continue_generation(self) -> None: - """Continue generation.""" - ... - - -class LocalInfEngine: - """ - Base implementation for local in-process inference engines. - - This class provides common functionality for running inference engines - within the same process. Backend-specific behaviors are delegated to - an injected LocalInfBackendProtocol implementation. - - Uses composition pattern - instantiate directly with a backend rather - than inheriting from this class. - - Parameters - ---------- - config : InferenceEngineConfig - Configuration for the inference engine - backend : LocalInfBackendProtocol - Backend implementation providing engine-specific behavior - """ - - def __init__(self, config: InferenceEngineConfig, backend: LocalInfBackendProtocol): - self.config = config - self.backend = backend - - self.engine = None - self.distributed_weight_update_initialized = False - self._version = 0 - - self.lock = Lock() - - self.workflow_executor: WorkflowExecutor - - def configure(self, config): - self.config = config - - def create_engine(self, engine_args: dict[str, Any] | None = None): - # Create the local engine via backend - engine_args = engine_args or {} - self.engine = self.backend.create_engine(engine_args) - - def destroy_engine(self): - if self.engine is not None: - self.backend.destroy(self.engine) - self.engine = None - - def initialize( - self, - engine_id: str | None = None, - train_data_parallel_size: int | None = None, - ): - """Initialize the engine by creating the local inference engine. - - Parameters - ---------- - engine_id : Optional[str] - Unique identifier for this engine instance - engine_args : Optional[Dict[str, Any]] - Arguments to pass to the backend engine constructor - train_data_parallel_size : int | None - Data parallel size of the training engine - """ - if engine_id is None: - if dist.is_initialized(): - engine_id = str(dist.get_rank()) - else: - engine_id = uuid.uuid4().hex - self.engine_id = engine_id - self.logger = logging.getLogger(f"[Local Inference Engine Rank {engine_id}]") - - # Initialize thread pool for non-blocking weight updates - # FIXME: develop a principled update methods with/without thread pool - - # Initialize workflow executor - self.workflow_executor = WorkflowExecutor( - config=self.config, - inference_engine=self, - ) - self.workflow_executor.initialize( - logger=self.logger, train_data_parallel_size=train_data_parallel_size - ) - - def destroy(self): - """Destroy the engine and clean up resources.""" - if getattr(self, "workflow_executor"): - self.workflow_executor.destroy() - self.workflow_executor = None - - def set_version(self, version: int): - """Set the current weight version.""" - with self.lock: - self._version = version - - def get_version(self) -> int: - """Get the current weight version.""" - with self.lock: - return self._version - - async def agenerate(self, req: ModelRequest) -> ModelResponse: - """Asynchronously generate a response for the given request. - - Parameters - ---------- - req : ModelRequest - The model request containing input data and generation parameters - - Returns - ------- - ModelResponse - The generated response from the model - """ - if self.engine is None: - raise RuntimeError( - "Local inference engine is not initialized, cannot generate." - ) - - # Create a shallow copy of the input request - # we are going to modify it in-place - req = req.copy() - - # Validate n_samples - gconfig = req.gconfig - if gconfig.n_samples != 1: - raise ValueError( - "Local inference engines do not support n_samples > 1. " - "Please call generate multiple times with n_samples = 1." - ) - - # Validate max_new_tokens - max_new_tokens = min( - gconfig.max_tokens - len(req.input_ids), gconfig.max_new_tokens - ) - if max_new_tokens <= 0: - raise RuntimeError( - f"max_new_tokens ({max_new_tokens}) is non-positive! " - f"max_tokens={gconfig.max_tokens}, prompt_len={len(req.input_ids)}, " - f"max_new_tokens={gconfig.max_new_tokens}." - ) - - # Update max_new_tokens in request - req.gconfig.max_new_tokens = max_new_tokens - - # Make request - start_time = time.perf_counter() - accumulated_output_tokens = [] - accumulated_output_logprobs = [] - accumulated_versions = [] - - # Loop until generation is complete - stop_reason = None - while ( - stop_reason not in ["stop", "tool_calls", "length"] - and len(accumulated_output_tokens) < gconfig.max_new_tokens - ): - # Handle rollout interruption - while self.workflow_executor.is_paused(): - await asyncio.sleep(0.5) - - # Call backend async_generation - response = await self.backend.async_generation(self.engine, req) - # Extract result - output_tokens = response.output_tokens - output_logprobs = response.output_logprobs - stop_reason = response.stop_reason - - # Update accumulated outputs - accumulated_output_tokens.extend(output_tokens) - accumulated_output_logprobs.extend(output_logprobs) - accumulated_versions.extend([self.get_version()] * len(output_tokens)) - - # Update request for next iteration - req.input_ids += output_tokens - req.gconfig.max_new_tokens -= len(output_tokens) - assert req.gconfig.max_new_tokens >= 0, ( - req.gconfig.max_new_tokens, - len(output_tokens), - len(req.input_ids), - ) - - # Final abort handling - if stop_reason == "abort": - # If stop_reason is "abort", the only reason we exit the loop is - # len(accumulated_output_tokens) >= gconfig.max_new_tokens - # so the actual reason is length - stop_reason = "length" - - latency = time.perf_counter() - start_time - - response = ModelResponse( - input_tokens=req.input_ids[ - : len(req.input_ids) - len(accumulated_output_tokens) - ], - input_images=req.image_data, - output_tokens=accumulated_output_tokens, - output_logprobs=accumulated_output_logprobs, - output_versions=accumulated_versions, - stop_reason=stop_reason, - latency=latency, - ttft=latency, # Simplified for non-streaming - tokenizer=req.tokenizer, - processor=req.processor, - ) - return response - - def init_weights_update_group(self, meta: WeightUpdateMeta) -> None: - assert meta.type == current_platform.communication_backend - assert not self.distributed_weight_update_initialized, ( - "Weight update group already initialized." - ) - - if self.engine is None: - raise RuntimeError( - "Local inference engine is not initialized, " - "cannot init weight update group." - ) - # FIXME: get the real rank_offset from local process rank and tp size - rank_offset = 1 - self.backend.init_update_weight_group(self.engine, meta, rank_offset) - self.logger.info( - f"Initialized {current_platform.communication_backend.upper()} group " - f"for distributed weight update for {meta.nccl_group_name}." - ) - self.distributed_weight_update_initialized = True - - def update_weights_from_distributed( - self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] - ): - assert meta.type == current_platform.communication_backend - - if self.engine is None: - raise RuntimeError( - "Local inference engine is not initialized, cannot update weights." - ) - - self.backend.update_weight_xccl(self.engine, meta, param_specs) - - def update_weights_from_disk(self, meta: WeightUpdateMeta): - assert meta.type == "disk" - - if self.engine is None: - raise RuntimeError( - "Local inference engine is not initialized, cannot update weights." - ) - - # Validate experiment and trial names - if self.config.experiment_name is None or self.config.trial_name is None: - raise RuntimeError( - "Experiment and trial names must be set for disk-based weight updates." - ) - - # Wait for training engine to signal that weights are ready - update_name = names.update_weights_from_disk( - self.config.experiment_name, - self.config.trial_name, - str(self.get_version()), - ) - save_timestamp = float(name_resolve.wait(update_name, timeout=120)) - load_timestamp = time.time() - - self.logger.info( - f"Begin update weights from {meta.path}, " - f"responded in {(load_timestamp - save_timestamp) * 1000:.2f} ms" - ) - - # Update weights from disk via backend - self.backend.update_weight_disk(self.engine, str(meta.path)) - - self.logger.info( - f"Loading weights done in {(time.time() - load_timestamp) * 1000:.2f} ms" - ) - - def submit( - self, - data: dict[str, Any], - workflow: RolloutWorkflow | None = None, - workflow_builder: Callable | None = None, - should_accept: Callable | None = None, - ) -> None: - """Submit a request to the inference engine and return immediately. - - Parameters - ---------- - data : Dict[str, Any] - The input data for rollout - workflow : RolloutWorkflow, optional - The workflow instance to run - workflow_builder : Callable, optional - A builder to create a workflow instance - should_accept : Callable, optional - A function to decide whether to accept a trajectory - """ - return self.workflow_executor.submit( - data, - workflow=workflow, - workflow_builder=workflow_builder, - should_accept=should_accept, - ) - - def wait(self, count: int, timeout: float | None = None) -> dict[str, Any]: - """Wait for a specified number of requests to complete. - - Parameters - ---------- - count : int - The number of accepted trajectories to wait for - timeout : float, optional - Timeout in seconds - - Returns - ------- - Dict[str, Any] - A concatenated batch of trajectories - """ - return self.workflow_executor.wait(count, timeout=timeout) - - def wait_quiet( - self, count: int, timeout: float | None = None - ) -> dict[str, Any] | None: - try: - return self.workflow_executor.wait(count, timeout=timeout) - except TimeoutError: - return "NO_RESULT" - - def rollout_batch( - self, - data: list[dict[str, Any]], - workflow: RolloutWorkflow | None = None, - workflow_builder: Callable | None = None, - should_accept: Callable | None = None, - ) -> dict[str, Any]: - """Submit a batch of requests and wait for results. - - Parameters - ---------- - data : List[Dict[str, Any]] - A list of input data dictionaries for rollout - workflow : RolloutWorkflow, optional - The workflow instance to run - workflow_builder : Callable, optional - A builder to create a workflow instance - should_accept : Callable, optional - A function to decide whether to accept a trajectory - - Returns - ------- - Dict[str, Any] - A concatenated batch of trajectory results - """ - return self.workflow_executor.rollout_batch( - data=data, - workflow=workflow, - workflow_builder=workflow_builder, - should_accept=should_accept, - ) - - def prepare_batch( - self, - dataloader: StatefulDataLoader, - workflow: RolloutWorkflow | None = None, - workflow_builder: Callable | None = None, - should_accept: Callable | None = None, - ): - """Asynchronously submit and wait until a full batch is ready. - - Parameters - ---------- - dataloader : StatefulDataLoader - The data loader to pull data from - workflow : RolloutWorkflow, optional - The workflow instance to run - workflow_builder : Callable, optional - A builder to create a workflow instance - should_accept : Callable, optional - A function to decide whether to accept a trajectory - - Returns - ------- - Dict[str, Any] - A full batch of trajectory results - """ - return self.workflow_executor.prepare_batch( - dataloader=dataloader, - workflow=workflow, - workflow_builder=workflow_builder, - should_accept=should_accept, - ) - - def pause(self): - """Pause request submission for async rollout. - - Used during evaluation to prevent data over generation. - """ - return self.workflow_executor.pause() - - def resume(self): - """Resume request submission for async rollout.""" - return self.workflow_executor.resume() - - def pause_generation(self): - """Pause request submission for async rollout.""" - try: - self.backend.pause_generation() - except NotImplementedError: - self.logger.warning("Backend does not support pause operation") - - # The above http request may require some time to be scheduled and executed. - # The following line waits until all requests are indeed dropped. - time.sleep(self.config.pause_grace_period) - - def continue_generation(self): - """Resume request submission for async rollout.""" - try: - self.backend.continue_generation() - except NotImplementedError: - self.logger.warning("Backend does not support resume operation") diff --git a/areal/engine/sglang_local.py b/areal/engine/sglang_local.py deleted file mode 100644 index d7befda08..000000000 --- a/areal/engine/sglang_local.py +++ /dev/null @@ -1,368 +0,0 @@ -import time -from collections.abc import Callable -from concurrent.futures import Future -from typing import Any - -from torchdata.stateful_dataloader import StatefulDataLoader - -from areal.api.cli_args import InferenceEngineConfig -from areal.api.engine_api import InferenceEngine -from areal.api.io_struct import ( - ModelRequest, - ModelResponse, - ParamSpec, - WeightUpdateMeta, -) -from areal.api.workflow_api import RolloutWorkflow -from areal.core.local_inf_engine import LocalInfEngine -from areal.platforms import current_platform - - -class SGLangLocalBackend: - """SGLang-specific backend implementation for local inference. - - This backend wraps SGLang's native Engine API for in-process inference. - """ - - def create_engine(self, engine_args: dict[str, Any]) -> Any: - """Create a local SGLang engine instance. - - Parameters - ---------- - engine_args : Dict[str, Any] - Arguments to pass to sglang.Engine constructor - - Returns - ------- - Any - The created SGLang Engine instance - """ - import sglang as sgl - - engine = sgl.Engine(**engine_args) - return engine - - async def async_generation(self, engine: Any, req: ModelRequest) -> ModelResponse: - """Perform async generation using the local SGLang engine. - - Parameters - ---------- - engine : Any - The SGLang Engine instance - req : ModelRequest - The generation request containing input and parameters - - Returns - ------- - ModelResponse - The generated response with tokens, logprobs, and metadata - """ - # Prepare request payload - gconfig = req.gconfig - stop_token_ids = gconfig.stop_token_ids - - sampling_params = { - "top_p": gconfig.top_p, - "top_k": gconfig.top_k, - "max_new_tokens": gconfig.max_new_tokens, - "temperature": 0.0 if gconfig.greedy else gconfig.temperature, - "stop_token_ids": stop_token_ids, - "frequency_penalty": gconfig.frequency_penalty, - } - - if gconfig.stop: - sampling_params["stop"] = gconfig.stop - - # Make request - start_time = time.perf_counter() - - # Call SGLang's async_generate method - outputs = await engine.async_generate( - input_ids=req.input_ids, - sampling_params=sampling_params, - return_logprob=True, - ) - - # Parse response - meta_info = outputs["meta_info"] - finish_reason = meta_info["finish_reason"] - stop_reason = finish_reason["type"] - stop_message = finish_reason.get("message", "") - - # Handle early abort - if stop_reason == "abort" and stop_message.startswith("Abort before prefill"): - latency = time.perf_counter() - start_time - return ModelResponse( - input_tokens=req.input_ids, - input_images=req.image_data, - output_tokens=[], - output_logprobs=[], - output_versions=[], - stop_reason=stop_reason, - latency=latency, - ttft=latency, - tokenizer=req.tokenizer, - processor=req.processor, - ) - - # Extract output tokens and logprobs - output_tokens = [x[1] for x in meta_info["output_token_logprobs"]] - output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]] - - latency = time.perf_counter() - start_time - - return ModelResponse( - input_tokens=req.input_ids, - input_images=req.image_data, - output_tokens=output_tokens, - output_logprobs=output_logprobs, - output_versions=[], # Will be filled by LocalInfEngine - stop_reason=stop_reason, - latency=latency, - ttft=latency, - tokenizer=req.tokenizer, - processor=req.processor, - ) - - def update_weight_disk(self, engine: Any, model_path: str) -> None: - """Update weights from disk synchronously. - - Parameters - ---------- - engine : Any - The SGLang Engine instance - model_path : str - Path to the model weights on disk - """ - # otherwise will encounter" eventloop is already running" issue - # def _run_in_thread(): - # print(11111111111111) - # # Call SGLang's update_weights_from_disk method - # try: - # cur_loop = asyncio.get_event_loop() - # except RuntimeError: - # cur_loop = None - # loop = asyncio.new_event_loop() - # asyncio.set_event_loop(loop) - # try: - # # Call SGLang's update_weights_from_distributed method - # from sglang.srt.managers.io_struct import ( - # UpdateWeightFromDiskReqInput, - # ) - engine.update_weights_from_disk(model_path=model_path) - # obj = UpdateWeightFromDiskReqInput( - # model_path=model_path, - # abort_all_requests=False, - # ) - - # loop = asyncio.get_running_loop() - # future = asyncio.run_coroutine_threadsafe( - # engine.tokenizer_manager.update_weights_from_disk(obj, None), - # loop - # ) - # return future.result() # This blocks until complete - # print(2222222222, flush=True) - # finally: - # asyncio.set_event_loop(cur_loop) - # loop.close() - - # from concurrent.futures import ThreadPoolExecutor - - # with ThreadPoolExecutor() as executor: - # future = executor.submit(_run_in_thread) - # _ = future.result() - - def update_weight_xccl( - self, - engine: Any, - meta: WeightUpdateMeta, - param_specs: list[ParamSpec], - ) -> None: - """Update weights from distributed memory via NCCL/XCCL synchronously. - - Parameters - ---------- - engine : Any - The SGLang Engine instance - meta : WeightUpdateMeta - Metadata containing communication group info - param_specs : List[ParamSpec] - Specifications for parameters to be updated - """ - # Call SGLang's update_weights_from_distributed method - engine.update_weights_from_distributed( - names=[pspec.name for pspec in param_specs], - dtypes=[pspec.dtype for pspec in param_specs], - shapes=[pspec.shape for pspec in param_specs], - group_name=meta.nccl_group_name, - ) - - def init_update_weight_group( - self, engine: Any, meta: WeightUpdateMeta, rank_offset: int - ) -> None: - """Initialize weight update communication group synchronously. - - Parameters - ---------- - engine : Any - The SGLang Engine instance - meta : WeightUpdateMeta - Metadata containing communication backend configuration - rank_offset : int - Rank offset for this engine in the communication group - """ - assert meta.alloc_mode is not None - if meta.alloc_mode.gen.pp_size != 1: - raise NotImplementedError( - "NCCL weight update with PP size > 1 is not implemented yet." - ) - - # Call SGLang's init_weights_update_group method - engine.init_weights_update_group( - master_address=meta.nccl_master_address, - master_port=str(meta.nccl_master_port), - rank_offset=rank_offset, - world_size=meta.alloc_mode.gen.world_size + 1, - backend=current_platform.communication_backend, - group_name=meta.nccl_group_name, - ) - - def destroy(self, engine: Any) -> None: - """Destroy the engine and release resources. - - Parameters - ---------- - engine : Any - The SGLang Engine instance to destroy - """ - # SGLang engines typically don't need explicit cleanup - # but we include this for consistency with the protocol - if hasattr(engine, "shutdown"): - engine.shutdown() - - -class LocalSGLangEngine(InferenceEngine): - """SGLang local inference engine. - - This class delegates all functionality to LocalInfEngine with - an SGLangLocalBackend implementation. It maintains the same public API. - - Parameters - ---------- - config : InferenceEngineConfig - Configuration for the inference engine - """ - - def __init__(self, config: InferenceEngineConfig): - self.config = config - # Pure composition - create internal engine with SGLang backend - self._engine = LocalInfEngine(config, SGLangLocalBackend()) - - def configure(self, config): - self.config = config - self._engine.configure(config) - - def create_engine(self, engine_args): - return self._engine.create_engine(engine_args) - - def destroy_engine(self): - self._engine.destroy_engine() - - def initialize( - self, - engine_id: str | None = None, - train_data_parallel_size: int | None = None, - ): - """Initialize the engine by creating the local SGLang engine. - - Parameters - ---------- - engine_id : Optional[str] - Unique identifier for this engine instance - engine_args : Optional[Dict[str, Any]] - Arguments to pass to sglang.Engine constructor - train_data_parallel_size : int | None - Data parallel size of the training engine - """ - return self._engine.initialize(engine_id, train_data_parallel_size) - - def destroy(self): - """Destroy the engine and clean up resources.""" - return self._engine.destroy() - - def set_version(self, version: int): - """Set the current weight version.""" - return self._engine.set_version(version) - - def get_version(self) -> int: - """Get the current weight version.""" - return self._engine.get_version() - - async def agenerate(self, req: ModelRequest) -> ModelResponse: - """Asynchronously generate a response for the given request.""" - return await self._engine.agenerate(req) - - def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: - """Initialize the weight update process group.""" - return self._engine.init_weights_update_group(meta) - - def update_weights_from_distributed( - self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] - ) -> Future[None]: - """Update weights from distributed memory.""" - return self._engine.update_weights_from_distributed(meta, param_specs) - - def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: - """Update weights from disk.""" - return self._engine.update_weights_from_disk(meta) - - def submit( - self, - data: dict[str, Any], - workflow: RolloutWorkflow | None = None, - workflow_builder: Callable | None = None, - should_accept: Callable | None = None, - ) -> None: - """Submit a request to the inference engine.""" - return self._engine.submit(data, workflow, workflow_builder, should_accept) - - def wait(self, count: int, timeout: float | None = None) -> dict[str, Any]: - """Wait for a specified number of requests to complete.""" - return self._engine.wait(count, timeout) - - def rollout_batch( - self, - data: list[dict[str, Any]], - workflow: RolloutWorkflow | None = None, - workflow_builder: Callable | None = None, - should_accept: Callable | None = None, - ) -> dict[str, Any]: - """Submit a batch of requests and wait for results.""" - return self._engine.rollout_batch( - data, workflow, workflow_builder, should_accept - ) - - def prepare_batch( - self, - dataloader: StatefulDataLoader, - workflow: RolloutWorkflow | None = None, - workflow_builder: Callable | None = None, - should_accept: Callable | None = None, - ): - """Asynchronously submit and wait until a full batch is ready.""" - return self._engine.prepare_batch( - dataloader, workflow, workflow_builder, should_accept - ) - - def pause(self): - """Pause request submission for async rollout.""" - return self._engine.pause() - - def resume(self): - """Resume request submission for async rollout.""" - return self._engine.resume() - - def wait_quiet( - self, count: int, timeout: float | None = None - ) -> dict[str, Any] | None: - return self._engine.wait_quiet(count=count, timeout=timeout) diff --git a/areal/engine/vllm_local.py b/areal/engine/vllm_local.py deleted file mode 100644 index 18908cd86..000000000 --- a/areal/engine/vllm_local.py +++ /dev/null @@ -1,373 +0,0 @@ -import asyncio -import time -import uuid -from collections.abc import Callable -from concurrent.futures import Future -from typing import Any - -from torchdata.stateful_dataloader import StatefulDataLoader - -from areal.api.cli_args import InferenceEngineConfig -from areal.api.engine_api import InferenceEngine -from areal.api.io_struct import ( - ModelRequest, - ModelResponse, - ParamSpec, - WeightUpdateMeta, -) -from areal.api.workflow_api import RolloutWorkflow -from areal.core.local_inf_engine import LocalInfEngine -from areal.platforms import current_platform - - -class VLLMLocalBackend: - """vLLM-specific backend implementation for local inference. - - This backend wraps vLLM's native AsyncLLMEngine API for in-process inference. - """ - - def create_engine(self, engine_args: dict[str, Any]) -> Any: - """Create a local vLLM engine instance. - - Parameters - ---------- - engine_args : Dict[str, Any] - Arguments to pass to vLLM AsyncLLMEngine constructor - - Returns - ------- - Any - The created vLLM AsyncLLMEngine instance - """ - from vllm import AsyncEngineArgs, AsyncLLMEngine - - engine_args.pop("host", None) - engine_args.pop("port", None) - engine_args.pop("uvicorn_log_level", None) - - engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) - return engine - - async def async_generation(self, engine: Any, req: ModelRequest) -> ModelResponse: - """Perform async generation using the local vLLM engine. - - Parameters - ---------- - engine : Any - The vLLM AsyncLLMEngine instance - req : ModelRequest - The generation request containing input and parameters - - Returns - ------- - ModelResponse - The generated response with tokens, logprobs, and metadata - """ - from vllm import SamplingParams - - # Prepare request payload - gconfig = req.gconfig - stop_token_ids = gconfig.stop_token_ids - - sampling_params = SamplingParams( - top_p=gconfig.top_p, - top_k=gconfig.top_k, - max_tokens=gconfig.max_new_tokens, - temperature=0.0 if gconfig.greedy else gconfig.temperature, - stop_token_ids=stop_token_ids, - logprobs=0, # Request logprobs - ) - - # Make request - start_time = time.perf_counter() - - # Generate unique request ID - request_id = uuid.uuid4().hex - - # Call vLLM's generate method which returns an async generator - from vllm.inputs.data import TokensPrompt - - results_generator = engine.generate( - prompt=TokensPrompt(prompt_token_ids=req.input_ids), - sampling_params=sampling_params, - request_id=request_id, - ) - - # Iterate through the generator to get the final result - final_output = None # RequestOutput - async for request_output in results_generator: - final_output = request_output - - # Parse response - if final_output is None: - latency = time.perf_counter() - start_time - return ModelResponse( - input_tokens=req.input_ids, - input_images=req.image_data, - output_tokens=[], - output_logprobs=[], - output_versions=[], - stop_reason="abort", - latency=latency, - ttft=latency, - tokenizer=req.tokenizer, - processor=req.processor, - ) - - # Extract first completion output - assert len(final_output.outputs) == 1 - completion_output = final_output.outputs[0] - stop_reason = completion_output.finish_reason - - # Extract output tokens from token_ids - output_tokens = completion_output.token_ids - - # Extract logprobs - vLLM returns logprobs as a list of dicts - output_logprobs = [] - for token_logprobs, token_id in zip(completion_output.logprobs, output_tokens): - output_logprobs.append(token_logprobs[token_id].logprob) - - latency = time.perf_counter() - start_time - - return ModelResponse( - input_tokens=req.input_ids, - input_images=req.image_data, - output_tokens=output_tokens, - output_logprobs=output_logprobs, - output_versions=[], # Will be filled by LocalInfEngine - stop_reason=stop_reason, - latency=latency, - ttft=latency, - tokenizer=req.tokenizer, - processor=req.processor, - ) - - def update_weight_disk(self, engine: Any, model_path: str) -> None: - """Update weights from disk synchronously. - - Parameters - ---------- - engine : Any - The vLLM AsyncLLMEngine instance - model_path : str - Path to the model weights on disk - """ - loop = asyncio.new_event_loop() - loop.run_until_complete( - engine.collective_rpc("areal_injected_update_weight", model_path) - ) - return None - - def update_weight_xccl( - self, - engine: Any, - meta: WeightUpdateMeta, - param_specs: list[ParamSpec], - ) -> None: - """Update weights from distributed memory via NCCL/XCCL synchronously. - - Parameters - ---------- - engine : Any - The vLLM AsyncLLMEngine instance - meta : WeightUpdateMeta - Metadata containing communication group info - param_specs : List[ParamSpec] - Specifications for parameters to be updated - """ - loop = asyncio.new_event_loop() - task = engine.collective_rpc( - "set_weight_meta", - args=( - [pspec.name for pspec in param_specs], - [pspec.dtype for pspec in param_specs], - [pspec.shape for pspec in param_specs], - ), - ) - loop.run_until_complete(task) - loop.run_until_complete( - engine.collective_rpc("areal_injected_update_weight_xccl") - ) - return None - - def init_update_weight_group( - self, engine: Any, meta: WeightUpdateMeta, rank_offset: int - ) -> None: - """Initialize weight update communication group synchronously. - - Parameters - ---------- - engine : Any - The vLLM AsyncLLMEngine instance - meta : WeightUpdateMeta - Metadata containing communication backend configuration - rank_offset : int - Rank offset for this engine in the communication group - """ - task = engine.collective_rpc( - "init_update_weight_group", - args=( - meta.nccl_master_address, - str(meta.nccl_master_port), - rank_offset, - meta.alloc_mode.gen.world_size + 1, - current_platform.communication_backend, - meta.nccl_group_name, - ), - ) - loop = asyncio.new_event_loop() - loop.run_until_complete(task) - return None - - def destroy(self, engine: Any) -> None: - """Destroy the engine and release resources. - - Parameters - ---------- - engine : Any - The vLLM AsyncLLMEngine instance to destroy - """ - # vLLM engines typically don't need explicit cleanup - # but we include this for consistency with the protocol - if hasattr(engine, "shutdown"): - engine.shutdown() - - def pause_generation(self): - raise NotImplementedError() - - def continue_generation(self): - raise NotImplementedError() - - -class LocalvLLMEngine(InferenceEngine): - """vLLM local inference engine. - - This class delegates all functionality to LocalInfEngine with - a VLLMLocalBackend implementation. It maintains the same public API. - - Note: vLLM does not support weight updates, so update_weights_from_disk - and update_weights_from_distributed will raise NotImplementedError. - - Parameters - ---------- - config : InferenceEngineConfig - Configuration for the inference engine - """ - - def __init__(self, config: InferenceEngineConfig): - self.config = config - # Pure composition - create internal engine with vLLM backend - self._engine = LocalInfEngine(config, VLLMLocalBackend()) - - def configure(self, config): - self.config = config - self._engine.configure(config) - - def create_engine(self, engine_args): - return self._engine.create_engine(engine_args) - - def destroy_engine(self): - self._engine.destroy_engine() - - def initialize( - self, - engine_id: str | None = None, - train_data_parallel_size: int | None = None, - ): - """Initialize the engine by creating the local vLLM engine. - - Parameters - ---------- - engine_id : Optional[str] - Unique identifier for this engine instance - engine_args : Optional[Dict[str, Any]] - Arguments to pass to vLLM AsyncLLMEngine constructor - train_data_parallel_size : int | None - Data parallel size of the training engine - """ - return self._engine.initialize(engine_id, train_data_parallel_size) - - def destroy(self): - """Destroy the engine and clean up resources.""" - return self._engine.destroy() - - def set_version(self, version: int): - """Set the current weight version.""" - return self._engine.set_version(version) - - def get_version(self) -> int: - """Get the current weight version.""" - return self._engine.get_version() - - async def agenerate(self, req: ModelRequest) -> ModelResponse: - """Asynchronously generate a response for the given request.""" - return await self._engine.agenerate(req) - - def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: - """Initialize the weight update process group. - - Note: Not supported by vLLM. - """ - return self._engine.init_weights_update_group(meta) - - def update_weights_from_distributed( - self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] - ) -> Future[None]: - """Update weights from distributed memory. - - Note: Not supported by vLLM. - """ - return self._engine.update_weights_from_distributed(meta, param_specs) - - def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]: - """Update weights from disk. - - Note: Not supported by vLLM. - """ - return self._engine.update_weights_from_disk(meta) - - def submit( - self, - data: dict[str, Any], - workflow: RolloutWorkflow | None = None, - workflow_builder: Callable | None = None, - should_accept: Callable | None = None, - ) -> None: - """Submit a request to the inference engine.""" - return self._engine.submit(data, workflow, workflow_builder, should_accept) - - def wait(self, count: int, timeout: float | None = None) -> dict[str, Any]: - """Wait for a specified number of requests to complete.""" - return self._engine.wait(count, timeout) - - def rollout_batch( - self, - data: list[dict[str, Any]], - workflow: RolloutWorkflow | None = None, - workflow_builder: Callable | None = None, - should_accept: Callable | None = None, - ) -> dict[str, Any]: - """Submit a batch of requests and wait for results.""" - return self._engine.rollout_batch( - data, workflow, workflow_builder, should_accept - ) - - def prepare_batch( - self, - dataloader: StatefulDataLoader, - workflow: RolloutWorkflow | None = None, - workflow_builder: Callable | None = None, - should_accept: Callable | None = None, - ): - """Asynchronously submit and wait until a full batch is ready.""" - return self._engine.prepare_batch( - dataloader, workflow, workflow_builder, should_accept - ) - - def pause(self): - """Pause request submission for async rollout.""" - return self._engine.pause() - - def resume(self): - """Resume request submission for async rollout.""" - return self._engine.resume() diff --git a/areal/tests/test_inference_engines.py b/areal/tests/test_inference_engines.py index 649cdec7d..bb66bdf41 100644 --- a/areal/tests/test_inference_engines.py +++ b/areal/tests/test_inference_engines.py @@ -1,4 +1,4 @@ -"""Unified test suite for inference engines (vLLM and SGLang, both local and remote).""" +"""Test suite for remote inference engines (vLLM and SGLang).""" import os import subprocess @@ -45,27 +45,11 @@ def _dummy_reward_fn(*args, **kwargs): @pytest.fixture( - params=[ - # ("vllm", "remote"), - # ("vllm", "local"), - ("sglang", "remote"), - # ("sglang", "local"), - ], - ids=[ - # "vllm-remote", - # "vllm-local", - "sglang-remote", - # "sglang-local", - ], + params=[("vllm", "remote"), ("sglang", "remote")], + ids=["vllm-remote", "sglang-remote"], ) def inference_engine(request): - """Unified fixture that provides any inference engine (vLLM/SGLang, local/remote). - - This fixture: - 1. Launches the appropriate server (for remote) or prepares engine args (for local) - 2. Yields engine metadata for test initialization - 3. Cleans up resources after all tests complete - """ + """Fixture for remote inference engines only (vLLM and SGLang).""" backend, mode = request.param # Skip if vLLM is not installed @@ -74,13 +58,15 @@ def inference_engine(request): from areal.utils import seeding - expr_name = f"test_{mode}_{backend}_engine" + expr_name = f"test_remote_{backend}_engine" trial_name = "trial_0" seeding.set_random_seed(1, expr_name) port, dist_port = network.find_free_ports(2) host = network.gethostip() + + # Configure SGLang sglang_config = SGLangConfig( skip_tokenizer_init=True, model_path=MODEL_PATH, @@ -94,6 +80,8 @@ def inference_engine(request): port=port, dist_init_addr=f"{host}:{dist_port}", ) + + # Configure vLLM vllm_config = vLLMConfig( skip_tokenizer_init=False, model=MODEL_PATH, @@ -105,85 +93,59 @@ def inference_engine(request): host=host, port=port, ) + config = InferenceEngineConfig( experiment_name=expr_name, trial_name=trial_name, ) - # Initialize engine based on backend and mode - if mode == "remote": - # Launch server - - if backend == "vllm": - from areal.engine.vllm_remote import RemotevLLMEngine - - cmd = vLLMConfig.build_cmd_from_args(vllm_args) - engine_class = RemotevLLMEngine - else: # sglang - from areal.engine.sglang_remote import RemoteSGLangEngine - - cmd = SGLangConfig.build_cmd_from_args(sglang_args) - engine_class = RemoteSGLangEngine - - # Launch process - cmd = cmd.replace("\\\n", " ").replace("\\", " ") - process = subprocess.Popen( - cmd.split(), - text=True, - stdout=sys.stdout, - stderr=sys.stdout, - ) - base_url = f"http://{host}:{port}" - tik = time.time() - while time.time() - tik < RUN_SERVER_TIMEOUT: - if check_server_health(base_url): - break - time.sleep(1) - if time.time() - tik > RUN_SERVER_TIMEOUT: - process.terminate() - raise RuntimeError(f"{backend.upper()} server launch failed") - - # Set environment for remote engine - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{host}:{port}" - - engine = engine_class(config) - - yield { - "engine": engine, - "backend": backend, - "mode": mode, - "expr_name": expr_name, - "trial_name": trial_name, - "host": host, - "port": port, - } - - # Cleanup + # Launch remote server and initialize engine + if backend == "vllm": + from areal.engine.vllm_remote import RemotevLLMEngine + + cmd = vLLMConfig.build_cmd_from_args(vllm_args) + engine_class = RemotevLLMEngine + else: # sglang + from areal.engine.sglang_remote import RemoteSGLangEngine + + cmd = SGLangConfig.build_cmd_from_args(sglang_args) + engine_class = RemoteSGLangEngine + + # Launch process + cmd = cmd.replace("\\\n", " ").replace("\\", " ") + process = subprocess.Popen( + cmd.split(), + text=True, + stdout=sys.stdout, + stderr=sys.stdout, + ) + base_url = f"http://{host}:{port}" + tik = time.time() + while time.time() - tik < RUN_SERVER_TIMEOUT: + if check_server_health(base_url): + break + time.sleep(1) + if time.time() - tik > RUN_SERVER_TIMEOUT: process.terminate() + raise RuntimeError(f"{backend.upper()} server launch failed") + + # Set environment for remote engine + os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{host}:{port}" + + engine = engine_class(config) + + yield { + "engine": engine, + "backend": backend, + "mode": mode, + "expr_name": expr_name, + "trial_name": trial_name, + "host": host, + "port": port, + } - else: # local - if backend == "vllm": - from areal.engine.vllm_local import LocalvLLMEngine - - engine_args = vllm_args - engine_class = LocalvLLMEngine - else: # sglang - from areal.engine.sglang_local import LocalSGLangEngine - - engine_args = sglang_args - engine_class = LocalSGLangEngine - - engine = engine_class(config) - engine.create_engine(engine_args=engine_args) - - yield { - "engine": engine, - "backend": backend, - "mode": mode, - "expr_name": expr_name, - "trial_name": trial_name, - } - engine.destroy_engine() + # Cleanup + process.terminate() # ============================================================================ From a25d37892dc0090d0faa01d958a6996c91e533aa Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Fri, 31 Oct 2025 16:59:33 +0800 Subject: [PATCH 34/52] minor revert --- areal/core/__init__.py | 6 ------ areal/core/async_task_runner.py | 19 +++++-------------- 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/areal/core/__init__.py b/areal/core/__init__.py index ddc08c8cf..28fa01204 100644 --- a/areal/core/__init__.py +++ b/areal/core/__init__.py @@ -1,9 +1,5 @@ """Core components for AREAL.""" -from .local_inf_engine import ( - LocalInfBackendProtocol, - LocalInfEngine, -) from .remote_inf_engine import ( RemoteInfBackendProtocol, RemoteInfEngine, @@ -15,8 +11,6 @@ ) __all__ = [ - "LocalInfBackendProtocol", - "LocalInfEngine", "RemoteInfBackendProtocol", "RemoteInfEngine", "StalenessManager", diff --git a/areal/core/async_task_runner.py b/areal/core/async_task_runner.py index 66520f0ae..f9cb2f63d 100644 --- a/areal/core/async_task_runner.py +++ b/areal/core/async_task_runner.py @@ -170,7 +170,7 @@ def __init__( self.max_queue_size = max_queue_size self.poll_wait_time = poll_wait_time self.poll_sleep_time = poll_sleep_time - self._enable_tracing = enable_tracing + self.enable_tracing = enable_tracing # Thread control self.exiting = threading.Event() @@ -188,22 +188,13 @@ def __init__( self.result_cache: list[_TimedResult[T]] = [] # Thread exception handling - self._lock = threading.Lock() + self._thread_exception_lock = threading.Lock() self._thread_exception: Exception | None = None # Will be set in initialize() self.logger = None self.thread: threading.Thread | None = None - def set_enable_tracing(self, enabled: bool): - with self._lock: - self._enable_tracing = enabled - - @property - def enable_tracing(self): - with self._lock: - return self._enable_tracing - def initialize(self, logger=None): """Initialize and start the background thread. @@ -240,7 +231,7 @@ def _check_thread_health(self): RuntimeError If the background thread has died due to an exception. """ - with self._lock: + with self._thread_exception_lock: if self._thread_exception is not None: raise RuntimeError( "AsyncTaskRunner thread has died due to an exception. " @@ -256,7 +247,7 @@ def _run_thread(self): uvloop.run(self._run_async_loop()) except Exception as e: # Store exception for thread-safe access - with self._lock: + with self._thread_exception_lock: self._thread_exception = e if self.logger: self.logger.error( @@ -348,7 +339,7 @@ async def _run_async_loop(self): ) if self.enable_tracing and self.logger: self.logger.info( - f"AsyncTaskRunner: Completed task ID: {tid}. " + f"AsyncTaskRunner: Completed task {tid}. " f"Running: {len(running_tasks)}" ) except queue.Full: From 99fe517a15362c712ca5a631f71d429968d64f9d Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Fri, 31 Oct 2025 17:02:59 +0800 Subject: [PATCH 35/52] revert realhf --- realhf/api/cli_args.py | 2 +- realhf/api/core/system_api.py | 6 +++--- realhf/apps/main.py | 4 ++-- realhf/experiments/async_exp/async_rl_exp.py | 12 ++++++------ realhf/experiments/common/common.py | 6 +++--- realhf/scheduler/client.py | 2 +- realhf/scheduler/slurm/client.py | 4 ++-- realhf/system/controller.py | 6 +++--- 8 files changed, 21 insertions(+), 21 deletions(-) diff --git a/realhf/api/cli_args.py b/realhf/api/cli_args.py index 0a9934028..6de76c710 100644 --- a/realhf/api/cli_args.py +++ b/realhf/api/cli_args.py @@ -980,7 +980,7 @@ class BaseExperimentConfig: partition: str = field( default="dev", metadata={"help": "SLURM partition for running the experiment."} ) - scheduling_strategy: str = field( + schedule_strategy: str = field( default="empty_first", metadata={"help": "Resource scheduling strategy."} ) wandb: WandBConfig = field( diff --git a/realhf/api/core/system_api.py b/realhf/api/core/system_api.py index 47c695792..ea30213fe 100644 --- a/realhf/api/core/system_api.py +++ b/realhf/api/core/system_api.py @@ -34,7 +34,7 @@ class ExpStatus(Enum): @dataclasses.dataclass -class SchedulingSpec: +class Scheduling: # TODO: add partition cpu: int gpu: int @@ -173,7 +173,7 @@ class MasterWorker: @dataclasses.dataclass class TasksGroup: count: int - scheduling: SchedulingSpec + scheduling: Scheduling @dataclasses.dataclass @@ -458,7 +458,7 @@ class Experiment: """Base class for defining the procedure of an experiment.""" def scheduling_setup(self) -> ExperimentScheduling: - """Returns the SchedulingSpec of all workers.""" + """Returns the Scheduling of all workers.""" raise NotImplementedError() def initial_setup(self) -> ExperimentConfig | List[ExperimentConfig]: diff --git a/realhf/apps/main.py b/realhf/apps/main.py index 5165d1d24..e1ca08764 100644 --- a/realhf/apps/main.py +++ b/realhf/apps/main.py @@ -47,7 +47,7 @@ def _submit_workers( job_environs = {**environs, **sch_cfg.scheduling.env_vars} cmd = sched_client.remote_worker_cmd(expr_name, trial_name, debug, worker_type) - logger.debug(f"SchedulingSpec worker {worker_type}, {scheduling_configs}") + logger.debug(f"Scheduling worker {worker_type}, {scheduling_configs}") nodelist = sch_cfg.scheduling.nodelist exclude = sch_cfg.scheduling.exclude @@ -335,7 +335,7 @@ def main(): choices=["local", "slurm", "ray"], ) subparser.add_argument( - "--scheduling_strategy", + "--schedule_strategy", default="empty_first", choices=["empty_first", "allocated_first"], help="Schedule strategy for scheduler. Currently only effective in slurm mode. " diff --git a/realhf/experiments/async_exp/async_rl_exp.py b/realhf/experiments/async_exp/async_rl_exp.py index 313606400..3c23f13d1 100755 --- a/realhf/experiments/async_exp/async_rl_exp.py +++ b/realhf/experiments/async_exp/async_rl_exp.py @@ -34,7 +34,7 @@ GserverManager, ModelWorker, RolloutWorker, - SchedulingSpec, + Scheduling, TasksGroup, ) from realhf.api.quickstart.device_mesh import RPCAllocation @@ -90,7 +90,7 @@ def scheduling_setup(self) -> ExperimentScheduling: return ExperimentScheduling( master_worker=TasksGroup( count=1, - scheduling=SchedulingSpec( + scheduling=Scheduling( cpu=self.cpus_per_master_worker, gpu=0, mem=self.mem_per_master_worker, @@ -101,7 +101,7 @@ def scheduling_setup(self) -> ExperimentScheduling: ), model_worker=TasksGroup( count=train_world_size, - scheduling=SchedulingSpec( + scheduling=Scheduling( cpu=self.cpus_per_model_worker, gpu=1, mem=self.mem_per_model_worker, @@ -112,7 +112,7 @@ def scheduling_setup(self) -> ExperimentScheduling: ), generation_server=TasksGroup( count=gen_world_size // gen_tp_size, - scheduling=SchedulingSpec( + scheduling=Scheduling( cpu=self.cpus_per_generation_server, gpu=gen_tp_size, mem=self.mem_per_generation_server, @@ -123,7 +123,7 @@ def scheduling_setup(self) -> ExperimentScheduling: ), gserver_manager=TasksGroup( count=1, - scheduling=SchedulingSpec( + scheduling=Scheduling( cpu=self.cpus_per_gserver_manager, gpu=0, mem=self.mem_per_gserver_manager, @@ -134,7 +134,7 @@ def scheduling_setup(self) -> ExperimentScheduling: ), rollout_worker=TasksGroup( count=self.n_rollout_workers or train_world_size, - scheduling=SchedulingSpec( + scheduling=Scheduling( cpu=self.cpus_per_rollout_worker, gpu=0, mem=self.mem_per_rollout_worker, diff --git a/realhf/experiments/common/common.py b/realhf/experiments/common/common.py index 2507fe896..370f16fd1 100644 --- a/realhf/experiments/common/common.py +++ b/realhf/experiments/common/common.py @@ -34,7 +34,7 @@ ExperimentConfig, ExperimentScheduling, ModelWorker, - SchedulingSpec, + Scheduling, TasksGroup, ) from realhf.api.quickstart.device_mesh import ( @@ -163,7 +163,7 @@ def scheduling_setup(self) -> ExperimentScheduling: return ExperimentScheduling( master_worker=TasksGroup( count=1, - scheduling=SchedulingSpec( + scheduling=Scheduling( cpu=self.cpus_per_master_worker, gpu=0, mem=self.mem_per_master_worker, @@ -174,7 +174,7 @@ def scheduling_setup(self) -> ExperimentScheduling: ), model_worker=TasksGroup( count=self.n_nodes * self.n_gpus_per_node, - scheduling=SchedulingSpec( + scheduling=Scheduling( cpu=self.cpus_per_model_worker, gpu=1, mem=self.mem_per_model_worker, diff --git a/realhf/scheduler/client.py b/realhf/scheduler/client.py index 894f0e66e..3662b7c24 100644 --- a/realhf/scheduler/client.py +++ b/realhf/scheduler/client.py @@ -160,7 +160,7 @@ def make(args: "BaseExperimentConfig", **kwargs) -> SchedulerClient: evaluator = kwargs.get("evaluator", None) return SlurmSchedulerClient( args, - args.scheduling_strategy, + args.schedule_strategy, evaluator, job_group_id, job_group_index, diff --git a/realhf/scheduler/slurm/client.py b/realhf/scheduler/slurm/client.py index e094ee6f4..772b60138 100644 --- a/realhf/scheduler/slurm/client.py +++ b/realhf/scheduler/slurm/client.py @@ -81,14 +81,14 @@ class SlurmSchedulerClient(SchedulerClient): def __init__( self, args, - scheduling_strategy: str, + schedule_strategy: str, evaluator: Optional[AutomaticEvaluator], job_group_id: str, job_group_index: int, ): super().__init__(args) - self.__schedule_strategy = scheduling_strategy + self.__schedule_strategy = schedule_strategy self.__pending_jobs: Dict[str, SlurmLaunchInfo] = dict() self.__committed_jobs: Dict[str, SlurmLaunchInfo] = dict() diff --git a/realhf/system/controller.py b/realhf/system/controller.py index 33bc3075e..0e8b982ee 100644 --- a/realhf/system/controller.py +++ b/realhf/system/controller.py @@ -127,7 +127,7 @@ def __check_consistent_scheduling( setup: system_api.ExperimentConfig, verbose=False, ): - # SchedulingSpec and connecting to workers. + # Scheduling and connecting to workers. workers_configs = [ (k, getattr(setup, k), getattr(scheduling, k)) for k in WORKER_TYPES @@ -142,7 +142,7 @@ def __check_consistent_scheduling( raise ValueError( f"Configuration and scheduling mismatch. " f"Number of worker configurations: {len(worker_setups)}, " - f"SchedulingSpec configs: {schedules}." + f"Scheduling configs: {schedules}." ) for name, config, schedule in workers_configs: @@ -153,7 +153,7 @@ def __check_consistent_scheduling( ) if len(config) != count: logger.error( - "SchedulingSpec and config mismatch, interrupting all workers." + "Scheduling and config mismatch, interrupting all workers." ) self.interrupt() raise IndexError( From 7e133b881aace5a28cf2a9128f3d0fa61afacc09 Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Sun, 2 Nov 2025 13:30:51 +0800 Subject: [PATCH 36/52] minor config fix --- examples/single-controller/gsm8k_grpo.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/single-controller/gsm8k_grpo.yaml b/examples/single-controller/gsm8k_grpo.yaml index ef0b0a36b..63107f19b 100644 --- a/examples/single-controller/gsm8k_grpo.yaml +++ b/examples/single-controller/gsm8k_grpo.yaml @@ -28,7 +28,7 @@ rollout: type: worker port_count: 1 gpu: 1 - cmd: python3 -m areal.scheduler.rpc.sync_rpc_server + cmd: python3 -m areal.scheduler.rpc.rpc_server gconfig: n_samples: 4 @@ -81,7 +81,7 @@ actor: type: worker port_count: 1 gpu: 1 - cmd: python3 -m areal.scheduler.rpc.sync_rpc_server + cmd: python3 -m areal.scheduler.rpc.rpc_server ref: experiment_name: ${experiment_name} @@ -101,7 +101,7 @@ ref: type: worker port_count: 1 gpu: 1 - cmd: python3 -m areal.scheduler.rpc.sync_rpc_server + cmd: python3 -m areal.scheduler.rpc.rpc_server # SGLang sglang: From 73912a8fd898098343792dd6036301fc93ecc449 Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Fri, 31 Oct 2025 23:18:47 +0800 Subject: [PATCH 37/52] merge tests --- areal/api/cli_args.py | 75 ++++---- areal/api/engine_api.py | 30 --- areal/api/io_struct.py | 4 + areal/core/remote_inf_engine.py | 3 +- areal/experimental/tests/test_openai.py | 4 +- areal/launcher/sglang_server.py | 18 +- areal/launcher/vllm_server.py | 11 +- areal/tests/test_fsdp_engine_nccl.py | 7 +- areal/tests/test_sglang_engine.py | 226 ---------------------- areal/tests/test_vllm_engine.py | 244 ------------------------ notebook/math_reflection_en.ipynb | 1 - notebook/math_reflection_zh.ipynb | 1 - notebook/search_agent_zh.ipynb | 1 - 13 files changed, 62 insertions(+), 563 deletions(-) delete mode 100644 areal/tests/test_sglang_engine.py delete mode 100644 areal/tests/test_vllm_engine.py diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index f7ec28f57..97b050fef 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -3,6 +3,7 @@ import os from dataclasses import asdict, dataclass, field from pathlib import Path +from typing import Any import uvloop import yaml @@ -531,6 +532,24 @@ class PPOCriticConfig(TrainEngineConfig): ) +def get_py_cmd(module: str, args: dict[str, Any]): + # convert to flags + cmd = ["python3", "-m", module] + for k, v in args.items(): + if v is None or v is False or v == "" or (isinstance(v, list) and not v): + continue + flag = f"--{k.replace('_', '-')}" + if v is True: + cmd.append(flag) + elif isinstance(v, list): + cmd.append(flag) + cmd.extend(map(str, v)) + else: + cmd.append(flag) + cmd.append(str(v)) + return cmd + + @dataclass class vLLMConfig: """Configuration for vLLM runtime. Refer to: @@ -591,6 +610,10 @@ def build_args( ) return args + @staticmethod + def build_cmd_from_args(args: dict[str, Any]): + return get_py_cmd("areal.thirdparty.vllm.areal_vllm_server", args) + @staticmethod def build_cmd( vllm_config: "vLLMConfig", @@ -608,18 +631,7 @@ def build_cmd( port=port, dist_init_addr=dist_init_addr, ) - # convert to flags - flags = [] - for k, v in args.items(): - if v is None or v is False or v == "": - continue - if v is True: - flags.append(f"--{k.replace('_', '-')}") - elif isinstance(v, list): - flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}") - else: - flags.append(f"--{k.replace('_', '-')} {v}") - return f"python3 -m areal.thirdparty.vllm.areal_vllm_server {' '.join(flags)}" + return vLLMConfig.build_cmd_from_args(args) @dataclass @@ -717,28 +729,19 @@ def build_cmd( node_rank=node_rank, ) - # convert to flags - flags = [] - for k, v in args.items(): - if is_version_less("sglang", "0.4.10.post2") and "max_loaded_loras" in k: - continue - if v is None or v is False or v == "": - continue - if v is True: - flags.append(f"--{k.replace('_', '-')}") - elif isinstance(v, list): - flags.append(f"--{k.replace('_', '-')} {' '.join(map(str, v))}") - else: - flags.append(f"--{k.replace('_', '-')} {v}") - return f"python3 -m sglang.launch_server {' '.join(flags)}" + return SGLangConfig.build_cmd_from_args(args) + + @staticmethod + def build_cmd_from_args(args: dict[str, Any]): + return get_py_cmd("sglang.launch_server", args) @staticmethod def build_args( sglang_config: "SGLangConfig", - tp_size, - base_gpu_id, - host, - port, + tp_size: int, + base_gpu_id: int, + host: str | None = None, + port: str | None = None, dist_init_addr: str | None = None, n_nodes: int = 1, node_rank: int = 0, @@ -754,19 +757,17 @@ def build_args( enable_multithread_load=sglang_config.enable_multithread_load, enable_fast_load=sglang_config.enable_fast_load, ) - args.pop("enable_multithread_load", None) - args.pop("enable_fast_load", None) args["model_loader_extra_config"] = json.dumps( model_loader_extra_config, separators=(",", ":") ) + args.pop("enable_multithread_load", None) + args.pop("enable_fast_load", None) # Map "all-linear" to "all" if "lora_target_modules" in args and args["lora_target_modules"]: args["lora_target_modules"] = [ x.replace("-linear", "") for x in args["lora_target_modules"] ] args = dict( - host=host, - port=port, # Model and tokenizer tokenizer_path=sglang_config.model_path, tokenizer_mode="auto", @@ -784,8 +785,14 @@ def build_args( dist_init_addr=dist_init_addr, **args, ) + if host is not None: + args["host"] = host + if port is not None: + args["port"] = port if not pkg_version.is_version_greater_or_equal("sglang", "0.4.9.post2"): raise RuntimeError("Needs sglang>=0.4.9.post2 to run the code.") + if is_version_less("sglang", "0.4.10.post2"): + args.pop("max_loaded_loras", None) return args diff --git a/areal/api/engine_api.py b/areal/api/engine_api.py index 216d5c02c..e554f4e38 100644 --- a/areal/api/engine_api.py +++ b/areal/api/engine_api.py @@ -1,7 +1,6 @@ import abc from collections.abc import Callable from concurrent.futures import Future -from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional import torch @@ -21,23 +20,6 @@ from areal.api.workflow_api import RolloutWorkflow -@dataclass -class Scheduling: - cpu: int - gpu: int - mem: int - nodelist: str | None = None - exclude: str | None = None - partition: str | None = None - container_image: str | None = None - type: str | None = None - env_vars: dict[str, str] = field(default_factory=dict) - # time utils from "https://slurm.schedmd.com/sbatch.html" - time_limit: str | None = None # see "--time" option for format - begin: str | None = None # see "--begin" option for format - deadline: str | None = None # see "--deadline" option for format - - class TrainEngine(abc.ABC): def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): """Initialize PyTorch distributed communication groups. @@ -138,18 +120,6 @@ def parallelism_group(self) -> dist.ProcessGroup: """ raise NotImplementedError() - def get_scheduling_config(self) -> Scheduling: - """Get the scheduling configuration for the engine. - - This includes configuration such as container image, CPU/GPU/memory size. - - Returns - ------- - Scheduling - The scheduling configuration for the engine - """ - raise NotImplementedError() - def destroy(self): """Destroy the engine and release GPU memory of models.""" diff --git a/areal/api/io_struct.py b/areal/api/io_struct.py index 0f9552fc9..4e6ade735 100644 --- a/areal/api/io_struct.py +++ b/areal/api/io_struct.py @@ -114,6 +114,8 @@ class WeightUpdateMeta: use_lora: bool = False + clear_checkpoint_after_load: bool = True + @classmethod def from_disk( cls, @@ -122,6 +124,7 @@ def from_disk( file_root: str, name: str = "default", use_lora: bool = False, + clear_checkpoint_after_load: bool = True, ) -> "WeightUpdateMeta": from areal.utils.saver import Saver @@ -133,6 +136,7 @@ def from_disk( type="disk", path=path, use_lora=use_lora, + clear_checkpoint_after_load=clear_checkpoint_after_load, ) @classmethod diff --git a/areal/core/remote_inf_engine.py b/areal/core/remote_inf_engine.py index 7eae12c59..4cb6067d5 100644 --- a/areal/core/remote_inf_engine.py +++ b/areal/core/remote_inf_engine.py @@ -606,7 +606,8 @@ def callback(fut): # Update LoRA state if this was a LoRA update if meta.use_lora: self.lora_initialized = True - shutil.rmtree(meta.path, ignore_errors=True) + if meta.clear_checkpoint_after_load: + shutil.rmtree(meta.path, ignore_errors=True) fut.add_done_callback(callback) diff --git a/areal/experimental/tests/test_openai.py b/areal/experimental/tests/test_openai.py index 3dc545256..722dd8ae8 100644 --- a/areal/experimental/tests/test_openai.py +++ b/areal/experimental/tests/test_openai.py @@ -48,10 +48,8 @@ def sglang_server(): dist_init_addr=f"{HOST}:{DIST_PORT}", ) # Launch process - cmd = cmd.replace("\\\n", " ").replace("\\", " ") process = subprocess.Popen( - cmd.split(), - text=True, + cmd, stdout=sys.stdout, stderr=sys.stdout, ) diff --git a/areal/launcher/sglang_server.py b/areal/launcher/sglang_server.py index 56fef2eaf..6c04e30c9 100644 --- a/areal/launcher/sglang_server.py +++ b/areal/launcher/sglang_server.py @@ -7,7 +7,6 @@ import uuid from concurrent.futures import ThreadPoolExecutor from copy import deepcopy -from typing import Optional import psutil import requests @@ -68,29 +67,26 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N pass -def launch_server_cmd(command: str) -> subprocess.Popen: +def launch_server_cmd(command: list[str]) -> subprocess.Popen: """ Execute a shell command and return its process handle. """ # Replace newline continuations and split the command string. - command = command.replace("\\\n", " ").replace("\\", " ") - logger.info(f"Launch command: {command}") - parts = command.split() + logger.info(f"Launch command: {' '.join(command)}") _env = os.environ.copy() # To avoid DirectoryNotEmpty error caused by triton triton_cache_path = _env.get("TRITON_CACHE_PATH", TRITON_CACHE_PATH) unique_triton_cache_path = os.path.join(triton_cache_path, str(uuid.uuid4())) _env["TRITON_CACHE_PATH"] = unique_triton_cache_path return subprocess.Popen( - parts, - text=True, + command, env=_env, stdout=sys.stdout, stderr=subprocess.STDOUT, ) -def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None: +def wait_for_server(base_url: str, timeout: int | None = None) -> None: """Wait for the server to be ready by polling the /v1/models endpoint. Args: @@ -137,9 +133,9 @@ def run(self): gpus_per_server = self.allocation_mode.gen_instance_size cross_nodes = False if gpus_per_server > self.n_gpus_per_node: - assert ( - gpus_per_server % self.n_gpus_per_node == 0 - ), "Cross-nodes SGLang only supports utilizing all gpus in one node" + assert gpus_per_server % self.n_gpus_per_node == 0, ( + "Cross-nodes SGLang only supports utilizing all gpus in one node" + ) cross_nodes = True node_rank = int(os.environ["AREAL_SGLANG_MULTI_NODE_RANK"]) master_addr = os.environ["AREAL_SGLANG_MULTI_NODE_MASTER_ADDR"] diff --git a/areal/launcher/vllm_server.py b/areal/launcher/vllm_server.py index ddbfed5f8..70a4b4e55 100644 --- a/areal/launcher/vllm_server.py +++ b/areal/launcher/vllm_server.py @@ -24,14 +24,14 @@ logger = logging.getLogger("vLLMServer Wrapper") -def launch_server_cmd(command: str, custom_env: dict | None = None) -> subprocess.Popen: +def launch_server_cmd( + command: list[str], custom_env: dict | None = None +) -> subprocess.Popen: """ Execute a shell command and return its process handle. """ # Replace newline continuations and split the command string. - command = command.replace("\\\n", " ").replace("\\", " ") - logger.info(f"Launch command: {command}") - parts = command.split() + logger.info(f"Launch command: {' '.join(command)}") _env = os.environ.copy() # To avoid DirectoryNotEmpty error caused by triton triton_cache_path = _env.get("TRITON_CACHE_PATH", TRITON_CACHE_PATH) @@ -44,8 +44,7 @@ def launch_server_cmd(command: str, custom_env: dict | None = None) -> subproces if custom_env is not None: _env.update(custom_env) return subprocess.Popen( - parts, - text=True, + command, env=_env, stdout=sys.stdout, stderr=subprocess.STDOUT, diff --git a/areal/tests/test_fsdp_engine_nccl.py b/areal/tests/test_fsdp_engine_nccl.py index 9aa0d49bf..df36d0230 100644 --- a/areal/tests/test_fsdp_engine_nccl.py +++ b/areal/tests/test_fsdp_engine_nccl.py @@ -57,14 +57,11 @@ def sglang_server_nccl(): port=PORT, dist_init_addr=f"{HOST}:{DIST_PORT}", ) - full_command = f"{cmd} --port {PORT}" - full_command = full_command.replace("\\\n", " ").replace("\\", " ") os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - print(f"full_command to start sglang server: {full_command}", flush=True) + print(f"full_command to start sglang server: {cmd}", flush=True) process = subprocess.Popen( - full_command.split(), - text=True, + cmd, stdout=sys.stdout, stderr=sys.stdout, ) diff --git a/areal/tests/test_sglang_engine.py b/areal/tests/test_sglang_engine.py deleted file mode 100644 index a3d1b5b02..000000000 --- a/areal/tests/test_sglang_engine.py +++ /dev/null @@ -1,226 +0,0 @@ -import os -import subprocess -import sys -import time - -import pytest -import requests - -from areal.api.cli_args import ( - GenerationHyperparameters, - InferenceEngineConfig, - SGLangConfig, -) -from areal.api.io_struct import WeightUpdateMeta -from areal.utils import network -from areal.utils.data import get_batch_size -from areal.utils.hf_utils import load_hf_tokenizer - -EXPR_NAME = "test_sglang_engine" -TRIAL_NAME = "trial_0" -MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" -if not os.path.exists(MODEL_PATH): - MODEL_PATH = "Qwen/Qwen3-0.6B" -PORT, DIST_PORT = network.find_free_ports(2) -HOST = network.gethostip() -# set a large timeout since we may need to download the model from hub -RUN_SERVER_TIMEOUT = 180 - - -def check_server_health(base_url): - try: - response = requests.get(f"{base_url}/health", timeout=30) - return response.status_code == 200 - except requests.exceptions.RequestException: - return False - - -@pytest.fixture(scope="module") -def sglang_server(): - from areal.utils import seeding - - seeding.set_random_seed(1, EXPR_NAME) - cmd = SGLangConfig.build_cmd( - sglang_config=SGLangConfig( - skip_tokenizer_init=True, - model_path=MODEL_PATH, - mem_fraction_static=0.3, - ), - host=HOST, - port=PORT, - tp_size=1, - base_gpu_id=0, - dist_init_addr=f"{HOST}:{DIST_PORT}", - ) - # Launch process - cmd = cmd.replace("\\\n", " ").replace("\\", " ") - process = subprocess.Popen( - cmd.split(), - text=True, - stdout=sys.stdout, - stderr=sys.stdout, - ) - base_url = f"http://{HOST}:{PORT}" - tik = time.time() - while time.time() - tik < RUN_SERVER_TIMEOUT: - if check_server_health(base_url): - break - time.sleep(1) - if time.time() - tik > RUN_SERVER_TIMEOUT: - raise RuntimeError("server launch failed") - yield - process.terminate() - - -def _dummy_reward_fn(*args, **kwargs): - return 1.0 - - -@pytest.mark.parametrize("n_samples", [1, 2, 4]) -def test_remote_sglang_rollout(sglang_server, n_samples): - from areal.engine.sglang_remote import RemoteSGLangEngine - from areal.workflow.rlvr import RLVRWorkflow - - config = InferenceEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - max_concurrent_rollouts=2, - consumer_batch_size=2, - ) - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - engine = RemoteSGLangEngine(config) - engine.initialize() - - gconfig = GenerationHyperparameters( - max_new_tokens=16, greedy=False, n_samples=n_samples - ) - tokenizer = load_hf_tokenizer(MODEL_PATH) - - workflow = RLVRWorkflow( - reward_fn=_dummy_reward_fn, - gconfig=gconfig, - tokenizer=tokenizer, - enable_thinking=False, - ) - - data = { - "messages": [{"role": "user", "content": "Hello, how are you?"}], - } - result = engine.rollout_batch([data] * 2, workflow=workflow) - assert isinstance(result, dict) - bs = get_batch_size(result) - assert bs == 2 * n_samples - engine.destroy() - - -@pytest.mark.parametrize("ofp", [0, 1, 4]) -@pytest.mark.parametrize("bs", [2]) -@pytest.mark.parametrize("n_samples", [2, 1]) -def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples): - from areal.engine.sglang_remote import RemoteSGLangEngine - from areal.workflow.rlvr import RLVRWorkflow - - config = InferenceEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - consumer_batch_size=bs, - max_head_offpolicyness=ofp, - enable_rollout_tracing=True, - ) - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - engine = RemoteSGLangEngine(config) - engine.initialize() - - gconfig = GenerationHyperparameters( - max_new_tokens=2, greedy=False, n_samples=n_samples - ) - tokenizer = load_hf_tokenizer(MODEL_PATH) - - workflow = RLVRWorkflow( - reward_fn=_dummy_reward_fn, - gconfig=gconfig, - tokenizer=tokenizer, - enable_thinking=False, - ) - data = { - "messages": [{"role": "user", "content": "Hello, how are you?"}], - } - for _ in range(bs * 2): - engine.submit(data, workflow=workflow) - - if ofp < 1: - # Due to controlled offpolicyness, not all requests are committed - with pytest.raises(TimeoutError): - engine.wait(count=bs * 2, timeout=10) - else: - result = engine.wait(count=bs * 2, timeout=10) - assert result["attention_mask"].shape[0] == bs * 2 * n_samples - - # Update model version - engine.set_version(1) - print("Updated model version", flush=True) - - # submit again - for _ in range(bs * 2): - engine.submit(data, workflow=workflow) - - if ofp < 2: - # Due to controlled offpolicyness, not all requests are committed - with pytest.raises(TimeoutError): - engine.wait(count=bs * 4, timeout=5) - else: - # 2 * bs samples haved been retrived above - results = engine.wait(count=bs * 2, timeout=5) - assert results["attention_mask"].shape[0] == bs * 2 * n_samples - - # exit - engine.destroy() - - -def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server): - # setup FSDP engine - from areal.api.cli_args import OptimizerConfig, TrainEngineConfig - from areal.api.io_struct import FinetuneSpec - from areal.engine.fsdp_engine import FSDPEngine - - os.environ["WORLD_SIZE"] = "1" - os.environ["RANK"] = "0" - os.environ["LOCAL_RANK"] = "0" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "7777" - - engine_config = TrainEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - path=MODEL_PATH, - optimizer=OptimizerConfig(), - ) - engine = FSDPEngine(engine_config) - engine.create_process_group() - ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) - engine.initialize(None, ft_spec) - engine.model_version = 100 - - # setup name resolve - import areal.utils.name_resolve as name_resolve - from areal.api.cli_args import NameResolveConfig - - nfs_record_root = tmp_path_factory.mktemp("nfs_record_path") - name_resolve_config = NameResolveConfig(type="nfs", nfs_record_root=nfs_record_root) - name_resolve.reconfigure(name_resolve_config) - # initialize SGLang remote engine - from areal.api.cli_args import InferenceEngineConfig - from areal.engine.sglang_remote import RemoteSGLangEngine - - config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - inf_engine = RemoteSGLangEngine(config) - inf_engine.initialize() - inf_engine.set_version(100) - # test update weights - path = tmp_path_factory.mktemp("upload_weights_from_disk") - update_weight_meta = WeightUpdateMeta(type="disk", path=str(path)) - engine.connect_engine(inf_engine, update_weight_meta) - engine.set_version(100) - engine.update_weights(update_weight_meta) - inf_engine.destroy() diff --git a/areal/tests/test_vllm_engine.py b/areal/tests/test_vllm_engine.py deleted file mode 100644 index c560eb918..000000000 --- a/areal/tests/test_vllm_engine.py +++ /dev/null @@ -1,244 +0,0 @@ -import os -import subprocess -import sys -import time - -import pytest -import requests - -from areal.api.cli_args import ( - GenerationHyperparameters, - InferenceEngineConfig, - vLLMConfig, -) -from areal.api.io_struct import WeightUpdateMeta -from areal.utils import network -from areal.utils.data import get_batch_size -from areal.utils.hf_utils import load_hf_tokenizer -from areal.utils.pkg_version import is_available - -EXPR_NAME = "test_vllm_engine" -TRIAL_NAME = "trial_0" -MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" -if not os.path.exists(MODEL_PATH): - MODEL_PATH = "Qwen/Qwen3-0.6B" -PORT, DIST_PORT = network.find_free_ports(2) -HOST = network.gethostip() -# set a large timeout since we may need to download the model from hub -RUN_SERVER_TIMEOUT = 180 - -IS_VLLM_INSTALLED = is_available("vllm") - - -def check_server_health(base_url): - try: - response = requests.get(f"{base_url}/health", timeout=30) - return response.status_code == 200 - except requests.exceptions.RequestException: - return False - - -@pytest.fixture(scope="module") -def vllm_server(): - from areal.utils import seeding - - seeding.set_random_seed(1, EXPR_NAME) - cmd = vLLMConfig.build_cmd( - vllm_config=vLLMConfig( - skip_tokenizer_init=False, - model=MODEL_PATH, - gpu_memory_utilization=0.1, - max_model_len=4096, - ), - host=HOST, - port=PORT, - tp_size=1, - pp_size=1, - dist_init_addr=f"{HOST}:{DIST_PORT}", - ) - # Launch process - cmd = cmd.replace("\\\n", " ").replace("\\", " ") - process = subprocess.Popen( - cmd.split(), - text=True, - stdout=sys.stdout, - stderr=sys.stdout, - ) - base_url = f"http://{HOST}:{PORT}" - tik = time.time() - while time.time() - tik < RUN_SERVER_TIMEOUT: - if check_server_health(base_url): - break - time.sleep(1) - if time.time() - tik > RUN_SERVER_TIMEOUT: - raise RuntimeError("server launch failed") - yield - process.terminate() - - -def _dummy_reward_fn(*args, **kwargs): - return 1.0 - - -@pytest.mark.skipif( - not IS_VLLM_INSTALLED, reason="Skip the test because vllm is not installed." -) -@pytest.mark.parametrize("n_samples", [1, 2, 4]) -@pytest.mark.slow -@pytest.mark.ci -def test_remote_vllm_rollout(vllm_server, n_samples): - from areal.engine.vllm_remote import RemotevLLMEngine - from areal.workflow.rlvr import RLVRWorkflow - - config = InferenceEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - max_concurrent_rollouts=2, - consumer_batch_size=2, - ) - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - engine = RemotevLLMEngine(config) - engine.initialize() - - gconfig = GenerationHyperparameters( - max_new_tokens=16, greedy=False, n_samples=n_samples - ) - tokenizer = load_hf_tokenizer(MODEL_PATH) - - workflow = RLVRWorkflow( - reward_fn=_dummy_reward_fn, - gconfig=gconfig, - tokenizer=tokenizer, - enable_thinking=False, - ) - - data = { - "messages": [{"role": "user", "content": "Hello, how are you?"}], - } - result = engine.rollout_batch([data] * 2, workflow=workflow) - assert isinstance(result, dict) - bs = get_batch_size(result) - assert bs == 2 * n_samples - engine.destroy() - - -@pytest.mark.skipif( - not IS_VLLM_INSTALLED, reason="Skip the test because vllm is not installed." -) -@pytest.mark.parametrize("ofp", [1, 4, 16]) -@pytest.mark.parametrize("bs", [2, 4]) -@pytest.mark.parametrize("n_samples", [2, 1]) -@pytest.mark.slow -@pytest.mark.ci -def test_remote_vllm_staleness_control(vllm_server, bs, ofp, n_samples): - from areal.engine.vllm_remote import RemotevLLMEngine - from areal.workflow.rlvr import RLVRWorkflow - - config = InferenceEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - consumer_batch_size=bs, - max_head_offpolicyness=ofp, - ) - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - engine = RemotevLLMEngine(config) - engine.initialize() - - gconfig = GenerationHyperparameters( - max_new_tokens=2, greedy=False, n_samples=n_samples - ) - tokenizer = load_hf_tokenizer(MODEL_PATH) - - workflow = RLVRWorkflow( - reward_fn=_dummy_reward_fn, - gconfig=gconfig, - tokenizer=tokenizer, - enable_thinking=False, - ) - data = { - "messages": [{"role": "user", "content": "Hello, how are you?"}], - } - for _ in range(bs * 2): - engine.submit(data, workflow=workflow) - - if ofp < 1: - # Due to controlled offpolicyness, not all requests are committed - with pytest.raises(TimeoutError): - engine.wait(count=bs * 2, timeout=10) - else: - result = engine.wait(count=bs * 2, timeout=10) - assert result["attention_mask"].shape[0] == bs * 2 * n_samples - - # Update model version - engine.set_version(1) - print("Updated model version", flush=True) - - # submit again - for _ in range(bs * 2): - engine.submit(data, workflow=workflow) - - if ofp < 2: - # Due to controlled offpolicyness, not all requests are committed - with pytest.raises(TimeoutError): - engine.wait(count=bs * 4, timeout=5) - else: - # 2 * bs samples haved been retrived above - results = engine.wait(count=bs * 2, timeout=5) - assert results["attention_mask"].shape[0] == bs * 2 * n_samples - - # exit - engine.destroy() - - -@pytest.mark.skipif( - not IS_VLLM_INSTALLED, reason="Skip the test because vllm is not installed." -) -@pytest.mark.slow -@pytest.mark.ci -def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, vllm_server): - # setup FSDP engine - from areal.api.cli_args import OptimizerConfig, TrainEngineConfig - from areal.api.io_struct import FinetuneSpec - from areal.engine.fsdp_engine import FSDPEngine - - os.environ["WORLD_SIZE"] = "1" - os.environ["RANK"] = "0" - os.environ["LOCAL_RANK"] = "0" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "7777" - - engine_config = TrainEngineConfig( - experiment_name=EXPR_NAME, - trial_name=TRIAL_NAME, - path=MODEL_PATH, - optimizer=OptimizerConfig(), - ) - engine = FSDPEngine(engine_config) - engine.create_process_group() - ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) - engine.initialize(None, ft_spec) - engine.model_version = 100 - - # setup name resolve - import areal.utils.name_resolve as name_resolve - from areal.api.cli_args import NameResolveConfig - - nfs_record_root = tmp_path_factory.mktemp("nfs_record_path") - name_resolve_config = NameResolveConfig(type="nfs", nfs_record_root=nfs_record_root) - name_resolve.reconfigure(name_resolve_config) - # initialize vLLM remote engine - from areal.api.cli_args import InferenceEngineConfig - from areal.engine.vllm_remote import RemotevLLMEngine - - config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - inf_engine = RemotevLLMEngine(config) - inf_engine.initialize() - inf_engine.set_version(100) - # test update weights - path = tmp_path_factory.mktemp("areal_update_weights") - update_weight_meta = WeightUpdateMeta(type="disk", path=str(path)) - engine.connect_engine(inf_engine, update_weight_meta) - engine.set_version(100) - engine.update_weights(update_weight_meta) - inf_engine.destroy() diff --git a/notebook/math_reflection_en.ipynb b/notebook/math_reflection_en.ipynb index 34cfad1a4..2cb1fa669 100644 --- a/notebook/math_reflection_en.ipynb +++ b/notebook/math_reflection_en.ipynb @@ -169,7 +169,6 @@ ")\n", "sglang_process = subprocess.Popen(\n", " sglang_cmd,\n", - " shell=True,\n", " stdout=sys.stdout,\n", " stderr=sys.stderr,\n", ")" diff --git a/notebook/math_reflection_zh.ipynb b/notebook/math_reflection_zh.ipynb index 40af42915..0bbc9902f 100644 --- a/notebook/math_reflection_zh.ipynb +++ b/notebook/math_reflection_zh.ipynb @@ -160,7 +160,6 @@ ")\n", "sglang_process = subprocess.Popen(\n", " sglang_cmd,\n", - " shell=True,\n", " stdout=sys.stdout,\n", " stderr=sys.stderr,\n", ")" diff --git a/notebook/search_agent_zh.ipynb b/notebook/search_agent_zh.ipynb index b1ebf0628..41f63b15e 100644 --- a/notebook/search_agent_zh.ipynb +++ b/notebook/search_agent_zh.ipynb @@ -136,7 +136,6 @@ ")\n", "sglang_process = subprocess.Popen(\n", " sglang_cmd,\n", - " shell=True,\n", " stdout=sys.stdout,\n", " stderr=sys.stderr,\n", ")\n", From a822cb28cde784c1ac39d5a21816de40f47864b7 Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Fri, 31 Oct 2025 23:35:01 +0800 Subject: [PATCH 38/52] fix docstring --- areal/launcher/sglang_server.py | 2 +- areal/launcher/vllm_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/areal/launcher/sglang_server.py b/areal/launcher/sglang_server.py index 6c04e30c9..efe9b2805 100644 --- a/areal/launcher/sglang_server.py +++ b/areal/launcher/sglang_server.py @@ -69,7 +69,7 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N def launch_server_cmd(command: list[str]) -> subprocess.Popen: """ - Execute a shell command and return its process handle. + Launch inference server in a new process and return its process handle. """ # Replace newline continuations and split the command string. logger.info(f"Launch command: {' '.join(command)}") diff --git a/areal/launcher/vllm_server.py b/areal/launcher/vllm_server.py index 70a4b4e55..1d8b12be1 100644 --- a/areal/launcher/vllm_server.py +++ b/areal/launcher/vllm_server.py @@ -28,7 +28,7 @@ def launch_server_cmd( command: list[str], custom_env: dict | None = None ) -> subprocess.Popen: """ - Execute a shell command and return its process handle. + Launch inference server in a new process and return its process handle. """ # Replace newline continuations and split the command string. logger.info(f"Launch command: {' '.join(command)}") From 6e6288413aa70d28a302be13a345ce47c6a59b1d Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Sat, 1 Nov 2025 14:39:08 +0800 Subject: [PATCH 39/52] add test --- areal/tests/test_inference_engines.py | 298 ++++++++++++++++++++++++++ 1 file changed, 298 insertions(+) create mode 100644 areal/tests/test_inference_engines.py diff --git a/areal/tests/test_inference_engines.py b/areal/tests/test_inference_engines.py new file mode 100644 index 000000000..072a6b07c --- /dev/null +++ b/areal/tests/test_inference_engines.py @@ -0,0 +1,298 @@ +"""Test suite for remote inference engines (vLLM and SGLang).""" + +import os +import subprocess +import sys +import time + +import pytest +import requests + +from areal.api.cli_args import ( + GenerationHyperparameters, + InferenceEngineConfig, + SGLangConfig, + vLLMConfig, +) +from areal.api.io_struct import WeightUpdateMeta +from areal.utils import network +from areal.utils.data import get_batch_size +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.pkg_version import is_available + +MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" +if not os.path.exists(MODEL_PATH): + MODEL_PATH = "Qwen/Qwen3-0.6B" + +# set a large timeout since we may need to download the model from hub +RUN_SERVER_TIMEOUT = 180 + +IS_VLLM_INSTALLED = is_available("vllm") + + +def check_server_health(base_url): + """Check if the server is healthy and ready to accept requests.""" + try: + response = requests.get(f"{base_url}/health", timeout=30) + return response.status_code == 200 + except requests.exceptions.RequestException: + return False + + +def _dummy_reward_fn(*args, **kwargs): + """Dummy reward function for testing.""" + return 1.0 + + +@pytest.fixture(params=["vllm", "sglang"], scope="module") +def inference_engine(request): + """Fixture for remote inference engines only (vLLM and SGLang).""" + backend = request.param + + # Skip if vLLM is not installed + if backend == "vllm" and not IS_VLLM_INSTALLED: + pytest.skip("vLLM is not installed") + + from areal.utils import seeding + + expr_name = f"test_remote_{backend}_engine" + trial_name = "trial_0" + + seeding.set_random_seed(1, expr_name) + + port, dist_port = network.find_free_ports(2) + host = network.gethostip() + + # Configure SGLang + sglang_config = SGLangConfig( + skip_tokenizer_init=True, + model_path=MODEL_PATH, + mem_fraction_static=0.1, + ) + sglang_args = SGLangConfig.build_args( + sglang_config=sglang_config, + tp_size=1, + base_gpu_id=0, + host=host, + port=port, + dist_init_addr=f"{host}:{dist_port}", + ) + + # Configure vLLM + vllm_config = vLLMConfig( + skip_tokenizer_init=False, + model=MODEL_PATH, + gpu_memory_utilization=0.1, + ) + vllm_args = vLLMConfig.build_args( + vllm_config=vllm_config, + tp_size=1, + host=host, + port=port, + ) + + # Launch remote server and initialize engine + if backend == "vllm": + from areal.engine.vllm_remote import RemotevLLMEngine + + cmd = vLLMConfig.build_cmd_from_args(vllm_args) + engine_class = RemotevLLMEngine + else: # sglang + from areal.engine.sglang_remote import RemoteSGLangEngine + + cmd = SGLangConfig.build_cmd_from_args(sglang_args) + engine_class = RemoteSGLangEngine + + # Launch process + process = subprocess.Popen( + cmd, + stdout=sys.stdout, + stderr=sys.stdout, + ) + base_url = f"http://{host}:{port}" + tik = time.time() + while time.time() - tik < RUN_SERVER_TIMEOUT: + if check_server_health(base_url): + break + time.sleep(1) + if time.time() - tik > RUN_SERVER_TIMEOUT: + process.terminate() + raise RuntimeError(f"{backend.upper()} server launch failed") + + # Set environment for remote engine + os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{host}:{port}" + + yield { + "engine_class": engine_class, + "expr_name": expr_name, + "trial_name": trial_name, + "host": host, + "port": port, + } + + # Cleanup + process.terminate() + + +# ============================================================================ +# Unified Tests +# ============================================================================ + + +@pytest.mark.parametrize("n_samples", [1, 2, 4]) +@pytest.mark.slow +@pytest.mark.ci +def test_rollout(inference_engine, n_samples): + """Test engine rollout with different sample sizes.""" + from areal.workflow.rlvr import RLVRWorkflow + + config = InferenceEngineConfig( + experiment_name=inference_engine["expr_name"], + trial_name=inference_engine["trial_name"], + max_concurrent_rollouts=2, + consumer_batch_size=2, + enable_rollout_tracing=True, + ) + + engine = inference_engine["engine_class"](config) + engine.initialize() + + gconfig = GenerationHyperparameters( + max_new_tokens=16, greedy=False, n_samples=n_samples + ) + tokenizer = load_hf_tokenizer(MODEL_PATH) + + workflow = RLVRWorkflow( + reward_fn=_dummy_reward_fn, + gconfig=gconfig, + tokenizer=tokenizer, + enable_thinking=False, + ) + + data = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + } + result = engine.rollout_batch([data] * 2, workflow=workflow) + assert isinstance(result, dict) + bs = get_batch_size(result) + assert bs == 2 * n_samples + engine.destroy() + + +@pytest.mark.parametrize("ofp", [0, 1, 4, 16]) +@pytest.mark.parametrize("bs", [2, 4]) +@pytest.mark.parametrize("n_samples", [2, 1]) +@pytest.mark.slow +@pytest.mark.ci +def test_staleness_control(inference_engine, bs, ofp, n_samples): + """Test engine staleness control mechanism.""" + from areal.workflow.rlvr import RLVRWorkflow + + config = InferenceEngineConfig( + experiment_name=inference_engine["expr_name"], + trial_name=inference_engine["trial_name"], + consumer_batch_size=bs, + max_head_offpolicyness=ofp, + enable_rollout_tracing=True, + ) + + engine = inference_engine["engine_class"](config) + engine.initialize() + + gconfig = GenerationHyperparameters( + max_new_tokens=2, greedy=False, n_samples=n_samples + ) + tokenizer = load_hf_tokenizer(MODEL_PATH) + + workflow = RLVRWorkflow( + reward_fn=_dummy_reward_fn, + gconfig=gconfig, + tokenizer=tokenizer, + enable_thinking=False, + ) + data = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + } + for _ in range(bs * 2): + engine.submit(data, workflow=workflow) + + if ofp < 1: + # Due to controlled offpolicyness, not all requests are committed + with pytest.raises(TimeoutError): + engine.wait(count=bs * 2, timeout=10) + else: + result = engine.wait(count=bs * 2, timeout=10) + assert result["attention_mask"].shape[0] == bs * 2 * n_samples + + # Update model version + engine.set_version(1) + print("Updated model version", flush=True) + + # submit again + for _ in range(bs * 2): + engine.submit(data, workflow=workflow) + + if ofp < 2: + # Due to controlled offpolicyness, not all requests are committed + with pytest.raises(TimeoutError): + engine.wait(count=bs * 4, timeout=5) + else: + # 2 * bs samples haved been retrived above + results = engine.wait(count=bs * 2, timeout=5) + assert results["attention_mask"].shape[0] == bs * 2 * n_samples + + engine.destroy() + + +@pytest.mark.slow +@pytest.mark.ci +def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, inference_engine): + """Test disk-based weight updates from FSDP engine to inference engine.""" + + # setup FSDP engine + from areal.api.cli_args import OptimizerConfig, TrainEngineConfig + from areal.api.io_struct import FinetuneSpec + from areal.engine.fsdp_engine import FSDPEngine + + os.environ["WORLD_SIZE"] = "1" + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "7777" + + engine_config = TrainEngineConfig( + experiment_name=inference_engine["expr_name"], + trial_name=inference_engine["trial_name"], + path=MODEL_PATH, + optimizer=OptimizerConfig(), + ) + train_engine = FSDPEngine(engine_config) + train_engine.create_process_group() + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) + train_engine.initialize(None, ft_spec) + train_engine.model_version = 100 + + # setup name resolve + import areal.utils.name_resolve as name_resolve + from areal.api.cli_args import NameResolveConfig + + nfs_record_root = tmp_path_factory.mktemp("nfs_record_path") + name_resolve_config = NameResolveConfig(type="nfs", nfs_record_root=nfs_record_root) + name_resolve.reconfigure(name_resolve_config) + + config = InferenceEngineConfig( + experiment_name=inference_engine["expr_name"], + trial_name=inference_engine["trial_name"], + ) + # initialize inference engine + inf_engine = inference_engine["engine_class"](config) + inf_engine.initialize() + inf_engine.set_version(100) + + # test update weights + path = tmp_path_factory.mktemp("update_weights_from_disk") + update_weight_meta = WeightUpdateMeta(type="disk", path=str(path)) + train_engine.connect_engine(inf_engine, update_weight_meta) + train_engine.set_version(100) + train_engine.update_weights(update_weight_meta) + inf_engine.destroy() From 98d2c8d2bfcb7faf8bbb5a6e27f6decec6ea08b9 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Fri, 31 Oct 2025 23:21:14 +0800 Subject: [PATCH 40/52] fix format --- areal/launcher/sglang_server.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/areal/launcher/sglang_server.py b/areal/launcher/sglang_server.py index efe9b2805..e60c71659 100644 --- a/areal/launcher/sglang_server.py +++ b/areal/launcher/sglang_server.py @@ -133,9 +133,10 @@ def run(self): gpus_per_server = self.allocation_mode.gen_instance_size cross_nodes = False if gpus_per_server > self.n_gpus_per_node: - assert gpus_per_server % self.n_gpus_per_node == 0, ( - "Cross-nodes SGLang only supports utilizing all gpus in one node" - ) + if gpus_per_server % self.n_gpus_per_node != 0: + raise ValueError( + "Cross-nodes SGLang only supports utilizing all gpus in one node" + ) cross_nodes = True node_rank = int(os.environ["AREAL_SGLANG_MULTI_NODE_RANK"]) master_addr = os.environ["AREAL_SGLANG_MULTI_NODE_MASTER_ADDR"] From 12cc12e2fe08913e19f3fa6d90f6928bc82ed2c3 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Sun, 2 Nov 2025 22:11:21 +0800 Subject: [PATCH 41/52] shorter ctx len for test --- areal/tests/test_inference_engines.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/areal/tests/test_inference_engines.py b/areal/tests/test_inference_engines.py index 072a6b07c..724258c2c 100644 --- a/areal/tests/test_inference_engines.py +++ b/areal/tests/test_inference_engines.py @@ -67,7 +67,8 @@ def inference_engine(request): sglang_config = SGLangConfig( skip_tokenizer_init=True, model_path=MODEL_PATH, - mem_fraction_static=0.1, + mem_fraction_static=0.2, + context_length=128, ) sglang_args = SGLangConfig.build_args( sglang_config=sglang_config, @@ -82,7 +83,8 @@ def inference_engine(request): vllm_config = vLLMConfig( skip_tokenizer_init=False, model=MODEL_PATH, - gpu_memory_utilization=0.1, + gpu_memory_utilization=0.2, + max_model_len=128, ) vllm_args = vLLMConfig.build_args( vllm_config=vllm_config, From 3ba98e6af89c2dad4f20f5d8c2a33d52eceec130 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Sun, 2 Nov 2025 22:14:14 +0800 Subject: [PATCH 42/52] add adv norm in grpo test --- areal/tests/grpo/config.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/areal/tests/grpo/config.yaml b/areal/tests/grpo/config.yaml index 374975fa1..46c35a740 100644 --- a/areal/tests/grpo/config.yaml +++ b/areal/tests/grpo/config.yaml @@ -60,6 +60,9 @@ actor: gradient_clipping: 1.0 warmup_steps_proportion: 0.001 group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch ref: experiment_name: ${experiment_name} From 204b1fd134e90007655554a4449cf14f4b6229ea Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Mon, 3 Nov 2025 10:44:56 +0800 Subject: [PATCH 43/52] update test to use local path --- areal/engine/megatron_engine.py | 3 --- areal/platforms/cuda.py | 2 -- areal/tests/test_estimate_num_params.py | 7 +++++++ 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 66402aec1..393b2f118 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -192,9 +192,6 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None # TODO: Change engine_api.py and FSDPEngine API to seperate create_process_group # from engine initialize when moving out of experimental. self.parallel_strategy = self._make_parallel_strategy(parallel_strategy) - # Required by NCCL weight update group for SGLang - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" # TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher # NOTE: device_id **SHOULD NOT** be passed into init_process_group, # otherwise initializing the NCCL weight update group will be wrong! diff --git a/areal/platforms/cuda.py b/areal/platforms/cuda.py index 953ecac2c..190d5f0f3 100644 --- a/areal/platforms/cuda.py +++ b/areal/platforms/cuda.py @@ -55,8 +55,6 @@ def get_custom_env_vars(cls) -> dict: # "RAY_DEBUG": "legacy" "TORCHINDUCTOR_COMPILE_THREADS": "2", "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", - "NCCL_CUMEM_ENABLE": "0", # https://github.com/NVIDIA/nccl/issues/1234 - "NCCL_NVLS_ENABLE": "0", } return env_vars diff --git a/areal/tests/test_estimate_num_params.py b/areal/tests/test_estimate_num_params.py index 2f2f48dac..db6a5cc81 100644 --- a/areal/tests/test_estimate_num_params.py +++ b/areal/tests/test_estimate_num_params.py @@ -31,6 +31,13 @@ def test_estimate_num_params(model_name_or_path): mpu.initialize_model_parallel() tensor_parallel.model_parallel_cuda_manual_seed(0) + # use local model if possible + local_path = os.path.join( + "/storage/openpsi/models", model_name_or_path.replace("/", "__") + ) + if os.path.exists(local_path): + model_name_or_path = local_path + bridge = mbridge.AutoBridge.from_pretrained(model_name_or_path) hf_config, tf_config = make_hf_and_mcore_config( model_name_or_path, dtype=torch.bfloat16, bridge=bridge From 95a08acf66b8930e48e610c3feba3007c4ca4050 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Mon, 3 Nov 2025 13:04:36 +0800 Subject: [PATCH 44/52] resource cleanup in tests --- areal/tests/test_examples.py | 62 +++++++++++--------- areal/tests/test_fsdp_engine_nccl.py | 81 +++++++++++++++++---------- areal/tests/test_inference_engines.py | 51 ++++++++++------- areal/tests/test_megatron_engine.py | 6 +- 4 files changed, 120 insertions(+), 80 deletions(-) diff --git a/areal/tests/test_examples.py b/areal/tests/test_examples.py index c726e3251..4323e811b 100644 --- a/areal/tests/test_examples.py +++ b/areal/tests/test_examples.py @@ -728,35 +728,43 @@ def test_search_agent_deepresearch(tmp_path_factory): stderr=sys.stderr, env=_env, ) - time.sleep(20) - loop = asyncio.get_event_loop() - return_code, success = loop.run_until_complete( - run_example( - example_file, - config_name, - "allocation_mode=sglang:d1+megatron:d1", - "gconfig.n_samples=1", - "gconfig.max_new_tokens=128", - "actor.mb_spec.max_tokens_per_mb=2048", - "train_dataset.batch_size=4", - f"train_dataset.path={dataset_path}", - f"cluster.fileroot={str(experiments_path)}", - f"cluster.name_resolve.nfs_record_root={str(name_resolve_path)}", - f"actor.path={model_path}", - "n_trajs=2", - "max_tokens_per_trajectory=1024", - "max_llm_calls_per_run=2", - f"judge_engine.experiment_name={llm_judge_exp_name}", - f"judge_engine.trial_name={llm_judge_trial_name}", - ) - ) - if not success: - raise RuntimeError( - f"Search Agent DeepResearch example failed, return_code={return_code}" + try: + time.sleep(20) + + loop = asyncio.get_event_loop() + return_code, success = loop.run_until_complete( + run_example( + example_file, + config_name, + "allocation_mode=sglang:d1+megatron:d1", + "gconfig.n_samples=1", + "gconfig.max_new_tokens=128", + "actor.mb_spec.max_tokens_per_mb=2048", + "train_dataset.batch_size=4", + f"train_dataset.path={dataset_path}", + f"cluster.fileroot={str(experiments_path)}", + f"cluster.name_resolve.nfs_record_root={str(name_resolve_path)}", + f"actor.path={model_path}", + "n_trajs=2", + "max_tokens_per_trajectory=1024", + "max_llm_calls_per_run=2", + f"judge_engine.experiment_name={llm_judge_exp_name}", + f"judge_engine.trial_name={llm_judge_trial_name}", + ) ) - llm_judge_proc.terminate() - llm_judge_proc.wait(5) + if not success: + raise RuntimeError( + f"Search Agent DeepResearch example failed, return_code={return_code}" + ) + finally: + # Ensure cleanup happens even if test fails + llm_judge_proc.terminate() + try: + llm_judge_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + llm_judge_proc.kill() + llm_judge_proc.wait() @pytest.mark.multi_gpu diff --git a/areal/tests/test_fsdp_engine_nccl.py b/areal/tests/test_fsdp_engine_nccl.py index df36d0230..78125f90f 100644 --- a/areal/tests/test_fsdp_engine_nccl.py +++ b/areal/tests/test_fsdp_engine_nccl.py @@ -67,15 +67,27 @@ def sglang_server_nccl(): ) base_url = f"http://{HOST}:{PORT}" tik = time.time() - while time.time() - tik < RUN_SERVER_TIMEOUT: - if check_server_health(base_url): - break - time.sleep(1) - if time.time() - tik > RUN_SERVER_TIMEOUT: + try: + while time.time() - tik < RUN_SERVER_TIMEOUT: + if check_server_health(base_url): + break + time.sleep(1) + if time.time() - tik > RUN_SERVER_TIMEOUT: + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + raise RuntimeError("server launch failed") + yield + finally: process.terminate() - raise RuntimeError("server launch failed") - yield - process.terminate() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + process.kill() + process.wait() def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl): @@ -97,26 +109,33 @@ def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server optimizer=OptimizerConfig(), ) engine = FSDPEngine(engine_config) - engine.create_process_group() - ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) - engine.initialize(None, ft_spec) - - # Initialize RemoteSGLangEngine - config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) - config.server_addrs = [f"{HOST}:{PORT}"] - remote_engine = RemoteSGLangEngine(config) - remote_engine.initialize() - - # Get WeightUpdateMeta - meta = WeightUpdateMeta.from_fsdp_xccl( - AllocationMode.from_str("sglang.d1p1t1+d1p1t1"), - nccl_group_name=GROUP_NAME, - ) - - engine.connect_engine(remote_engine, meta) - - # Broadcast weights - engine.update_weights(meta) - print("uploaded weights to remote engine", flush=True) - remote_engine.destroy() - engine.destroy() + remote_engine = None + try: + engine.create_process_group() + ft_spec = FinetuneSpec( + total_train_epochs=1, dataset_size=100, train_batch_size=2 + ) + engine.initialize(None, ft_spec) + + # Initialize RemoteSGLangEngine + config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) + config.server_addrs = [f"{HOST}:{PORT}"] + remote_engine = RemoteSGLangEngine(config) + remote_engine.initialize() + + # Get WeightUpdateMeta + meta = WeightUpdateMeta.from_fsdp_xccl( + AllocationMode.from_str("sglang.d1p1t1+d1p1t1"), + nccl_group_name=GROUP_NAME, + ) + + engine.connect_engine(remote_engine, meta) + + # Broadcast weights + engine.update_weights(meta) + print("uploaded weights to remote engine", flush=True) + finally: + # Cleanup in reverse order + if remote_engine is not None: + remote_engine.destroy() + engine.destroy() diff --git a/areal/tests/test_inference_engines.py b/areal/tests/test_inference_engines.py index 724258c2c..b9082c48c 100644 --- a/areal/tests/test_inference_engines.py +++ b/areal/tests/test_inference_engines.py @@ -113,27 +113,38 @@ def inference_engine(request): ) base_url = f"http://{host}:{port}" tik = time.time() - while time.time() - tik < RUN_SERVER_TIMEOUT: - if check_server_health(base_url): - break - time.sleep(1) - if time.time() - tik > RUN_SERVER_TIMEOUT: + try: + while time.time() - tik < RUN_SERVER_TIMEOUT: + if check_server_health(base_url): + break + time.sleep(1) + if time.time() - tik > RUN_SERVER_TIMEOUT: + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + raise RuntimeError(f"{backend.upper()} server launch failed") + + # Set environment for remote engine + os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{host}:{port}" + + yield { + "engine_class": engine_class, + "expr_name": expr_name, + "trial_name": trial_name, + "host": host, + "port": port, + } + finally: + # Cleanup - ensure process is fully terminated process.terminate() - raise RuntimeError(f"{backend.upper()} server launch failed") - - # Set environment for remote engine - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{host}:{port}" - - yield { - "engine_class": engine_class, - "expr_name": expr_name, - "trial_name": trial_name, - "host": host, - "port": port, - } - - # Cleanup - process.terminate() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + process.kill() + process.wait() # ============================================================================ diff --git a/areal/tests/test_megatron_engine.py b/areal/tests/test_megatron_engine.py index 8014c4757..4110ded60 100644 --- a/areal/tests/test_megatron_engine.py +++ b/areal/tests/test_megatron_engine.py @@ -89,8 +89,10 @@ def engine(): engine.initialize(addr=None, ft_spec=ft_spec, parallel_strategy=alloc_mode.train) logger.info(f"mcore GPTModel initialized: {engine.model}") log_gpu_stats("initialize") - yield engine - engine.destroy() + try: + yield engine + finally: + engine.destroy() def test_simple_forward(engine, mock_input): From d0dfad72e615f5e425743861ed8c94cb2f9e8d6c Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Mon, 3 Nov 2025 13:06:18 +0800 Subject: [PATCH 45/52] fix vllm pp --- areal/tests/test_inference_engines.py | 1 + 1 file changed, 1 insertion(+) diff --git a/areal/tests/test_inference_engines.py b/areal/tests/test_inference_engines.py index b9082c48c..caaea0ce4 100644 --- a/areal/tests/test_inference_engines.py +++ b/areal/tests/test_inference_engines.py @@ -89,6 +89,7 @@ def inference_engine(request): vllm_args = vLLMConfig.build_args( vllm_config=vllm_config, tp_size=1, + pp_size=1, host=host, port=port, ) From 67499161383f18b28b745ce9bd9ac6b8ee1ee2b2 Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Mon, 3 Nov 2025 16:24:02 +0800 Subject: [PATCH 46/52] fix --- areal/engine/megatron_engine.py | 1 - areal/tests/test_megatron_engine.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 393b2f118..a43de5d1c 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -189,7 +189,6 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None if parallel_strategy is None: parallel_strategy = ParallelStrategy() assert not dist.is_initialized() - # TODO: Change engine_api.py and FSDPEngine API to seperate create_process_group # from engine initialize when moving out of experimental. self.parallel_strategy = self._make_parallel_strategy(parallel_strategy) # TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher diff --git a/areal/tests/test_megatron_engine.py b/areal/tests/test_megatron_engine.py index 4110ded60..75eeb736d 100644 --- a/areal/tests/test_megatron_engine.py +++ b/areal/tests/test_megatron_engine.py @@ -63,7 +63,8 @@ def mock_loss_fn(logits: torch.Tensor, input_data: dict) -> torch.Tensor: return torch.mean(logits) -@pytest.fixture(scope="module") +# Cannot use a "module" scope since process groups can only be initialized once. +@pytest.fixture def engine(): logger.info(f"megatron.core version={get_version('megatron.core')}") os.environ.update( @@ -93,6 +94,7 @@ def engine(): yield engine finally: engine.destroy() + engine.destroy_process_groups() def test_simple_forward(engine, mock_input): From 52921f2bd5c19c8b9f6c458f8a63bef3a4a40f30 Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Tue, 4 Nov 2025 12:27:06 +0800 Subject: [PATCH 47/52] . --- areal/tests/test_inference_engines.py | 65 +++++++++++-------- .../tests/test_megatron_engine_distributed.py | 2 + areal/tests/test_train_engine.py | 7 +- 3 files changed, 44 insertions(+), 30 deletions(-) diff --git a/areal/tests/test_inference_engines.py b/areal/tests/test_inference_engines.py index caaea0ce4..8695bc772 100644 --- a/areal/tests/test_inference_engines.py +++ b/areal/tests/test_inference_engines.py @@ -282,31 +282,40 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, inference_engine ) train_engine = FSDPEngine(engine_config) train_engine.create_process_group() - ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) - train_engine.initialize(None, ft_spec) - train_engine.model_version = 100 - - # setup name resolve - import areal.utils.name_resolve as name_resolve - from areal.api.cli_args import NameResolveConfig - - nfs_record_root = tmp_path_factory.mktemp("nfs_record_path") - name_resolve_config = NameResolveConfig(type="nfs", nfs_record_root=nfs_record_root) - name_resolve.reconfigure(name_resolve_config) - - config = InferenceEngineConfig( - experiment_name=inference_engine["expr_name"], - trial_name=inference_engine["trial_name"], - ) - # initialize inference engine - inf_engine = inference_engine["engine_class"](config) - inf_engine.initialize() - inf_engine.set_version(100) - - # test update weights - path = tmp_path_factory.mktemp("update_weights_from_disk") - update_weight_meta = WeightUpdateMeta(type="disk", path=str(path)) - train_engine.connect_engine(inf_engine, update_weight_meta) - train_engine.set_version(100) - train_engine.update_weights(update_weight_meta) - inf_engine.destroy() + inf_engine = None + try: + ft_spec = FinetuneSpec( + total_train_epochs=1, dataset_size=100, train_batch_size=2 + ) + train_engine.initialize(None, ft_spec) + train_engine.model_version = 100 + + # setup name resolve + import areal.utils.name_resolve as name_resolve + from areal.api.cli_args import NameResolveConfig + + nfs_record_root = tmp_path_factory.mktemp("nfs_record_path") + name_resolve_config = NameResolveConfig( + type="nfs", nfs_record_root=nfs_record_root + ) + name_resolve.reconfigure(name_resolve_config) + + config = InferenceEngineConfig( + experiment_name=inference_engine["expr_name"], + trial_name=inference_engine["trial_name"], + ) + # initialize inference engine + inf_engine = inference_engine["engine_class"](config) + inf_engine.initialize() + inf_engine.set_version(100) + + # test update weights + path = tmp_path_factory.mktemp("update_weights_from_disk") + update_weight_meta = WeightUpdateMeta(type="disk", path=str(path)) + train_engine.connect_engine(inf_engine, update_weight_meta) + train_engine.set_version(100) + train_engine.update_weights(update_weight_meta) + finally: + train_engine.destroy() + if inf_engine is not None: + inf_engine.destroy() diff --git a/areal/tests/test_megatron_engine_distributed.py b/areal/tests/test_megatron_engine_distributed.py index 3e739f626..0e60df0ee 100644 --- a/areal/tests/test_megatron_engine_distributed.py +++ b/areal/tests/test_megatron_engine_distributed.py @@ -77,6 +77,7 @@ def test_qwen3moe_expert_parallel(tmp_path_factory): @pytest.mark.multi_gpu +@pytest.mark.slow def test_qwen3_dcp_save_load(tmp_path_factory): if current_platform.device_count() < 8: pytest.skip("DCP save load requires 8 GPUs to run") @@ -90,6 +91,7 @@ def test_qwen3_dcp_save_load(tmp_path_factory): @pytest.mark.multi_gpu +@pytest.mark.slow def test_qwen3moe_dcp_save_load(tmp_path_factory): if current_platform.device_count() < 8: pytest.skip("Qwen3 MoE DCP save load requires 8 GPUs to run") diff --git a/areal/tests/test_train_engine.py b/areal/tests/test_train_engine.py index a8606bd16..fceefb80d 100644 --- a/areal/tests/test_train_engine.py +++ b/areal/tests/test_train_engine.py @@ -72,7 +72,7 @@ def mock_loss_fn(logits: torch.Tensor, input_data: dict) -> torch.Tensor: return torch.mean(logits) -@pytest.fixture(scope="module", params=["fsdp"]) +@pytest.fixture(params=["fsdp"]) def engine(request): os.environ.update( { @@ -86,7 +86,10 @@ def engine(request): engine = get_engine(request.param, MODEL_PATH) print(f"✓ {request.param.upper()} Engine created successfully") - yield engine + try: + yield engine + finally: + engine.destroy() @torch.no_grad() From 9258e2e691b61c09c9eddf8e212428b5bf756f76 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Tue, 4 Nov 2025 12:28:12 +0800 Subject: [PATCH 48/52] . --- areal/tests/test_estimate_num_params.py | 1 + areal/tests/test_fsdp_engine_nccl.py | 2 ++ areal/tests/test_inference_engines.py | 3 +++ 3 files changed, 6 insertions(+) diff --git a/areal/tests/test_estimate_num_params.py b/areal/tests/test_estimate_num_params.py index db6a5cc81..fc8cfddbf 100644 --- a/areal/tests/test_estimate_num_params.py +++ b/areal/tests/test_estimate_num_params.py @@ -57,3 +57,4 @@ def test_estimate_num_params(model_name_or_path): finally: mpu.destroy_model_parallel() dist.destroy_process_group() + assert not dist.is_initialized() diff --git a/areal/tests/test_fsdp_engine_nccl.py b/areal/tests/test_fsdp_engine_nccl.py index 78125f90f..314868a57 100644 --- a/areal/tests/test_fsdp_engine_nccl.py +++ b/areal/tests/test_fsdp_engine_nccl.py @@ -5,6 +5,7 @@ import pytest import requests +import torch.distributed as dist from areal.api.alloc_mode import AllocationMode from areal.api.cli_args import ( @@ -139,3 +140,4 @@ def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server if remote_engine is not None: remote_engine.destroy() engine.destroy() + assert not dist.is_initialized() diff --git a/areal/tests/test_inference_engines.py b/areal/tests/test_inference_engines.py index caaea0ce4..3a8fb7982 100644 --- a/areal/tests/test_inference_engines.py +++ b/areal/tests/test_inference_engines.py @@ -7,6 +7,7 @@ import pytest import requests +import torch.distributed as dist from areal.api.cli_args import ( GenerationHyperparameters, @@ -191,6 +192,7 @@ def test_rollout(inference_engine, n_samples): bs = get_batch_size(result) assert bs == 2 * n_samples engine.destroy() + assert not dist.is_initialized() @pytest.mark.parametrize("ofp", [0, 1, 4, 16]) @@ -256,6 +258,7 @@ def test_staleness_control(inference_engine, bs, ofp, n_samples): assert results["attention_mask"].shape[0] == bs * 2 * n_samples engine.destroy() + assert not dist.is_initialized() @pytest.mark.slow From 4443c9b0c6b34cfba6e16e0e86be92cd33158740 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Tue, 4 Nov 2025 12:29:09 +0800 Subject: [PATCH 49/52] . --- areal/tests/test_inference_engines.py | 1 + 1 file changed, 1 insertion(+) diff --git a/areal/tests/test_inference_engines.py b/areal/tests/test_inference_engines.py index a86ef4a3a..829943330 100644 --- a/areal/tests/test_inference_engines.py +++ b/areal/tests/test_inference_engines.py @@ -322,3 +322,4 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, inference_engine train_engine.destroy() if inf_engine is not None: inf_engine.destroy() + assert not dist.is_initialized() From ac6a11af6d393d2b050c1927488fb633ed0b57e6 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Tue, 4 Nov 2025 12:33:31 +0800 Subject: [PATCH 50/52] add assertion --- areal/tests/test_megatron_engine.py | 3 ++- areal/tests/test_packed_vs_padded_consistency.py | 3 +++ areal/tests/test_train_engine.py | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/areal/tests/test_megatron_engine.py b/areal/tests/test_megatron_engine.py index 75eeb736d..8f861cc91 100644 --- a/areal/tests/test_megatron_engine.py +++ b/areal/tests/test_megatron_engine.py @@ -5,6 +5,7 @@ import pytest import torch +import torch.distributed as dist from transformers import AutoTokenizer from areal.api.alloc_mode import AllocationMode @@ -94,7 +95,7 @@ def engine(): yield engine finally: engine.destroy() - engine.destroy_process_groups() + assert not dist.is_initialized() def test_simple_forward(engine, mock_input): diff --git a/areal/tests/test_packed_vs_padded_consistency.py b/areal/tests/test_packed_vs_padded_consistency.py index 6427fd318..b6cb1ad89 100644 --- a/areal/tests/test_packed_vs_padded_consistency.py +++ b/areal/tests/test_packed_vs_padded_consistency.py @@ -3,6 +3,7 @@ import pytest import torch +import torch.distributed as dist from torch.testing import assert_close from areal.api.cli_args import TrainEngineConfig @@ -103,6 +104,7 @@ def test_llm_consistency(model_path, mock_padded_llm_data): assert_close(x1, x2, atol=2e-1, rtol=2e-1) finally: engine.destroy() + assert not dist.is_initialized() QWEN25_VL_PATH = "/storage/openpsi/models/Qwen2.5-VL-3B-Instruct" @@ -272,3 +274,4 @@ def test_vlm_consistency(model_path): assert_close(x1, x2, atol=2e-1, rtol=2e-1) finally: engine.destroy() + assert not dist.is_initialized() diff --git a/areal/tests/test_train_engine.py b/areal/tests/test_train_engine.py index fceefb80d..1a5032782 100644 --- a/areal/tests/test_train_engine.py +++ b/areal/tests/test_train_engine.py @@ -5,6 +5,7 @@ import pytest import torch +import torch.distributed as dist from transformers import AutoTokenizer from areal.api.cli_args import MicroBatchSpec, OptimizerConfig, TrainEngineConfig @@ -90,6 +91,7 @@ def engine(request): yield engine finally: engine.destroy() + assert not dist.is_initialized() @torch.no_grad() From 54fce97930a268c831e59af25b0224f3f5e83d8b Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Tue, 4 Nov 2025 14:07:13 +0800 Subject: [PATCH 51/52] revert and fix --- areal/api/cli_args.py | 2 +- areal/api/engine_api.py | 10 ++-------- areal/controller/__init__.py | 2 ++ areal/core/remote_inf_engine.py | 33 ++++++++++++++++++++++++--------- areal/core/workflow_executor.py | 3 --- areal/engine/sglang_remote.py | 29 ++++++++++++----------------- areal/engine/vllm_remote.py | 22 +++++++++++++--------- docs/cli_reference.md | 2 +- 8 files changed, 55 insertions(+), 48 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index affdfbd93..c4311148f 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -871,7 +871,7 @@ class InferenceEngineConfig: ) queue_size: None | int = field( default=None, - metadata={"help": "(Deprecated) Input/Output queue size for async rollout."}, + metadata={"help": "Input/Output queue size for async rollout."}, ) consumer_batch_size: int = field( default=1, diff --git a/areal/api/engine_api.py b/areal/api/engine_api.py index 82575906e..c9cde9316 100644 --- a/areal/api/engine_api.py +++ b/areal/api/engine_api.py @@ -21,9 +21,6 @@ class TrainEngine(abc.ABC): - def configure(self, config): - raise NotImplementedError() - def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): """Initialize PyTorch distributed communication groups. @@ -321,13 +318,10 @@ def forward( class InferenceEngine(abc.ABC): - def configure(self, config): - raise NotImplementedError() - - def create_engine(self, engine_args: dict[str, Any]): + def launch_server(self, server_args: dict[str, Any]) -> str: raise NotImplementedError() - def destroy_engine(self): + def teardown_server(self): raise NotImplementedError() def initialize(self, *args, **kwargs): diff --git a/areal/controller/__init__.py b/areal/controller/__init__.py index d6905531d..aae841297 100644 --- a/areal/controller/__init__.py +++ b/areal/controller/__init__.py @@ -2,8 +2,10 @@ from areal.controller.batch import DistributedBatchMemory from areal.controller.rollout_controller import RolloutController +from areal.controller.train_controller import TrainController __all__ = [ "DistributedBatchMemory", "RolloutController", + "TrainController", ] diff --git a/areal/core/remote_inf_engine.py b/areal/core/remote_inf_engine.py index 6d66186c8..0340172bf 100644 --- a/areal/core/remote_inf_engine.py +++ b/areal/core/remote_inf_engine.py @@ -2,6 +2,7 @@ import os import random import shutil +import subprocess import time import uuid from collections.abc import Callable @@ -31,6 +32,7 @@ from areal.utils import logging, name_resolve, names from areal.utils.http import arequest_with_retry, get_default_connector from areal.utils.launcher import wait_llm_server_addrs +from areal.utils.network import find_free_ports, gethostip from .workflow_executor import WorkflowExecutor @@ -51,6 +53,8 @@ class RemoteInfBackendProtocol(Protocol): Implementations can raise NotImplementedError for unsupported features. """ + def launch_server(self, server_args) -> subprocess.Popen: ... + def build_generation_request( self, req: ModelRequest, with_lora: bool ) -> HttpRequest: @@ -229,15 +233,25 @@ def __init__( self.workflow_executor: WorkflowExecutor - def configure(self, config): - self.config = config - - def create_engine(self, engine_args): - # remote inference engine does not need to create an engine - return - - def destroy_engine(self): - return + self.server_process: subprocess.Popen | None = None + + def launch_server(self, server_args): + server_args["host"] = host_ip = gethostip() + server_args["port"] = server_port = find_free_ports(1)[0] + self.server_process = self.backend.launch_server(server_args=server_args) + address = f"{host_ip}:{server_port}" + self._wait_for_server(address) + return address + + def teardown_server(self): + if self.server_process is not None: + self.server_process.terminate() + try: + self.server_process.wait(timeout=5) + except TimeoutError: + self.server_process.kill() + self.server_process.wait() + self.server_process = None def _wait_for_server(self, address): """Wait for a server to become healthy.""" @@ -339,6 +353,7 @@ def destroy(self): if getattr(self, "executor"): self.executor.shutdown() self.executor = None + self.teardown_server() def set_version(self, version): """Set the current weight version.""" diff --git a/areal/core/workflow_executor.py b/areal/core/workflow_executor.py index 184eec7a2..eb668c15f 100644 --- a/areal/core/workflow_executor.py +++ b/areal/core/workflow_executor.py @@ -223,9 +223,6 @@ class _RolloutResult: request_id: int | None = None -TASK_RUNNER_MAX_QSIZE = 4096 - - class WorkflowExecutor: """Executor for asynchronous workflow-based rollout generation. diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 37d973176..ad998e0b4 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -1,4 +1,4 @@ -import os +import subprocess from collections.abc import Callable from concurrent.futures import Future from typing import Any, Optional @@ -18,14 +18,19 @@ ) from areal.api.workflow_api import RolloutWorkflow from areal.core import RemoteInfEngine -from areal.launcher.sglang_server import launch_server_cmd, wait_for_server +from areal.launcher.sglang_server import launch_server_cmd from areal.platforms import current_platform -from areal.utils.network import find_free_ports, gethostip class SGLangBackend: """SGLang-specific backend implementation for remote inference.""" + def launch_server(self, server_args: dict[str, Any]) -> subprocess.Popen: + # FIXME: avoid circular import + + cmd = SGLangConfig.build_cmd_from_args(server_args) + return launch_server_cmd(cmd) + def build_generation_request( self, req: ModelRequest, with_lora: bool ) -> HttpRequest: @@ -191,21 +196,11 @@ def __init__(self, config: InferenceEngineConfig): # Pure composition - create internal engine with SGLang backend self._engine = RemoteInfEngine(config, SGLangBackend()) - def create_engine(self, engine_args): - engine_args["host"] = host_ip = gethostip() - engine_args["port"] = server_port = find_free_ports(1)[0] - cmd = SGLangConfig.build_cmd_from_args(engine_args) - self.server_process = launch_server_cmd(cmd) - wait_for_server(f"http://{host_ip}:{server_port}") - print(f"SGLang server launched at: http://{host_ip}:{server_port}") - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{host_ip}:{server_port}" - - def configure(self, config): - self.config = config - self._engine.configure(config) + def launch_server(self, server_args: dict[str, Any]): + return self._engine.launch_server(server_args) - def destroy_engine(self, *args, **kwargs): - return self._engine.destroy_engine(*args, **kwargs) + def teardown_server(self): + return self._engine.teardown_server() def initialize( self, diff --git a/areal/engine/vllm_remote.py b/areal/engine/vllm_remote.py index 754ea3ee3..65d1afa6e 100644 --- a/areal/engine/vllm_remote.py +++ b/areal/engine/vllm_remote.py @@ -1,10 +1,11 @@ +import subprocess from collections.abc import Callable from concurrent.futures import Future from typing import Any, Optional from torchdata.stateful_dataloader import StatefulDataLoader -from areal.api.cli_args import InferenceEngineConfig +from areal.api.cli_args import InferenceEngineConfig, vLLMConfig from areal.api.engine_api import InferenceEngine from areal.api.io_struct import ( HttpGenerationResult, @@ -23,6 +24,13 @@ class VLLMBackend: """vLLM-specific backend implementation for remote inference.""" + def launch_server(self, server_args: dict[str, Any]) -> subprocess.Popen: + # FIXME: avoid circular import + from areal.launcher.vllm_server import launch_server_cmd + + cmd = vLLMConfig.build_cmd_from_args(server_args) + return launch_server_cmd(cmd) + def build_generation_request( self, req: ModelRequest, with_lora: bool ) -> HttpRequest: @@ -159,15 +167,11 @@ def __init__(self, config: InferenceEngineConfig): # Pure composition - create internal engine with vLLM backend self._engine = RemoteInfEngine(config, VLLMBackend()) - def configure(self, config): - self.config = config - self._engine.configure(config) - - def create_engine(self, *args, **kwargs): - return self._engine.create_engine(*args, **kwargs) + def launch_server(self, server_args: dict[str, Any]): + return self._engine.launch_server(server_args) - def destroy_engine(self, *args, **kwargs): - return self._engine.destroy_engine(*args, **kwargs) + def teardown_server(self): + return self._engine.teardown_server() def initialize( self, diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 3f2526a6d..b8f8faf3f 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -474,7 +474,7 @@ Configuration for inference servers, including offpolicyness control. | `experiment_name` | string \| None | `None` | - | | `trial_name` | string \| None | `None` | - | | `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | -| `queue_size` | integer \| None | `None` | (Deprecated) Input/Output queue size for async rollout. | +| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | | `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | | `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | | `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | From 749e0475d7b01186f85d3c9721d63ef74ad91e12 Mon Sep 17 00:00:00 2001 From: garrett4wade Date: Mon, 10 Nov 2025 20:45:29 +0800 Subject: [PATCH 52/52] minor revert --- areal/api/engine_api.py | 3 - areal/core/remote_inf_engine.py | 2 - areal/engine/ppo/actor.py | 100 +++++++++++++++++--------------- pyproject.toml | 1 - 4 files changed, 53 insertions(+), 53 deletions(-) diff --git a/areal/api/engine_api.py b/areal/api/engine_api.py index db459f9e0..230dfe4c1 100644 --- a/areal/api/engine_api.py +++ b/areal/api/engine_api.py @@ -34,9 +34,6 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None """ raise NotImplementedError() - def destroy_process_group(self): - raise NotImplementedError() - def initialize(self, *args, **kwargs): """Initialize environments for distributed training and load models. diff --git a/areal/core/remote_inf_engine.py b/areal/core/remote_inf_engine.py index 95b56467c..0277d330e 100644 --- a/areal/core/remote_inf_engine.py +++ b/areal/core/remote_inf_engine.py @@ -249,8 +249,6 @@ def __init__( self.workflow_executor: WorkflowExecutor self.local_server_processes: list[LocalInfServerInfo] = [] - self.server_process: subprocess.Popen | None = None - def _wait_for_server(self, address): """Wait for a server to become healthy.""" base_url = f"http://{address}" diff --git a/areal/engine/ppo/actor.py b/areal/engine/ppo/actor.py index 7d37175f1..5b2cd804c 100644 --- a/areal/engine/ppo/actor.py +++ b/areal/engine/ppo/actor.py @@ -20,7 +20,6 @@ ppo_actor_loss_fn, reward_overlong_penalty, ) -from areal.utils.perf_tracer import trace_perf logger = logging.getLogger(__name__) @@ -67,7 +66,6 @@ def __init__(self, config: PPOActorConfig, engine: TrainEngine): logger.info(f" eps_clip: {config.eps_clip}") logger.info(f" group_size: {config.group_size}") - @trace_perf("ppo_actor.compute_logp", category="compute") @torch.no_grad() def compute_logp( self, @@ -89,7 +87,6 @@ def calc_logprobs(logits, input_data): aggregate_fn=lambda xs: torch.cat(xs, dim=-1), ) - @trace_perf("ppo_actor.compute_advantages", category="compute") def compute_advantages(self, data: dict[str, Any]) -> dict[str, Any]: bs = data["input_ids"].shape[0] max_seqlen = data["input_ids"].shape[1] @@ -186,30 +183,28 @@ def compute_advantages(self, data: dict[str, Any]) -> dict[str, Any]: return data - @trace_perf("ppo_actor.ppo_update", category="compute") - @stats_tracker.scope_func_wrapper("ppo_actor") - def ppo_update(self, data: dict[str, Any]) -> None: - with stats_tracker.scope("dynamic_sampling"): - if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: - data, sampling_stat = dynamic_sampling(data, self.group_size) - stats_tracker.scalar(**sampling_stat) + def ppo_update(self, data: dict[str, Any]) -> list[dict[str, float]]: + if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: + data, sampling_stat = dynamic_sampling(data, self.group_size) attn_mask = data["attention_mask"] loss_mask = data["loss_mask"] reward_score = data["rewards"] seqlens = attn_mask.sum(-1) + all_stats = [] ########## Logging code starts ########## - with stats_tracker.scope("ppo_actor"): - result_denominators = { - "correct_n_seqs": (reward_score > 0).bool(), - "incorrect_n_seqs": (reward_score <= 0).bool(), - } - if self.config.log_agent_stats: - assert "begin_of_trajectory" in data, ( + result_denominators = { + "correct_n_seqs": (reward_score > 0).bool(), + "incorrect_n_seqs": (reward_score <= 0).bool(), + } + if self.config.log_agent_stats: + if "begin_of_trajectory" not in data: + raise RuntimeError( "'begin_of_trajectory' is expected to log agent statistics" ) - assert len(self.config.log_agent_stats_keys) > 0, ( + if len(self.config.log_agent_stats_keys) == 0: + raise RuntimeError( "`log_agent_stats_keys` should not be empty when log_agent_stats=True" ) agent_denominator = (data["begin_of_trajectory"] > 0).bool() @@ -235,6 +230,7 @@ def ppo_update(self, data: dict[str, Any]) -> None: ) stats_tracker.stat(**stats, denominator="n_valid_tokens") + prompt_lens = [] prompt_lens = data["attention_mask"].sum(-1) - data["loss_mask"].sum(-1) seq_stats = dict( no_eos_ratios=(seqlens == attn_mask.shape[-1]).float(), @@ -258,11 +254,18 @@ def ppo_update(self, data: dict[str, Any]) -> None: if self.config.log_agent_stats: stats_tracker.stat( - correct_seq_len=seqlens.float(), denominator="correct_n_seqs" - ) - stats_tracker.stat( - incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs" + **{k: data[k].float() for k in self.config.log_agent_stats_keys}, + denominator="agent", ) + + global_stats = stats_tracker.export( + reduce_group=self.engine.data_parallel_group + ) + for k in global_denominators: + keys = list(global_stats.keys()) + for k2 in keys: + if k2.endswith(k): + global_stats.pop(k2) ########## Logging code ends ########## for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]: @@ -273,24 +276,27 @@ def ppo_update(self, data: dict[str, Any]) -> None: data, mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches), ) - - with stats_tracker.scope("update"): - for mb in mb_inputs.mbs: - train_stat = self.engine.train_batch( - mb, - loss_fn=functools.partial( - grpo_loss_fn, - temperature=self.temperature, - eps_clip=self.config.eps_clip, - eps_clip_higher=self.config.eps_clip_higher, - c_clip=self.config.c_clip, - behav_imp_weight_cap=self.config.behav_imp_weight_cap, - m2_threshold=self.m2_threshold, - importance_sampling_level=self.config.importance_sampling_level, - ), - loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), - ) - stats_tracker.scalar(**train_stat) + for mb in mb_inputs.mbs: + train_stat = self.engine.train_batch( + mb, + loss_fn=functools.partial( + grpo_loss_fn, + temperature=self.temperature, + eps_clip=self.config.eps_clip, + eps_clip_higher=self.config.eps_clip_higher, + c_clip=self.config.c_clip, + behav_imp_weight_cap=self.config.behav_imp_weight_cap, + m2_threshold=self.m2_threshold, + importance_sampling_level=self.config.importance_sampling_level, + ), + loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), + ) + stats_tracker.scalar(**train_stat) + all_stats.append( + stats_tracker.export(reduce_group=self.engine.data_parallel_group) + ) + all_stats[0].update(global_stats) + return all_stats class FSDPPPOActor(FSDPEngine): @@ -303,11 +309,11 @@ def compute_logp(self, *args, **kwargs) -> torch.Tensor | None: return self.actor.compute_logp(*args, **kwargs) @torch.no_grad() - def compute_advantages(self, *args, **kwargs): + def compute_advantages(self, *args, **kwargs) -> dict[str, Any]: return self.actor.compute_advantages(*args, **kwargs) - def ppo_update(self, *args, **kwargs) -> None: - self.actor.ppo_update(*args, **kwargs) + def ppo_update(self, *args, **kwargs) -> list[dict[str, float]]: + return self.actor.ppo_update(*args, **kwargs) class MegatronPPOActor(MegatronEngine): @@ -320,11 +326,11 @@ def compute_logp(self, *args, **kwargs) -> torch.Tensor | None: return self.actor.compute_logp(*args, **kwargs) @torch.no_grad() - def compute_advantages(self, *args, **kwargs) -> None: - self.actor.compute_advantages(*args, **kwargs) + def compute_advantages(self, *args, **kwargs) -> dict[str, Any]: + return self.actor.compute_advantages(*args, **kwargs) - def ppo_update(self, *args, **kwargs) -> None: - self.actor.ppo_update(*args, **kwargs) + def ppo_update(self, *args, **kwargs) -> list[dict[str, float]]: + return self.actor.ppo_update(*args, **kwargs) def grpo_loss_fn( diff --git a/pyproject.toml b/pyproject.toml index 2bfbb4502..32906b0fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -209,7 +209,6 @@ filterwarnings = [ "ignore::UserWarning:transformers.*", ] markers = [ - "integration: marks tests as integration tests (real processes, slower)", "slow: mark test as slow, expected to cost more than 30 seconds and will not run in CI by default.", "ci: mark test as must-run in CI (only marked for slow tests).", "gpu: mark test that uses a single GPU",