diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 8255e046a..a8acdd9ed 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -372,6 +372,52 @@ class MegatronEngineConfig: recompute_modules: list[str] | None = None +@dataclass +class SchedulingStrategy: + type: str = field( + default="separation", metadata={"choices": ["separation", "colocation"]} + ) + target: str | None = field( + default=None, metadata={"help": "The target role to be colocated with"} + ) + + +@dataclass +class SchedulingSpec: + cpu: int = field(default=0, metadata={"help": "Number of CPU cores required"}) + gpu: int = field(default=0, metadata={"help": "Number of GPU units required"}) + mem: int = field(default=0, metadata={"help": "Amount of memory (GB) required"}) + port_count: int = field(default=2, metadata={"help": "Number of ports to expose"}) + image: str = field( + default="", metadata={"help": "Docker/Singularity container image to use"} + ) + type: str = field( + default="worker", + metadata={ + "help": "Task type (e.g., worker, engine)", + "choices": ["worker", "engine"], + }, + ) + env_vars: dict[str, str] = field( + default_factory=dict, + metadata={"help": "Environment variables for the container"}, + ) + # cmd + cmd: str | None = field( + default=None, + metadata={ + "help": "Command to execute inside the container. Defaults to AReaL's RPC server." + }, + ) + # slurm configurations from "https://slurm.schedmd.com/sbatch.html" + nodelist: str | None = None + exclude: str | None = None + partition: str | None = None + time_limit: str | None = None # see "--time" option for format + begin: str | None = None # see "--begin" option for format + deadline: str | None = None # see "--deadline" option for format + + @dataclass class TrainEngineConfig: """Core configuration for model training, including optimization and backend settings.""" @@ -442,6 +488,13 @@ class TrainEngineConfig: default="lora", metadata={"help": "peft method type. Only LoRA is supported for now."}, ) + scheduling_spec: SchedulingSpec = field( + default_factory=lambda: SchedulingSpec( + cmd="python -m areal.scheduler.rpc.rpc_server" + ), + metadata={"help": "train engine schedule specs"}, + ) + scheduling_strategy: SchedulingStrategy = field(default_factory=SchedulingStrategy) @dataclass @@ -924,6 +977,13 @@ class InferenceEngineConfig: "help": "The grace period after calling /pause_generation. Wait until all requests have been dropped." }, ) + scheduling_spec: SchedulingSpec = field( + default_factory=lambda: SchedulingSpec( + cmd="python -m areal.scheduler.rpc.rpc_server" + ), + metadata={"help": "inference engine schedule specs"}, + ) + scheduling_strategy: SchedulingStrategy = field(default_factory=SchedulingStrategy) @dataclass diff --git a/areal/api/scheduler_api.py b/areal/api/scheduler_api.py index 5d2076212..7ad616ccb 100644 --- a/areal/api/scheduler_api.py +++ b/areal/api/scheduler_api.py @@ -1,68 +1,173 @@ import abc from dataclasses import dataclass, field +from typing import Any + +from areal.api.cli_args import SchedulingSpec, SchedulingStrategy @dataclass class Worker: + """ + Represents a worker process in the distributed system. + + Attributes: + id: Unique identifier for the worker (e.g., "rollout/0", "actor/1"). + ip: IP address where the worker is running. + ports: List of port numbers (as strings) allocated to this worker for RPC communication. + """ + id: str + # worker and engine deploy on the same machine, so ip are the same ip: str - ports: list[str] = field(default_factory=list) + worker_ports: list[str] = field(default_factory=list) + engine_ports: list[str] = field(default_factory=list) @dataclass -class ContainerSpec: - cpu: int = 0 - gpu: int = 0 - mem: int = 0 - container_image: str = "" - cmd: str = "" - env_vars: dict[str, str] = field(default_factory=dict) - port_count: int = 2 +class Job: + replicas: int = 0 + tasks: list[SchedulingSpec] = field(default_factory=list) + scheduling_strategy: SchedulingStrategy | None = None + role: str = "" -@dataclass -class ScheduleStrategy: - type: str = "" - uid: str = "" +class Scheduler(abc.ABC): + """ + Abstract base class for schedulers that manage distributed worker processes. + A scheduler is responsible for: + - Creating and managing worker processes/containers. + - Allocating resources (GPUs, ports, memory). + - Creating and managing engine instances on workers. + - Facilitating RPC calls to engine methods. + """ -@dataclass -class SchedulingConfig: - replicas: int = 0 - specs: list[ContainerSpec] = field(default_factory=list) - schedule_strategy: ScheduleStrategy | None = None - role: str = "" + @abc.abstractmethod + def create_workers(self, job: Job, *args, **kwargs) -> list[str]: + """ + Create and start worker processes for a specific role. + Args: + scheduler_config: Configuration specifying replicas, resources, and scheduling strategy. + *args: Additional positional arguments (implementation-specific). + **kwargs: Additional keyword arguments (implementation-specific). -class Scheduler(abc.ABC): - def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> str: + Returns: + List of worker IDs created (e.g., ["rollout/0", "rollout/1"]). + + Raises: + WorkerCreationError: If worker creation fails. + ValueError: If scheduler_config is invalid. """ - Start workers, return job id + raise NotImplementedError() + + @abc.abstractmethod + def get_workers(self, role: str, timeout: int | None = None) -> list[Worker]: + """ + Wait for workers to be ready and return their information. + + This method blocks until all workers for the specified role are ready + to accept RPC requests, or until the timeout is reached. + + Args: + role: Role name to query (e.g., "rollout", "actor"). + timeout: Maximum time to wait in seconds. None means use the default timeout. + + Returns: + List of Worker objects containing worker ID, IP address, and allocated ports. + + Raises: + WorkerNotFoundError: If no workers exist for the specified role. + WorkerFailedError: If any worker process has failed. + WorkerTimeoutError: If timeout is exceeded while waiting for workers. """ + raise NotImplementedError() - def get_workers(self, worker_key, timeout=None) -> list[Worker]: + @abc.abstractmethod + def delete_workers(self, role: str | None = None): """ - Wait and return worker list, including scheduling results such as ip and engine ports - (worker id, ip, ports) + Stop and clean up worker processes. + + Args: + role: Specific role to delete. If None, all workers are deleted. + + Raises: + WorkerNotFoundError: If the specified role doesn't exist. + + Note: + This method should gracefully terminate workers and clean up resources. + It should not raise an exception if workers have already stopped. """ raise NotImplementedError() - def delete_workers(self): - """stop all workers + @abc.abstractmethod + async def create_engine(self, worker_id: str, engine: str, *args, **kwargs) -> Any: + """ + Create an engine instance on a remote worker. + + The engine parameter is a string import path (e.g., "areal.engine.ppo.actor.FSDPPPOActor") + that will be dynamically imported and instantiated on the worker. + + Args: + worker_id: ID of the worker to create the engine on (e.g., "rollout/0"). + engine: Import path to the engine class (e.g., "areal.engine.ppo.actor.FSDPPPOActor"). + *args: Positional arguments passed to engine initialization. + **kwargs: Keyword arguments passed to engine initialization. - Raises exception if there is no such job, but passes if the job - has stopped either successfully or not. + Returns: + Result from engine initialization. + + Raises: + WorkerNotFoundError: If the specified worker doesn't exist. + WorkerFailedError: If the worker process has failed. + EngineCreationError: If engine creation or initialization fails. """ raise NotImplementedError() - async def create_engine(self, worker_id, engine_obj, *args, **kwargs): + @abc.abstractmethod + def call_engine(self, worker_id: str, method: str, *args, **kwargs) -> Any: """ - Create engine instance remotely + Call a method on an engine instance running on a worker (data plane operation). + + This is the synchronous version. Use `async_call_engine` for async operations. + + Args: + worker_id: ID of the worker hosting the engine (e.g., "rollout/0"). + method: Name of the method to call on the engine. + *args: Positional arguments to pass to the method. + **kwargs: Keyword arguments to pass to the method. + + Returns: + Result from the engine method call. + + Raises: + WorkerNotFoundError: If the specified worker doesn't exist. + WorkerFailedError: If the worker process has failed. + EngineCallError: If the method call fails. """ raise NotImplementedError() - def call_engine(self, worker_id, method, *args, **kwargs): + @abc.abstractmethod + async def async_call_engine( + self, worker_id: str, method: str, *args, **kwargs + ) -> Any: """ - Data plane call + Async version of call_engine for calling engine methods asynchronously. + + This is useful for concurrent operations or when integrating with async frameworks. + + Args: + worker_id: ID of the worker hosting the engine (e.g., "rollout/0"). + method: Name of the method to call on the engine. + *args: Positional arguments to pass to the method. + **kwargs: Keyword arguments to pass to the method. + + Returns: + Result from the engine method call. + + Raises: + WorkerNotFoundError: If the specified worker doesn't exist. + WorkerFailedError: If the worker process has failed. + EngineCallError: If the method call fails. """ raise NotImplementedError() diff --git a/areal/controller/__init__.py b/areal/controller/__init__.py new file mode 100644 index 000000000..aae841297 --- /dev/null +++ b/areal/controller/__init__.py @@ -0,0 +1,11 @@ +"""Controller components for managing distributed training and inference.""" + +from areal.controller.batch import DistributedBatchMemory +from areal.controller.rollout_controller import RolloutController +from areal.controller.train_controller import TrainController + +__all__ = [ + "DistributedBatchMemory", + "RolloutController", + "TrainController", +] diff --git a/areal/controller/batch.py b/areal/controller/batch.py index 66df87573..53cd3a776 100644 --- a/areal/controller/batch.py +++ b/areal/controller/batch.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Union +from typing import Any import torch from torch import Tensor @@ -9,6 +9,7 @@ convert_list_to_dict, validate_dict_dataset, ) +from areal.utils.data import concat_padded_tensors from areal.utils.datapack import ffd_allocate from areal.utils.errors import FrameworkError @@ -17,7 +18,7 @@ class DistributedBatchMemory(DistributedBatch): dataset = None @classmethod - def from_dict(cls, dict_dataset: Dict[str, Union[Tensor, Any]]): + def from_dict(cls, dict_dataset: dict[str, Tensor | Any]): """Create a DistributedBatchMemory from dictionary format dataset. Parameters @@ -36,7 +37,7 @@ def from_dict(cls, dict_dataset: Dict[str, Union[Tensor, Any]]): return instance @classmethod - def from_list(cls, list_dataset: List[Dict[str, Union[Tensor, Any]]]): + def from_list(cls, list_dataset: list[dict[str, Tensor | Any]]): """Create a DistributedBatchMemory from list format dataset. Parameters @@ -103,9 +104,9 @@ def chunk_by_ffd( List of DistributedBatchMemory objects """ total_size = self._get_total_size() - assert ( - total_size % group_size == 0 - ), "tensor length must be devided by group_size" + assert total_size % group_size == 0, ( + "tensor length must be devided by group_size" + ) # Handle seqlen calculation for both tensor and scalar types if "seqlen" in self.dataset.keys(): @@ -209,7 +210,7 @@ def _get_total_size(self) -> int: # For scalar values, assume it's a single sample return 1 - def get_data(self) -> Dict[str, Union[torch.Tensor, Any]]: + def get_data(self) -> dict[str, torch.Tensor | Any]: """Get all data from the DistributedBatchMemory. Returns @@ -253,24 +254,7 @@ def concat(data: list["DistributedBatchMemory"]) -> "DistributedBatchMemory": batch.dataset = {} return batch - merged_data = {} - for batch in data: - for k, v in batch.dataset.items(): - if k in merged_data: - if isinstance(merged_data[k], torch.Tensor) and isinstance( - v, torch.Tensor - ): - merged_data[k] = torch.cat([merged_data[k], v], dim=0) - elif isinstance(merged_data[k], list) and isinstance(v, list): - merged_data[k] = merged_data[k] + v - else: - # Handle mixed types or scalar values - if isinstance(merged_data[k], list): - merged_data[k].append(v) - else: - merged_data[k] = [merged_data[k], v] - else: - merged_data[k] = v + merged_data = concat_padded_tensors([k.dataset for k in data]) result = DistributedBatchMemory.__new__(DistributedBatchMemory) result.dataset = merged_data return result @@ -319,7 +303,7 @@ def __setitem__(self, key, value): self.dataset[key] = value else: raise FrameworkError( - "FrameworkError", "DistributedBatchMemoryError", f"key must be str" + "FrameworkError", "DistributedBatchMemoryError", "key must be str" ) def __delitem__(self, key): diff --git a/areal/controller/rollout_controller.py b/areal/controller/rollout_controller.py new file mode 100644 index 000000000..36b3d40a6 --- /dev/null +++ b/areal/controller/rollout_controller.py @@ -0,0 +1,735 @@ +from __future__ import annotations + +import asyncio +import queue +import random +import time +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import InferenceEngineConfig +from areal.api.controller_api import DistributedBatch +from areal.api.engine_api import InferenceEngine +from areal.api.io_struct import ModelRequest, ModelResponse, ParamSpec, WeightUpdateMeta +from areal.api.scheduler_api import Job, Scheduler, Worker +from areal.controller.batch import DistributedBatchMemory +from areal.core.async_task_runner import AsyncTaskRunner, TaskQueueFullError +from areal.core.staleness_manager import StalenessManager +from areal.utils import logging +from areal.utils.data import cycle_dataloader + +CREATE_WORKER_TIMEOUT = 60.0 +TASK_RUNNER_MAX_QSIZE = 32768 + + +@dataclass +class _RemoteRolloutTaskInput: + data: dict[str, Any] + workflow_path: str + workflow_kwargs: dict[str, Any] + should_accept_path: str | None = None + + +class RolloutController: + """A centralized controller that manages multiple distributed InferenceEngine workers for rollout generation. + + RolloutController orchestrates distributed inference workloads by scheduling and + dispatching requests across multiple concurrent InferenceEngine instances. It provides + intelligent load balancing, staleness control, and capacity management to optimize + rollout generation efficiency. + + Key features: + - Distributed request scheduling and load balancing across workers + - Centralized staleness and capacity control for consistent performance + - Asynchronous rollout generation with configurable acceptance criteria + - Data aggregation from heterogeneously loaded workers + + The controller handles workload imbalances inherent in rollout generation, where + different workers may produce varying amounts of data depending on the complexity + of their assigned tasks. Generated data is stored locally on workers and aggregated + into `DistributedBatch` objects for seamless integration with TrainController. + + Implementation details: + - Launches local inference engines on workers via scheduler + - Schedules requests to specific engines via round-robin + - Delegates actual execution to AsyncTaskRunner + - Aggregates results from workers into DistributedBatch + + Parameters + ---------- + inf_engine : type[InferenceEngine] + The inference engine class (not instance) to instantiate on each worker + config : InferenceEngineConfig + Configuration for inference engines + scheduler : Scheduler + Scheduler for worker management + """ + + def __init__( + self, + inf_engine: type[InferenceEngine], + config: InferenceEngineConfig, + scheduler: Scheduler, + ): + self.inf_engine = inf_engine + self.config = config + self.scheduler = scheduler + + # Worker management + self.workers: list[Worker] = [] # List of Worker objects from scheduler + self._worker_role: str + + # Round-robin scheduling + self._current_worker_idx = 0 + + # Async task execution + self.runner: AsyncTaskRunner | None = None + + # Logging + self.logger = None + + # State + self._version = 0 + + # Staleness management + self.staleness_manager: StalenessManager | None = None + + self._pending_results: list[dict[str, Any]] = [] + self._pending_inputs: list[_RemoteRolloutTaskInput] = [] + + def initialize( + self, + role: str, + alloc_mode: AllocationMode, + engine_args: dict[str, Any], + *args, + **kwargs, + ): + """Initialize environments and launch the background thread for asynchronous distributed inference. + + For remote inference engines, this serves as a client and connects to the inference servers. + For local inference engines, this creates an LLM engine on the local GPU. + + Parameters + ---------- + alloc_mode : AllocationMode + The allocation mode configuration for distributed setup + *args + Variable length argument list passed to engine initialization + **kwargs + Arbitrary keyword arguments passed to engine initialization + """ + self.logger = logging.getLogger("[RolloutController]") + + # Get scheduling config from kwargs or use defaults + self._worker_role = role + self.config.scheduling_spec.cpu *= alloc_mode.gen_instance_size + self.config.scheduling_spec.mem *= alloc_mode.gen_instance_size + self.config.scheduling_spec.gpu = alloc_mode.gen_instance_size + job = Job( + replicas=alloc_mode.gen.dp_size, + tasks=[self.config.scheduling_spec for _ in range(alloc_mode.gen.dp_size)], + scheduling_strategy=self.config.scheduling_strategy, + role=self._worker_role, + ) + + # Use asyncio.run to call async scheduler methods synchronously + asyncio.run( + self._async_initialize( + job, + engine_args, + *args, + **kwargs, + ) + ) + + # Initialize AsyncTaskRunner for task execution + self.runner = AsyncTaskRunner( + max_queue_size=TASK_RUNNER_MAX_QSIZE, + enable_tracing=self.config.enable_rollout_tracing, + ) + self.runner.initialize(logger=self.logger) + + # Initialize staleness manager for global capacity control + max_concurrent_rollouts = ( + self.config.max_concurrent_rollouts or self.config.consumer_batch_size + ) + consumer_batch_size = self.config.consumer_batch_size + + self.staleness_manager = StalenessManager( + max_concurrent_rollouts=max_concurrent_rollouts, + consumer_batch_size=consumer_batch_size, + max_staleness=self.config.max_head_offpolicyness, + ) + + async def _async_initialize( + self, job: Job, engine_args: dict[str, Any], *args, **kwargs + ): + # Create workers via scheduler + self.logger.info("Creating workers via scheduler...") + worker_ids = self.scheduler.create_workers(job=job) + self.logger.info(f"Workers created: {worker_ids}") + + # Wait for workers to be ready + self.logger.info("Waiting for workers to be ready...") + self.workers = self.scheduler.get_workers(role=job.role) + self.logger.info(f"Workers ready: {[w.id for w in self.workers]}") + + # Get engine class path for dynamic import on workers + engine_class = self.inf_engine + engine_path = f"{engine_class.__module__}.{engine_class.__name__}" + + # Create and initialize engines on workers + self.logger.info("Creating engines...") + tasks = [ + self.scheduler.create_engine( + worker_id=worker.id, + engine=engine_path, + config=self.config, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + self.logger.info("Engine created on all workers!") + + self.logger.info("Calling engine initialization...") + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, method="create_engine", engine_args=engine_args + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, method="initialize", *args, **kwargs + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + self.logger.info("All engines are initialized...") + + def destroy(self): + """Destroy the engine and release GPU memory for the local inference engine. + + This method cleans up all resources including workers, task runner, and thread pool. + """ + self.logger.info("Destroying RolloutController...") + + # Destroy task runner + if self.runner is not None: + self.runner.destroy() + self.runner = None + + # Delete workers via scheduler + try: + self.scheduler.delete_workers(role=self._worker_role) + self.logger.info("Workers deleted") + except Exception as e: + self.logger.error(f"Error deleting workers: {e}") + + self.workers.clear() + + self.logger.info("RolloutController destroyed") + + def get_capacity(self) -> int: + """Get current available capacity for new rollouts based on staleness. + + Returns + ------- + int + Number of new rollout slots available based on staleness constraints + """ + version = self.get_version() # Use controller's global version + return self.staleness_manager.get_capacity(version) + + def _choose_worker(self) -> Worker: + """Choose a worker for the next request using round-robin scheduling. + + Returns + ------- + Worker + The chosen worker object + """ + + worker = self.workers[self._current_worker_idx] + self._current_worker_idx = (self._current_worker_idx + 1) % len(self.workers) + return worker + + def submit( + self, + data: dict[str, Any], + workflow_path: str, + workflow_kwargs: dict[str, Any], + should_accept_path: str | None = None, + ) -> None: + """Submit a request to the inference engine and return immediately. + + Should be used together with subsequent `wait`. + + Parameters + ---------- + data : dict[str, Any] + The input data for rollout. Used by the user's customized workflow implementation. + workflow_path : str + The fully qualified path to the workflow class (e.g., "module.submodule.WorkflowClass"). + The workflow will be dynamically imported on the worker. + workflow_kwargs : dict[str, Any] + Keyword arguments to pass to the workflow constructor. + should_accept_path : str | None, optional + The fully qualified path to a function used to decide whether to accept a specific + trajectory (dynamic filtering). The function should take a complete trajectory + output by the workflow and return a bool, by default None. + """ + # Add to pending queue (will be submitted when capacity allows) + self._pending_inputs.append( + _RemoteRolloutTaskInput( + data=data, + workflow_kwargs=workflow_kwargs, + workflow_path=workflow_path, + should_accept_path=should_accept_path, + ) + ) + + async def _wait_callback(self, worker: Worker): + # Wait for a generation to return + result = "NO_RESULT" + tik = time.time() + while result == "NO_RESULT" and time.time() - tik < self.config.request_timeout: + result = await self.scheduler.async_call_engine( + worker.id, "wait_quiet", count=1, timeout=1, max_retries=1 + ) + + # The RPCServer will return None if the + # trajectory is rejected. + if result is not None: + self.staleness_manager.on_rollout_accepted() + if self.config.enable_rollout_tracing: + stat = self.staleness_manager.get_stats() + self.logger.info( + f"Finish and accept rollout. " + f"Submit: {stat.submitted}, " + f"running: {stat.running}, " + f"accepted: {stat.accepted}." + ) + return result + else: + self.staleness_manager.on_rollout_rejected() + if self.config.enable_rollout_tracing: + stat = self.staleness_manager.get_stats() + self.logger.info( + f"Finish but reject rollout. " + f"Submit: {stat.submitted}, " + f"running: {stat.running}, " + f"accepted: {stat.accepted}." + ) + return None + + def _commit_one_to_runner(self): + """Commit one pending input to task runner with staleness tracking.""" + task_input = self._pending_inputs.pop(0) + + # Choose worker via round-robin + worker = self._choose_worker() + + self.scheduler.call_engine( + worker.id, + "submit", + data=task_input.data, + workflow_path=task_input.workflow_path, + workflow_kwargs=task_input.workflow_kwargs, + should_accept_path=task_input.should_accept_path, + ) + + # Submit a wait callback to AsyncTaskRunner + try: + self.runner.submit( + self._wait_callback, + worker, + ) + except TaskQueueFullError: + raise queue.Full("Input queue full") + + # Notify staleness manager AFTER successful submission + self.staleness_manager.on_rollout_submitted() + if self.config.enable_rollout_tracing: + stat = self.staleness_manager.get_stats() + self.logger.info( + f"Submit rollout. " + f"Submit: {stat.submitted}, " + f"running: {stat.running}, " + f"accepted: {stat.accepted}." + ) + + def wait(self, count: int, timeout: float | None = None) -> DistributedBatch: + """Wait for a specified number of requests to complete, with a timeout. + + Should be used together with preceding `submit`. + + Parameters + ---------- + count : int + The number of accepted trajectories to wait for + timeout : float | None, optional + Timeout in seconds. Exceeding the timeout will raise a `TimeoutError`, by default None + + Returns + ------- + DistributedBatch + A concatenated batch of trajectories + + Raises + ------ + TimeoutError + If the timeout is exceeded before enough trajectories are collected + """ + ####################################################### + # The following logic is copied from WorkflowExecutor # + ####################################################### + start_time = time.perf_counter() + timeout = timeout or float(7 * 24 * 3600) + + # Keep requesting results from runner until we have enough accepted + # (non-None) results. Use short timeout (1 second) for each wait call + # to allow periodic checking. This matches original behavior where + # wait() would poll and allow prepare_batch() to continue + while True: + # Submit pending inputs + # Check capacity before submitting + capacity = self.get_capacity() + # Submit pending tasks + for _ in range(capacity): + if len(self._pending_inputs) == 0: + break + self._commit_one_to_runner() + + if len(self._pending_results) >= count: + break + + elapsed = time.perf_counter() - start_time + remaining_timeout = timeout - elapsed + + if remaining_timeout <= 0: + raise TimeoutError( + f"Timed out waiting for {count} rollouts, only received " + f"{len(self._pending_results)}." + ) + + # Try to get at least the number we still need, but request at least 1 + # Note: runner.wait() might return fewer due to rejections (None results) + needed = max(1, count - len(self._pending_results)) + + try: + # Use short timeout to allow periodic returns (matches original + # polling behavior) + batch = self.runner.wait( + count=needed, timeout=min(0.1, remaining_timeout) + ) + + # Filter out None results (rejected trajectories) + # runner.wait() returns List[T] where T can be None for + # rejected rollouts + accepted = [result for result in batch if result is not None] + self._pending_results.extend(accepted) + except TimeoutError: + pass + + if self.config.enable_rollout_tracing: + self.logger.info("Rollout results are ready!") + + # Extract requested number of results + results = self._pending_results[:count] + self._pending_results = self._pending_results[count:] + + # Shuffle for randomness (helps with data diversity) + random.shuffle(results) + + # Convert to DistributedBatch + if len(results) == 0: + return DistributedBatchMemory.from_dict({}) + + return DistributedBatchMemory.concat( + [DistributedBatchMemory.from_dict(r) for r in results] + ) + + def rollout_batch( + self, + data: list[dict[str, Any]], + workflow_path: str, + workflow_kwargs: dict[str, Any], + should_accept_path: str | None = None, + ) -> DistributedBatch: + """Submit a batch of requests to the inference engine and wait for the results. + + Parameters + ---------- + data : list[dict[str, Any]] + A list of input data dictionaries for rollout + workflow_path : str + The fully qualified path to the workflow class (e.g., "module.submodule.WorkflowClass") + workflow_kwargs : dict[str, Any] + Keyword arguments to pass to the workflow constructor + should_accept_path : str | None, optional + The fully qualified path to a function to decide whether to accept a trajectory, by default None + + Returns + ------- + DistributedBatch + A concatenated batch of trajectory results + """ + # Submit all requests + for item in data: + self.submit( + item, + workflow_kwargs=workflow_kwargs, + workflow_path=workflow_path, + should_accept_path=should_accept_path, + ) + + # Wait for all results + return self.wait(count=len(data)) + + def prepare_batch( + self, + dataloader: StatefulDataLoader, + workflow_path: str, + workflow_kwargs: dict[str, Any], + should_accept_path: str | None = None, + ) -> DistributedBatch: + """Asynchronously submit and wait until a full batch is ready with controlled staleness. + + Parameters + ---------- + dataloader : StatefulDataLoader + The data loader to pull data from for batch preparation + workflow_path : str + The fully qualified path to the workflow class (e.g., "module.submodule.WorkflowClass") + workflow_kwargs : dict[str, Any] + Keyword arguments to pass to the workflow constructor + should_accept_path : str | None, optional + The fully qualified path to a function to decide whether to accept a trajectory, by default None + + Returns + ------- + DistributedBatch + A full batch of trajectory results with controlled staleness + """ + ####################################################### + # The following logic is copied from WorkflowExecutor # + ####################################################### + if not hasattr(self, "data_generator"): + self.data_generator = cycle_dataloader(dataloader) + assert dataloader.batch_size is not None + while True: + # Submit at least two batches to allow maximum overlap + if ( + self.get_capacity() + dataloader.batch_size > 0 + and self.runner.get_input_queue_size() + dataloader.batch_size + < self.runner.max_queue_size + ): + data = next(self.data_generator) + for item in data: + try: + self.submit( + item, + workflow_kwargs=workflow_kwargs, + workflow_path=workflow_path, + should_accept_path=should_accept_path, + ) + except queue.Full: + # Capacity exhausted during batch submission, stop and wait + break + try: + return self.wait(dataloader.batch_size, timeout=1) + except TimeoutError: + pass + + async def agenerate(self, req: ModelRequest) -> ModelResponse: + """Asynchronously generate a response for the given request. + + This method provides direct access to the inference engine's generation capabilities + for single requests, bypassing the workflow system. + + Parameters + ---------- + req : ModelRequest + The model request containing input data and generation parameters + + Returns + ------- + ModelResponse + The generated response from the model + """ + # Choose worker and delegate + worker = self._choose_worker() + + # Call agenerate on engine via scheduler + return await self.scheduler.async_call_engine( + worker_id=worker.id, + method="agenerate", + req=req, + ) + + async def init_weights_update_group(self, meta: WeightUpdateMeta) -> None: + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, + method="init_weights_update_group", + meta=meta, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + + async def update_weights_from_distributed( + self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] + ): + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, + method="update_weights_from_distributed", + meta=meta, + param_specs=param_specs, + max_retries=1, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + + async def update_weights_from_disk(self, meta: WeightUpdateMeta): + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, + method="update_weights_from_disk", + meta=meta, + max_retries=1, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + + def set_version(self, version: int) -> None: + """Set the current weight version in the inference engine. + + This updates the version number across all workers, which is used for + staleness tracking in online training scenarios. + + Parameters + ---------- + version : int + The weight version number to set + """ + self._version = version + for worker in self.workers: + try: + self.scheduler.call_engine( + worker_id=worker.id, + method="set_version", + version=version, + max_retries=1, + ) + except Exception as e: + self.logger.error(f"Error setting version for worker {worker.id}: {e}") + + def get_version(self) -> int: + """Get the current weight version in the inference engine. + + Returns + ------- + int + The current weight version number + """ + return self._version + + def pause(self): + """Pause request submission for async rollout. + + Used during evaluation to prevent data over-generation across all workers. + """ + for worker in self.workers: + try: + self.scheduler.call_engine( + worker_id=worker.id, + method="pause", + ) + except Exception as e: + self.logger.error(f"Error pausing worker {worker.id}: {e}") + + def resume(self): + """Resume request submission for async rollout across all workers.""" + for worker in self.workers: + try: + self.scheduler.call_engine( + worker_id=worker.id, + method="resume", + ) + except Exception as e: + self.logger.error(f"Error resuming worker {worker.id}: {e}") + + def export_stats(self): + async def _call_all(): + tasks = [ + self.scheduler.async_call_engine( + worker=worker, + method="export_stats", + ) + for worker in self.workers + ] + return await asyncio.gather(*tasks) + + # Stats + all_raw_stats = asyncio.run(_call_all()) + stats = {} + exported = set() + for raw_stats in all_raw_stats: + for k in raw_stats: + if k in exported: + continue + data = sum([s[1].get(k, []) for s in all_raw_stats], []) + if len(data) == 0: + continue + stats[k] = sum(data) / len(data) + exported.add(k) + return stats + + def register_callback_to_all_worker( + self, method: str, callback: Callable, **kwargs + ): + """Register a callback function for the specified method across all workers. + + Partial rollout API. After successful registration, the controller will poll + and call the specified method in a background thread. When the return value + is obtained, it will be used as a parameter to call the `callback` function. + + Parameters + ---------- + method : str + The name of the method to register the callback for + callback : Callable + The callback function to be called with the method's return value + **kwargs + Additional keyword arguments for the callback registration + + Raises + ------ + NotImplementedError + This method is not yet implemented + """ + raise NotImplementedError() + + def abort_all_requests(self) -> None: + """Abort all ongoing requests in the inference engine. + + Partial rollout API for canceling all queued and in-progress requests across + all workers. + + Raises + ------ + NotImplementedError + This method is not yet implemented + """ + raise NotImplementedError() diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index f440bf8b5..d6d6e80a4 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -1,3 +1,781 @@ -# Placeholder for train controller logic +import asyncio +import shutil +from collections.abc import Callable +from copy import deepcopy +from datetime import datetime +from typing import Any + +import torch +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.alloc_mode import ParallelStrategy +from areal.api.cli_args import TrainEngineConfig +from areal.api.controller_api import DistributedBatch +from areal.api.engine_api import TrainEngine +from areal.api.io_struct import ( + AllocationMode, + FinetuneSpec, + SaveLoadMeta, + WeightUpdateMeta, +) +from areal.api.scheduler_api import Job, Scheduler, Worker +from areal.controller.batch import DistributedBatchMemory +from areal.controller.rollout_controller import RolloutController +from areal.platforms import current_platform +from areal.utils import logging, name_resolve, names +from areal.utils.network import find_free_ports + +logger = logging.getLogger("TrainController") + + class TrainController: - pass + def __init__( + self, + train_engine: type[TrainEngine], + config: TrainEngineConfig, + scheduler: Scheduler, + ): + self.train_engine = train_engine + self.config = config + self.scheduler = scheduler + + self.alloc_mode: AllocationMode + self.workers: list[Worker] = [] + self.workers_is_dp_head: list[bool] = [] # Only DP head workers + self.parallel_strategy: ParallelStrategy | None = None + + self.rollout: RolloutController = None + self.weight_update_group_initialized = False + + self._worker_role: str + self.logger = None + + def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): + # A dummy method. Process group will be created during `initialize` + pass + + def initialize( + self, + role: str, + alloc_mode: AllocationMode, + ft_spec: FinetuneSpec, + **kwargs, + ): + """Initialize environments for distributed training and load models. + + This method should be called after `create_process_group`. + + Parameters + ---------- + role : str + Role identifier for the workers + alloc_mode : AllocationMode + Allocation mode configuration for distributed setup + ft_spec : FinetuneSpec + Finetune specification for model initialization + **kwargs + Additional keyword arguments passed to engine initialization + """ + self.logger = logging.getLogger("[TrainController]") + + # Store configuration + self._worker_role = role + self.alloc_mode = alloc_mode + + if alloc_mode.gen_backend == "sglang": + self.config.scheduling_spec.env_vars["NCCL_CUMEM_ENABLE"] = "0" + self.config.scheduling_spec.env_vars["NCCL_NVLS_ENABLE"] = "0" + + self.parallel_strategy = alloc_mode.train + + # Create job for scheduler + job = Job( + replicas=alloc_mode.train.world_size, + tasks=[ + deepcopy(self.config.scheduling_spec) + for _ in range(alloc_mode.train.world_size) + ], + scheduling_strategy=self.config.scheduling_strategy, + role=self._worker_role, + ) + # Create environment variables to mimic torchrun + # FIXME: here master_addr and master_port only work in the local setting + port = find_free_ports(1)[0] + for i, task in enumerate(job.tasks): + task.env_vars["RANK"] = str(i) + task.env_vars["WORLD_SIZE"] = str(alloc_mode.train.world_size) + task.env_vars["LOCAL_RANK"] = str( + 0 + ) # because we have only set 1 CUDA_VISIBLE_DEVICES for each process + # TODO: find a real master addr with scheduler + task.env_vars["MASTER_ADDR"] = "localhost" + task.env_vars["MASTER_PORT"] = str(port) + + # Create workers via scheduler + self.logger.info("Creating workers via scheduler...") + worker_ids = self.scheduler.create_workers(job=job) + self.logger.info(f"Workers created: {worker_ids}") + + # Wait for workers to be ready + self.logger.info("Waiting for workers to be ready...") + self.workers = self.scheduler.get_workers(role=job.role) + self.logger.info(f"Workers ready: {[w.id for w in self.workers]}") + + # Get engine class path for dynamic import on workers + engine_class = self.train_engine + engine_path = f"{engine_class.__module__}.{engine_class.__name__}" + + # Create and initialize engines on workers + self._run_async_task(self._async_create_engines(engine_path)) + self._run_async_task(self._async_initialize_engines(ft_spec, **kwargs)) + + # Identify DP head workers + self._identify_dp_heads() + + self.logger.info("TrainController initialization complete") + + def _run_async_task(self, task): + def _run_in_thread(): + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + return new_loop.run_until_complete(task) + finally: + new_loop.close() + + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor() as executor: + future = executor.submit(_run_in_thread) + return future.result() + + async def _async_create_engines(self, engine_path: str): + # Create engines on workers + self.logger.info("Creating engines on workers...") + tasks = [ + self.scheduler.create_engine( + worker_id=worker.id, + engine=engine_path, + config=self.config, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + self.logger.info("Engines created on all workers!") + + async def _async_initialize_engines(self, ft_spec: FinetuneSpec, **kwargs): + # Initialize engines + self.logger.info("Calling engine initialization...") + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, + method="create_process_group", + parallel_strategy=self.parallel_strategy, + _should_bcast=False, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, + method="initialize", + ft_spec=ft_spec, + _should_bcast=False, + **kwargs, + ) + for worker in self.workers + ] + await asyncio.gather(*tasks) + self.logger.info("All engines are initialized!") + + def _identify_dp_heads(self): + """Identify which workers are DP heads by querying their DP rank.""" + self.logger.info("Identifying DP head workers...") + + # Query all workers for their DP rank + async def _get_dp_head(): + tasks = [ + self.scheduler.async_call_engine( + worker_id=worker.id, method="is_data_parallel_head" + ) + for worker in self.workers + ] + return await asyncio.gather(*tasks) + + self.workers_is_dp_head = self._run_async_task(_get_dp_head()) + + def destroy(self): + """Destroy the controller and release GPU memory of models. + + Cleans up all resources including workers, engines, and internal state. + """ + self.logger.info("Destroying TrainController...") + + # Delete workers via scheduler + try: + self.scheduler.delete_workers(role=self._worker_role) + self.logger.info("Workers deleted") + except Exception as e: + self.logger.error(f"Error deleting workers: {e}") + + # Clear worker lists + self.workers.clear() + self.workers_is_dp_head.clear() + + self.logger.info("TrainController destroyed") + + def _custom_function_call(self, method: str, *args, **kwargs): + """Dispatch method call to appropriate workers based on input type. + + If any argument is a DistributedBatch, split data. Call only DP heads. + """ + dp_split_args, dp_split_kwargs = self._dispatch_inputs(*args, **kwargs) + results = self._run_async_task( + self._call_with_dispatched_inputs(method, dp_split_args, dp_split_kwargs) + ) + # Only remain data from DP head. + results = [r for idx, r in enumerate(results) if self.workers_is_dp_head[idx]] + return self._merge_results(results, method) + + async def _async_custom_function_call(self, method: str, *args, **kwargs): + dp_split_args, dp_split_kwargs = self._dispatch_inputs(*args, **kwargs) + results = await self._call_with_dispatched_inputs( + method, dp_split_args, dp_split_kwargs + ) + # Only remain data from DP head. + results = [r for idx, r in enumerate(results) if self.workers_is_dp_head[idx]] + return self._merge_results(results, method) + + def _dispatch_inputs(self, *args, **kwargs): + # Find and split DistributedBatch arguments + split_args = [] + for arg in args: + if isinstance(arg, DistributedBatch): + # Split across DP groups + split_args.append(self._align_batches_with_dp(arg, rebalance=True)) + else: + # Replicate to all DP heads + split_args.append([arg] * self.parallel_strategy.dp_size) + + split_kwargs = {} + for k, v in kwargs.items(): + if isinstance(v, DistributedBatch): + split_kwargs[k] = self._align_batches_with_dp(v, rebalance=True) + else: + split_kwargs[k] = [v] * self.parallel_strategy.dp_size + return split_args, split_kwargs + + async def _call_with_dispatched_inputs( + self, + method: str, + dp_split_args: list[list[Any]], + dp_worker_kwargs: list[dict[str, Any]], + ): + # Call all workers. + # ONLY DP head workers get their data slice. + # Other workers will get data by broadcasting in RPC server. + tasks = [] + dp_idx = 0 + for idx, worker in enumerate(self.workers): + if self.workers_is_dp_head[idx]: + # Get this worker's slice of each argument + worker_args = [splits[dp_idx] for splits in dp_split_args] + worker_kwargs = { + k: splits[dp_idx] for k, splits in dp_worker_kwargs.items() + } + + # Convert DistributedBatch to dict for RPC + # FIXME: pass metadata instead of real tensors + worker_args = [ + arg.get_data() if isinstance(arg, DistributedBatch) else arg + for arg in worker_args + ] + worker_kwargs = { + k: v.get_data() if isinstance(v, DistributedBatch) else v + for k, v in worker_kwargs.items() + } + dp_idx += 1 + else: + worker_args = [] + worker_kwargs = {} + + tasks.append( + self.scheduler.async_call_engine( + worker.id, + method, + *worker_args, + **worker_kwargs, + ) + ) + return await asyncio.gather(*tasks) + + def _merge_results(self, results, method): + """Merge results from DP head workers based on result type. + + - For torch.Tensor: concat results as DistributedBatch + - For others: assume they have been synchronized and return the first + """ + first_result = results[0] + + # FIXME: should use a more general data conversion strategy + if isinstance(first_result, torch.Tensor): + # Assume that tensor shapes are [bs, seqlen, *] + max_length = max(tensor.shape[1] for tensor in results) + n_dim = first_result.ndim + padded_tensors = [] + for tensor in results: + pad_mode = ( + (0,) * (2 * (n_dim - 2)) + + (0, max_length - tensor.shape[1]) + + (0, 0) + ) + padded_tensor = torch.nn.functional.pad(tensor, pad_mode, value=0.0) + padded_tensors.append(padded_tensor) + return torch.cat(padded_tensors, dim=0) + + if isinstance(first_result, dict): + if len(first_result) == 0: + return DistributedBatchMemory.from_dict({}) + + k = next(iter(first_result.keys())) + if isinstance(first_result[k], torch.Tensor): + # Check if this looks like a proper batch (has attention_mask) + # If so, use DistributedBatchMemory.concat which handles padding + if "attention_mask" in first_result: + return DistributedBatchMemory.concat( + [DistributedBatchMemory.from_dict(r) for r in results] + ) + else: + # Simple tensor dict - just concatenate tensors along batch dim + merged = {} + for key in first_result.keys(): + if isinstance(first_result[key], torch.Tensor): + merged[key] = torch.cat([r[key] for r in results], dim=0) + else: + merged[key] = first_result[key] + return DistributedBatchMemory.from_dict(merged) + + # Return first (already synchronized) + return first_result + + def _align_batches_with_dp( + self, input_: DistributedBatch, rebalance=True + ) -> list[DistributedBatch]: + """Split DistributedBatch across DP groups. + + Returns a list of batches, one for each DP head worker. + """ + # Handle empty batch by replicating to all DP groups + if len(input_.get_data()) == 0: + return [input_] * self.alloc_mode.train.dp_size + + # NOTE: group normalization should be done in workflow + if rebalance: + inputs = input_.chunk_by_ffd(1, self.alloc_mode.train.dp_size) + else: + inputs = input_.chunk(self.alloc_mode.train.dp_size) + return inputs + + def connect_engine(self, rollout: RolloutController, meta: WeightUpdateMeta): + if self.rollout is not None and self.rollout != rollout: + self.logger.warning( + f"Connected rollout controller changed from {self.rollout} to {rollout}." + ) + self.rollout = rollout + + if ( + meta.type == current_platform.communication_backend + and not self.weight_update_group_initialized + ): + self._init_weight_update_from_distributed(meta) + self.weight_update_group_initialized = True + + def prepare_batch( + self, + dataloader: StatefulDataLoader, + workflow_path: str, + workflow_kwargs: dict[str, Any], + should_accept_path: str | None = None, + ) -> DistributedBatch: + return self.rollout.prepare_batch( + dataloader=dataloader, + workflow_path=workflow_path, + workflow_kwargs=workflow_kwargs, + should_accept_path=should_accept_path, + ) + + def rollout_batch( + self, + data: list[dict[str, Any]], + workflow_path: str, + workflow_kwargs: dict[str, Any], + should_accept_path: str | None = None, + ) -> DistributedBatch: + return self.rollout.rollout_batch( + data=data, + workflow_path=workflow_path, + workflow_kwargs=workflow_kwargs, + should_accept_path=should_accept_path, + ) + + def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta): + raise NotImplementedError() + + def _update_weights_from_distributed(self, meta: WeightUpdateMeta): + raise NotImplementedError() + + def _update_weights_from_disk(self, meta: WeightUpdateMeta): + # Update all LocalInfEngine's local weight + self.save( + SaveLoadMeta( + path=meta.path, + weight_format="hf", + with_optim=False, + tokenizer=None, + processor=None, + ) + ) + update_name = names.update_weights_from_disk( + self.config.experiment_name, + self.config.trial_name, + self.get_version(), + ) + name_resolve.add( + update_name, + str(datetime.now().timestamp()), + keepalive_ttl=120, + replace=True, + ) + + meta.clear_checkpoint = False + asyncio.run(self.rollout.update_weights_from_disk(meta)) + shutil.rmtree(meta.path, ignore_errors=True) + + def _check_rollout_engine_connected(self): + """Validate that rollout engine has been connected via connect_engine().""" + if self.rollout is None: + raise RuntimeError( + "Rollout engine not connected. Call connect_engine()" + " before using rollout/update_weight methods." + ) + + def update_weights(self, meta: WeightUpdateMeta): + self._check_rollout_engine_connected() + if meta.type == current_platform.communication_backend: + assert self.weight_update_group_initialized + self._update_weights_from_distributed(meta) + elif meta.type == "disk": + self._update_weights_from_disk(meta) + else: + raise ValueError(f"Unknown weight update type {meta.type}") + + def export_stats(self): + async def _call_all(): + tasks = [ + self.scheduler.async_call_engine(worker.id, "export_stats") + for worker in self.workers + ] + return await asyncio.gather(*tasks) + + results = self._run_async_task(_call_all()) + # stats have been aggregated and synchronized. + return results[0] + + # ==================== ENGINE RPC WRAPPERS ==================== + def train(self, mode: bool = True): + """Set the engine to training mode. + + Parameters + ---------- + mode : bool, optional + Whether to set the engine to training mode, by default True + + Returns + ------- + TrainController + Returns self for method chaining + """ + self._custom_function_call("train", mode) + return self + + def eval(self): + """Set the engine to evaluation mode. + + This is a convenience method that calls `self.train(False)`. + + Returns + ------- + TrainController + Returns self for method chaining + """ + return self.train(False) + + def set_version(self, version: int): + """Set the current weight version in the training engine. + + Parameters + ---------- + version : int + The weight version number to set + """ + self._custom_function_call("set_version", version) + + def get_version(self) -> int: + """Get the current weight version in the training engine. + + Returns + ------- + int + The current weight version number + """ + return self._custom_function_call("get_version") + + def save(self, meta: SaveLoadMeta): + """Save model weights and optimizer states for later use. + + Parameters + ---------- + meta : SaveLoadMeta + Metadata containing information about where and how to save + """ + self._custom_function_call("save", meta) + + def load(self, meta: SaveLoadMeta): + """Load model weights and optimizer states from a file. + + Parameters + ---------- + meta : SaveLoadMeta + Metadata containing information about where and how to load + """ + self._custom_function_call("load", meta) + + def step_lr_scheduler(self): + """Step the learning rate scheduler. + + Since PPO uses minibatch updates, this method should be called periodically + (e.g., once per PPO step). It is separated from train_batch to allow + for more flexible learning rate scheduling. + """ + self._custom_function_call("step_lr_scheduler") + + def forward( + self, + input_: DistributedBatch, + output_seqlens: list[int] | None = None, + post_hook: Callable[[torch.Tensor, dict[str, Any]], Any] | None = None, + aggregate_fn: Callable[[list[Any]], Any] = torch.cat, + ) -> Any | None: + """Run the forward pass or inference on the model. + + Note + ---- + This operation is gradient-free. + + Parameters + ---------- + input_ : DistributedBatch + The distributed input data for model forward pass. Redundant entries are allowed. + output_seqlens : List[int], optional + The desired output sequence lengths. If None, assumes that the output + has the same lengths as inputs, by default None. + post_hook : Callable[[torch.Tensor, Dict[str, Any]], Any], optional + The post-processing function for micro-batched outputs. Post-processing + the output on-the-fly during micro-batched forward can reduce peak + memory usage, by default None. + aggregate_fn : Callable[[List[Any]], Any], optional + A function to aggregate micro-batched outputs, by default torch.cat. + + Returns + ------- + Any or None + The result produced by `post_hook` and `aggregate_fn`. + """ + return self._custom_function_call( + "forward", + input_=input_, + output_seqlens=output_seqlens, + post_hook=post_hook, + aggregate_fn=aggregate_fn, + ) + + def train_batch( + self, + input_: DistributedBatch, + loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], + loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], + ) -> dict[str, float]: + """Update the model with a batch of data and a loss function. + + Note + ---- + The loss_fn should process packed 1D inputs, instead of 2D inputs. + + Parameters + ---------- + input_ : DistributedBatch + The distributed input data for model forward pass and the loss function. + Redundant entries are allowed. + loss_fn : Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor] + The loss function that takes the model's forward output and input_, + and outputs a scalar normalized loss. + loss_weight_fn : Callable[[Dict[str, Any]], torch.Tensor] + A function used to calculate the weight of each micro-batch. Since + loss_fn normalizes the loss for a micro-batch, we need a corresponding + weight for each micro-batch to normalize the loss globally. The weight + is usually the number of response tokens in the batch. + + Returns + ------- + Dict[str, float] + Scalar statistics after training, e.g., the current learning rate, + gradient norm, etc. + """ + return self._custom_function_call( + "train_batch", input_, loss_fn, loss_weight_fn + ) + + def eval_batch( + self, + input_: DistributedBatch, + loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], + loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], + ) -> torch.Tensor | None: + """Evaluate the model using the forward pass and loss function. + + Note + ---- + The loss_fn should process packed 1D inputs, instead of 2D inputs. + + Parameters + ---------- + input_ : DistributedBatch + The distributed input data for model forward pass and the loss function. + Redundant entries are allowed. + loss_fn : Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor] + The loss function that takes the model's forward output and input_, + and outputs a scalar normalized loss. + loss_weight_fn : Callable[[Dict[str, Any]], torch.Tensor] + A function used to calculate the weight of each micro-batch. Since + loss_fn normalizes the loss for a micro-batch, we need a corresponding + weight for each micro-batch to normalize the loss globally. The weight + is usually the number of response tokens in the batch. + + Returns + ------- + torch.Tensor or None + A scalar loss or None. The evaluation statistics should be aggregated + with `stats_tracker`. + """ + return self._custom_function_call("eval_batch", input_, loss_fn, loss_weight_fn) + + # ==================== SFT RPC WRAPPERS ==================== + def train_lm( + self, + input_: DistributedBatch, + *args, + **kwargs, + ) -> dict[str, float]: + """Train language model across workers. + + Parameters + ---------- + input_ : DistributedBatch + The distributed input data for language model training + *args + Additional positional arguments passed to the engine + **kwargs + Additional keyword arguments passed to the engine + + Returns + ------- + Dict[str, float] + Scalar statistics after training + """ + return self._custom_function_call("train_lm", input_, *args, **kwargs) + + def evaluate_lm( + self, + input_: DistributedBatch, + *args, + **kwargs, + ) -> torch.Tensor | None: + """Evaluate language model across workers. + + Parameters + ---------- + input_ : DistributedBatch + The distributed input data for language model evaluation + *args + Additional positional arguments passed to the engine + **kwargs + Additional keyword arguments passed to the engine + + Returns + ------- + torch.Tensor or None + A scalar loss or None + """ + return self._custom_function_call("evaluate_lm", input_, *args, **kwargs) + + # ==================== PPO RPC WRAPPERS ==================== + def compute_logp( + self, + *args, + **kwargs, + ): + """Compute log probabilities across workers. + + Parameters + ---------- + *args + Positional arguments passed to the engine + **kwargs + Keyword arguments passed to the engine + + Returns + ------- + Any + Log probabilities computed by the engine + """ + return self._custom_function_call("compute_logp", *args, **kwargs) + + def compute_advantages( + self, + *args, + **kwargs, + ): + """Compute advantages across workers. + + Parameters + ---------- + *args + Positional arguments passed to the engine + **kwargs + Keyword arguments passed to the engine + + Returns + ------- + Any + Advantages computed by the engine + """ + return self._custom_function_call("compute_advantages", *args, **kwargs) + + def ppo_update( + self, + input_: DistributedBatch, + ) -> dict[str, float]: + """Perform PPO update step with the given batch. + + Parameters + ---------- + input_ : DistributedBatch + The distributed input data containing trajectories for PPO update + + Returns + ------- + Dict[str, float] + Scalar statistics after PPO update + """ + return self._custom_function_call("ppo_update", input_) diff --git a/areal/engine/ppo/actor.py b/areal/engine/ppo/actor.py index 59c5da998..bc51f30e7 100644 --- a/areal/engine/ppo/actor.py +++ b/areal/engine/ppo/actor.py @@ -67,7 +67,6 @@ def __init__(self, config: PPOActorConfig, engine: TrainEngine): logger.info(f" eps_clip: {config.eps_clip}") logger.info(f" group_size: {config.group_size}") - @trace_perf("ppo_actor.compute_logp", category="compute") @torch.no_grad() def compute_logp( self, @@ -199,6 +198,7 @@ def ppo_update(self, data: dict[str, Any]) -> None: reward_score = data["rewards"] seqlens = attn_mask.sum(-1) + all_stats = [] ########## Logging code starts ########## result_denominators = { "correct_n_seqs": (reward_score > 0).bool(), @@ -236,6 +236,7 @@ def ppo_update(self, data: dict[str, Any]) -> None: ) stats_tracker.stat(**stats, denominator="n_valid_tokens") + prompt_lens = [] prompt_lens = data["attention_mask"].sum(-1) - data["loss_mask"].sum(-1) seq_stats = dict( no_eos_ratios=(seqlens == attn_mask.shape[-1]).float(), @@ -262,6 +263,15 @@ def ppo_update(self, data: dict[str, Any]) -> None: **{k: data[k].float() for k in self.config.log_agent_stats_keys}, denominator="agent", ) + + global_stats = stats_tracker.export( + reduce_group=self.engine.data_parallel_group + ) + for k in global_denominators: + keys = list(global_stats.keys()) + for k2 in keys: + if k2.endswith(k): + global_stats.pop(k2) ########## Logging code ends ########## for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]: @@ -272,24 +282,27 @@ def ppo_update(self, data: dict[str, Any]) -> None: data, mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches), ) - - with stats_tracker.scope("update"): - for mb in mb_inputs.mbs: - train_stat = self.engine.train_batch( - mb, - loss_fn=functools.partial( - grpo_loss_fn, - temperature=self.temperature, - eps_clip=self.config.eps_clip, - eps_clip_higher=self.config.eps_clip_higher, - c_clip=self.config.c_clip, - behav_imp_weight_cap=self.config.behav_imp_weight_cap, - m2_threshold=self.m2_threshold, - importance_sampling_level=self.config.importance_sampling_level, - ), - loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), - ) - stats_tracker.scalar(**train_stat) + for mb in mb_inputs.mbs: + train_stat = self.engine.train_batch( + mb, + loss_fn=functools.partial( + grpo_loss_fn, + temperature=self.temperature, + eps_clip=self.config.eps_clip, + eps_clip_higher=self.config.eps_clip_higher, + c_clip=self.config.c_clip, + behav_imp_weight_cap=self.config.behav_imp_weight_cap, + m2_threshold=self.m2_threshold, + importance_sampling_level=self.config.importance_sampling_level, + ), + loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), + ) + stats_tracker.scalar(**train_stat) + all_stats.append( + stats_tracker.export(reduce_group=self.engine.data_parallel_group) + ) + all_stats[0].update(global_stats) + return all_stats class FSDPPPOActor(FSDPEngine): @@ -305,8 +318,8 @@ def compute_logp(self, *args, **kwargs) -> torch.Tensor: def compute_advantages(self, *args, **kwargs) -> dict[str, Any]: return self.actor.compute_advantages(*args, **kwargs) - def ppo_update(self, *args, **kwargs) -> None: - self.actor.ppo_update(*args, **kwargs) + def ppo_update(self, *args, **kwargs) -> list[dict[str, float]]: + return self.actor.ppo_update(*args, **kwargs) class MegatronPPOActor(MegatronEngine): @@ -322,8 +335,8 @@ def compute_logp(self, *args, **kwargs) -> torch.Tensor: def compute_advantages(self, *args, **kwargs) -> dict[str, Any]: return self.actor.compute_advantages(*args, **kwargs) - def ppo_update(self, *args, **kwargs) -> None: - self.actor.ppo_update(*args, **kwargs) + def ppo_update(self, *args, **kwargs) -> list[dict[str, float]]: + return self.actor.ppo_update(*args, **kwargs) def grpo_loss_fn( diff --git a/areal/experimental/openai/tool_call_parser.py b/areal/experimental/openai/tool_call_parser.py index 75f011a2c..70fba9640 100644 --- a/areal/experimental/openai/tool_call_parser.py +++ b/areal/experimental/openai/tool_call_parser.py @@ -2,11 +2,9 @@ import uuid from typing import Any -from openai.types.chat.chat_completion_message_function_tool_call import ( - ChatCompletionMessageFunctionToolCall, - Function, -) -from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall +from openai.types.chat import ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_message_tool_call import Function +from openai.types.responses import ResponseFunctionToolCall from areal.utils import logging @@ -21,7 +19,7 @@ def process_tool_calls( finish_reason: str, use_responses: bool = False, ) -> tuple[ - list[ChatCompletionMessageFunctionToolCall | ResponseFunctionToolCall] | None, + list[ChatCompletionMessageToolCall | ResponseFunctionToolCall] | None, str, str, ]: @@ -69,7 +67,7 @@ def process_tool_calls( ] else: tool_calls = [ - ChatCompletionMessageFunctionToolCall( + ChatCompletionMessageToolCall( type="function", id=f"call_{uuid.uuid4().hex[:24]}", function=Function( diff --git a/areal/scheduler/__init__.py b/areal/scheduler/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/areal/scheduler/exceptions.py b/areal/scheduler/exceptions.py new file mode 100644 index 000000000..29f746eff --- /dev/null +++ b/areal/scheduler/exceptions.py @@ -0,0 +1,128 @@ +"""Custom exceptions for the scheduler module.""" + + +class SchedulerError(Exception): + """Base exception for all scheduler-related errors.""" + + pass + + +class WorkerCreationError(SchedulerError): + """Raised when worker creation fails during subprocess spawn or initialization.""" + + def __init__(self, worker_key: str, reason: str, details: str = ""): + self.worker_key = worker_key + self.reason = reason + self.details = details + message = f"Failed to create worker '{worker_key}': {reason}" + if details: + message += f"\nDetails: {details}" + super().__init__(message) + + +class WorkerConfigurationError(SchedulerError): + def __init__(self, worker_key: str, reason: str, details: str = ""): + self.worker_key = worker_key + self.reason = reason + self.details = details + message = f"Failed to configure worker '{worker_key}': {reason}" + if details: + message += f"\nDetails: {details}" + super().__init__(message) + + +class WorkerFailedError(SchedulerError): + """Raised when a worker process fails or exits unexpectedly.""" + + def __init__(self, worker_id: str, exit_code: int, stderr: str = ""): + self.worker_id = worker_id + self.exit_code = exit_code + self.stderr = stderr + message = f"Worker '{worker_id}' failed with exit code {exit_code}" + if stderr: + message += f"\nStderr output:\n{stderr}" + super().__init__(message) + + +class WorkerNotFoundError(SchedulerError): + """Raised when attempting to access a worker that doesn't exist.""" + + def __init__(self, worker_id: str): + self.worker_id = worker_id + super().__init__(f"Worker '{worker_id}' not found") + + +class EngineCreationError(SchedulerError): + """Raised when engine creation fails on a worker.""" + + def __init__(self, worker_id: str, reason: str, status_code: int = None): + self.worker_id = worker_id + self.reason = reason + self.status_code = status_code + message = f"Failed to create engine on worker '{worker_id}': {reason}" + if status_code: + message += f" (HTTP {status_code})" + super().__init__(message) + + +class EngineCallError(SchedulerError): + """Raised when calling an engine method fails.""" + + def __init__(self, worker_id: str, method: str, reason: str, attempt: int = 1): + self.worker_id = worker_id + self.method = method + self.reason = reason + self.attempt = attempt + message = f"Failed to call method '{method}' on worker '{worker_id}': {reason}" + if attempt > 1: + message += f" (after {attempt} attempts)" + super().__init__(message) + + +class WorkerTimeoutError(SchedulerError): + """Raised when waiting for a worker exceeds the timeout.""" + + def __init__(self, worker_key: str, timeout: float): + self.worker_key = worker_key + self.timeout = timeout + super().__init__( + f"Timeout waiting for worker '{worker_key}' (waited {timeout}s)" + ) + + +class PortAllocationError(SchedulerError): + """Raised when port allocation fails.""" + + def __init__(self, reason: str): + self.reason = reason + super().__init__(f"Failed to allocate ports: {reason}") + + +class GPUAllocationError(SchedulerError): + """Raised when GPU allocation fails.""" + + def __init__(self, reason: str): + self.reason = reason + super().__init__(f"Failed to allocate GPU resources: {reason}") + + +class RPCConnectionError(SchedulerError): + """Raised when RPC connection to a worker fails.""" + + def __init__(self, worker_id: str, host: str, port: int, reason: str): + self.worker_id = worker_id + self.host = host + self.port = port + self.reason = reason + super().__init__( + f"Failed to connect to worker '{worker_id}' at {host}:{port}: {reason}" + ) + + +class EngineImportError(SchedulerError): + """Raised when importing an engine class fails on the worker.""" + + def __init__(self, import_path: str, reason: str): + self.import_path = import_path + self.reason = reason + super().__init__(f"Failed to import engine '{import_path}': {reason}") diff --git a/areal/scheduler/local.py b/areal/scheduler/local.py new file mode 100644 index 000000000..4cdcb5b55 --- /dev/null +++ b/areal/scheduler/local.py @@ -0,0 +1,1162 @@ +"""Local scheduler for managing worker subprocesses on a single GPU node.""" + +import getpass +import os +import shlex +import subprocess +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import httpx +import orjson +import psutil + +from areal.api.cli_args import BaseExperimentConfig +from areal.api.scheduler_api import Job, Scheduler, SchedulingSpec, Worker +from areal.platforms import current_platform +from areal.scheduler.exceptions import ( + EngineCallError, + EngineCreationError, + EngineImportError, + GPUAllocationError, + PortAllocationError, + RPCConnectionError, + SchedulerError, + WorkerConfigurationError, + WorkerCreationError, + WorkerFailedError, + WorkerNotFoundError, + WorkerTimeoutError, +) +from areal.scheduler.rpc.serialization import deserialize_value, serialize_value +from areal.utils import logging +from areal.utils.launcher import ( + get_env_vars, +) +from areal.utils.network import find_free_ports, gethostip + +logger = logging.getLogger("LocalScheduler") + + +@dataclass +class WorkerInfo: + """Internal tracking information for a worker process.""" + + worker: Worker # Public Worker object with id, ip, ports + process: subprocess.Popen # The subprocess handle + role: str # Worker role (e.g., "rollout", "actor", "critic") + gpu_devices: list[int] # Allocated GPU device IDs + created_at: float # Timestamp when worker was created + log_file: str # Path to stderr log file + env_vars: dict[str, str] = field(default_factory=dict) # Environment variables + + +class LocalScheduler(Scheduler): + """ + Local scheduler that manages worker subprocesses on a single GPU node. + + This scheduler spawns worker processes running RPC servers and manages their lifecycle. + It supports different worker types (rollout, actor, critic) through a unified interface. + + Features: + - Dynamic port allocation + - Round-robin GPU assignment + - Process health monitoring + - Comprehensive error handling + - Graceful cleanup + """ + + def __init__( + self, + gpu_devices: list[int] | None = None, + exp_config: BaseExperimentConfig | None = None, + fileroot: str | None = None, + experiment_name: str | None = None, + trial_name: str | None = None, + cluster_name: str | None = None, + log_dir: str | None = None, + startup_timeout: float = 30.0, + health_check_interval: float = 1.0, + ): + """ + Initialize the local scheduler. + + Args: + gpu_devices: List of GPU device IDs to use. If None, uses CUDA_VISIBLE_DEVICES or all GPUs. + log_dir: Directory for worker log files + startup_timeout: Maximum time to wait for worker startup (seconds) + health_check_interval: Interval for health checks (seconds) + """ + self.gpu_devices = gpu_devices or self._detect_gpus() + if log_dir is not None: + self.log_dir = Path(log_dir) + else: + experiment_name = experiment_name or exp_config.experiment_name + trial_name = trial_name or exp_config.trial_name + fileroot = fileroot or exp_config.cluster.fileroot + assert experiment_name is not None + assert trial_name is not None + assert fileroot is not None + self.log_dir = ( + Path(fileroot) + / "logs" + / getpass.getuser() + / experiment_name + / trial_name + ) + self.cluster_name = cluster_name or exp_config.cluster.cluster_name + self.exp_config = exp_config + + self.startup_timeout = startup_timeout + self.health_check_interval = health_check_interval + + # Create log directory + self.log_dir.mkdir(parents=True, exist_ok=True) + + # Track workers by worker_key + self._workers: dict[str, list[WorkerInfo]] = {} + + # GPU allocation counter for round-robin + self._gpu_counter = 0 + + # Track all allocated ports + self._allocated_ports = set() + + # HTTP clients for RPC communication + # FIXME: httpx may encounter "all connection attempts failed error" + self._http_client = httpx.Client(timeout=3600.0) # Sync client - 1 hour timeout + self._async_http_client = httpx.AsyncClient(timeout=3600.0) # Async client + + logger.info( + f"LocalScheduler initialized with GPU devices: {self.gpu_devices}, " + f"log directory: {self.log_dir}" + ) + + def _detect_gpus(self) -> list[int]: + """Detect available GPU devices.""" + cuda_visible = os.environ.get(current_platform.device_control_env_var) + if cuda_visible: + try: + return [int(x) for x in cuda_visible.split(",")] + except ValueError: + logger.warning( + f"Invalid {current_platform.device_control_env_var}: {cuda_visible}, using default [0]" + ) + return [0] + # Default to single GPU + return [0] + + def _allocate_gpus(self, num_gpus: int) -> list[int]: + """ + Allocate GPUs using round-robin strategy. + + Args: + num_gpus: Number of GPUs to allocate + + Returns: + List of GPU device IDs + + Raises: + GPUAllocationError: If not enough GPUs available + """ + if num_gpus > len(self.gpu_devices): + raise GPUAllocationError( + f"Requested {num_gpus} GPUs but only {len(self.gpu_devices)} available" + ) + + allocated = [] + for _ in range(num_gpus): + gpu_id = self.gpu_devices[self._gpu_counter % len(self.gpu_devices)] + allocated.append(gpu_id) + self._gpu_counter += 1 + + return allocated + + def _get_colocated_gpus(self, target_role: str, worker_idx: int) -> list[int]: + """ + Get GPU allocation from another role for colocation. + + Args: + target_role: The role to colocate with + worker_idx: Index of the worker to get GPUs from + + Returns: + List of GPU device IDs used by the target worker + + Raises: + WorkerNotFoundError: If target role doesn't exist + ValueError: If worker index is out of range + """ + if target_role not in self._workers: + raise WorkerNotFoundError( + f"Cannot colocate with role '{target_role}' - role not found" + ) + + target_workers = self._workers[target_role] + if worker_idx >= len(target_workers): + raise ValueError( + f"Cannot colocate with {target_role}/{worker_idx} - only {len(target_workers)} workers exist" + ) + + return target_workers[worker_idx].gpu_devices + + def _allocate_ports(self, count: int) -> list[int]: + """ + Allocate free ports. + + Args: + count: Number of ports to allocate + + Returns: + List of allocated port numbers + + Raises: + PortAllocationError: If port allocation fails + """ + try: + # Pass a copy of allocated_ports to avoid reference issues + ports = find_free_ports(count, exclude_ports=set(self._allocated_ports)) + self._allocated_ports.update(ports) + return ports + except ValueError as e: + raise PortAllocationError(str(e)) from e + + def _prepare_worker_specs( + self, role: str, num_workers: int, schedulings: list[SchedulingSpec] | None + ) -> list[SchedulingSpec]: + """ + Prepare worker specs for a given number of workers. + + Args: + role: Worker role name + num_workers: Number of workers to create + schedulings: Optional list of scheduling specs + + Returns: + List of SchedulingSpec objects (one per worker) + + Raises: + WorkerCreationError: If schedulings configuration is invalid + """ + if not schedulings: + # Default spec: 1 CPU, 1024 MB mem, 1 GPU, 2 ports + return [SchedulingSpec(cpu=1, mem=1024, gpu=1, port_count=2)] * num_workers + + # If a single spec is provided, use it for all workers + if len(schedulings) == 1: + return [schedulings[0]] * num_workers + + # If per-worker specs, validate length matches + if len(schedulings) == num_workers: + return schedulings + + # Invalid configuration + raise WorkerCreationError( + role, + "Invalid configuration", + f"schedulings length ({len(schedulings)}) must be 1 or equal to replicas ({num_workers})", + ) + + def create_workers(self, job: Job, *args, **kwargs) -> list[str]: + """ + Create worker subprocesses. + + Args: + job: Job configuration with role, replicas, tasks, and scheduling strategy + *args: Additional arguments passed to worker command + **kwargs: Additional keyword arguments + + Returns: + List of worker IDs created (e.g., ["rollout/0", "rollout/1"]) + + Raises: + WorkerCreationError: If worker creation fails + GPUAllocationError: If GPU allocation fails + PortAllocationError: If port allocation fails + """ + role = job.role + if role in self._workers: + raise WorkerCreationError( + role, + "Worker group already exists", + f"Use delete_workers('{role}') first to remove existing workers", + ) + + # Extract configuration + num_workers = job.replicas + if num_workers == 0: + raise WorkerCreationError( + role, "Invalid configuration", "replicas must be greater than 0" + ) + + # Prepare worker specs + schedulings = self._prepare_worker_specs(role, num_workers, job.tasks) + + # Determine scheduling strategy + strategy = job.scheduling_strategy + if strategy is None: + strategy_type = "separation" + colocate_role = None + else: + strategy_type = strategy.type or "separation" + colocate_role = strategy.target if strategy_type == "colocation" else None + + logger.info( + f"Creating {num_workers} workers for role '{role}' " + f"(strategy: {strategy_type}, colocate_with: {colocate_role})" + ) + + workers = [] + worker_ids = [] + try: + for idx in range(num_workers): + worker_id = f"{role}/{idx}" + scheduling = schedulings[idx] + + # Allocate resources based on strategy + try: + # GPU allocation + if strategy_type == "colocation": + if not colocate_role: + raise WorkerCreationError( + role, + "Invalid strategy", + "Colocation strategy requires target role to be specified", + ) + gpu_devices = self._get_colocated_gpus(colocate_role, idx) + logger.debug( + f"Worker {worker_id} colocated with {colocate_role}/{idx} on GPUs {gpu_devices}" + ) + else: # "separation" or default + gpu_devices = self._allocate_gpus(scheduling.gpu) + logger.debug( + f"Worker {worker_id} allocated new GPUs {gpu_devices}" + ) + + ports = self._allocate_ports(scheduling.port_count) + except ( + GPUAllocationError, + PortAllocationError, + WorkerNotFoundError, + ValueError, + ) as e: + # Clean up partially created workers + self._cleanup_workers(workers) + raise WorkerCreationError( + role, f"Resource allocation failed for worker {idx}", str(e) + ) from e + + # Prepare environment + env = get_env_vars( + self.cluster_name, + ",".join([f"{k}={v}" for k, v in scheduling.env_vars.items()]), + ) + env[current_platform.device_control_env_var] = ",".join( + map(str, gpu_devices) + ) + + # Merge user-provided environment variables from scheduling + if scheduling.env_vars: + env.update(scheduling.env_vars) + + # Prepare log file + log_file = self.log_dir / f"{worker_id.replace('/', '_')}.log" + + # Build command to start RPC server + if not scheduling.cmd: + self._cleanup_workers(workers) + raise WorkerCreationError( + role, + f"SchedulingSpec.cmd is required but not set for worker {worker_id}", + "Specify either 'python -m areal.scheduler.rpc.rpc_server' or " + "'python -m areal.scheduler.rpc.rpc_server' in your config.", + ) + + cmd = shlex.split(scheduling.cmd) + # Append --port argument to command + cmd.extend(["--port", str(ports[0])]) + + logger.info(f"Starting worker {worker_id}: {' '.join(cmd)}") + if cmd[0].startswith("python"): + cmd[0] = sys.executable + + cmd = ( + " ".join(str(k) + "=" + str(v) for k, v in env.items()) + + " stdbuf -oL " + + " ".join(cmd) + ) + cmd = f"{cmd} 2>&1 | tee -a {log_file}" + # Spawn subprocess + try: + process = subprocess.Popen( + cmd, + shell=isinstance(cmd, str), + stdout=sys.stdout, + stderr=sys.stdout, + ) + except Exception as e: + self._cleanup_workers(workers) + raise WorkerCreationError( + role, + f"Failed to spawn subprocess for worker {idx}", + str(e), + ) from e + + # Check if process started successfully + time.sleep(0.1) # Brief delay to catch immediate failures + if process.poll() is not None: + stderr = self._read_log_tail(log_file) + self._cleanup_workers(workers) + raise WorkerCreationError( + role, + f"Worker {worker_id} exited immediately with code {process.returncode}", + stderr, + ) + + # Create worker info + worker = Worker( + id=worker_id, + ip=gethostip(), + worker_ports=[str(p) for p in ports], + engine_ports=[], + ) + + worker_info = WorkerInfo( + worker=worker, + process=process, + role=role, + gpu_devices=gpu_devices, + created_at=time.time(), + log_file=str(log_file), + env_vars=env, + ) + + workers.append(worker_info) + worker_ids.append(worker_id) + logger.info( + f"Worker {worker_id} started (PID: {process.pid}, " + f"GPUs: {gpu_devices}, ports: {ports})" + ) + + # Store workers + self._workers[role] = workers + + logger.info( + f"Successfully created {len(workers)} workers for role '{role}'" + ) + + except Exception as e: + # Clean up any workers created before the failure + self._cleanup_workers(workers) + if isinstance(e, SchedulerError): + raise + raise WorkerCreationError(role, "Unexpected error", str(e)) from e + + # Send HTTP request to configure workers + for worker_rank, worker_info in enumerate(workers): + while not self._is_worker_ready(worker_info): + time.sleep(0.1) + worker_id = worker_info.worker.id + port = int(worker_info.worker.worker_ports[0]) + url = f"http://{worker_info.worker.ip}:{port}/configure" + + try: + response = self._http_client.post( + url, + content=orjson.dumps( + serialize_value( + dict( + config=self.exp_config, + role=worker_info.role, + rank=worker_rank, + ) + ) + ), + headers={"Content-Type": "application/json"}, + timeout=300.0, + ) + + if response.status_code == 200: + logger.info(f"Configuration successfully on worker '{worker_id}'") + continue + elif response.status_code == 400: + # Import error or bad request + error_detail = response.json().get("detail", "Unknown error") + raise WorkerConfigurationError(worker_id, error_detail, str(400)) + elif response.status_code == 500: + # Engine initialization failed + error_detail = response.json().get("detail", "Unknown error") + raise WorkerConfigurationError(worker_id, error_detail, str(500)) + else: + raise WorkerConfigurationError( + worker_id, + f"Unexpected status code: {response.status_code}", + str(response.status_code), + ) + + except httpx.ConnectError as e: + # Check if worker died + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_id, worker_info.process.returncode, stderr + ) from e + raise RPCConnectionError( + worker_id, worker_info.worker.ip, port, str(e) + ) from e + + except httpx.TimeoutException as e: + raise WorkerConfigurationError( + worker_id, f"Request timed out: {e}" + ) from e + + except WorkerConfigurationError: + raise + + except Exception as e: + raise WorkerConfigurationError( + worker_id, f"Unexpected error: {str(e)}" + ) from e + + return worker_ids + + def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: + """ + Get workers and wait for them to be ready. + + Args: + role: Worker role name + timeout: Maximum time to wait for workers to be ready (None = use default) + + Returns: + List of Worker objects + + Raises: + WorkerNotFoundError: If role doesn't exist + WorkerFailedError: If any worker process failed + WorkerTimeoutError: If timeout exceeded waiting for workers + """ + if role not in self._workers: + raise WorkerNotFoundError(role) + + workers = self._workers[role] + timeout = timeout if timeout is not None else self.startup_timeout + + # First check that all processes are still alive + self._check_worker_health(role) + + # Wait for RPC servers to be ready + start_time = time.time() + ready_workers = set() + + while len(ready_workers) < len(workers): + if time.time() - start_time > timeout: + raise WorkerTimeoutError( + role, + timeout, + ) + + for worker_info in workers: + if worker_info.worker.id in ready_workers: + continue + + # Check if process is still alive + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_info.worker.id, + worker_info.process.returncode, + stderr, + ) + + # Check if RPC server is ready + if self._is_worker_ready(worker_info): + ready_workers.add(worker_info.worker.id) + logger.debug(f"Worker {worker_info.worker.id} is ready") + + if len(ready_workers) < len(workers): + time.sleep(self.health_check_interval) + + logger.info(f"All {len(workers)} workers for role '{role}' are ready") + return [w.worker for w in workers] + + def _is_worker_ready(self, worker_info: WorkerInfo) -> bool: + """Check if worker's RPC server is ready via HTTP health check.""" + port = int(worker_info.worker.worker_ports[0]) + url = f"http://{worker_info.worker.ip}:{port}/health" + + try: + response = self._http_client.get(url, timeout=2.0) + return response.status_code == 200 + except Exception: + return False + + def _check_worker_health(self, role: str): + """ + Check health of all workers in a group. + + Raises: + WorkerFailedError: If any worker has failed + """ + if role not in self._workers: + return + + for worker_info in self._workers[role]: + returncode = worker_info.process.poll() + if returncode is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_info.worker.id, + returncode, + stderr, + ) + + def delete_workers(self, role: str | None = None): + """ + Delete workers and clean up resources. + + Args: + role: Specific worker role to delete, or None to delete all + """ + if role is None: + # Delete all workers + roles = list(self._workers.keys()) + for r in roles: + self.delete_workers(r) + return + + if role not in self._workers: + logger.warning(f"Worker role '{role}' not found, skipping deletion") + return + + workers = self._workers[role] + logger.info(f"Deleting {len(workers)} workers for role '{role}'") + + self._cleanup_workers(workers) + + # Remove from tracking + del self._workers[role] + + logger.info(f"Successfully deleted workers for role '{role}'") + + def _cleanup_workers(self, workers: list[WorkerInfo]): + """Clean up worker processes and resources.""" + for worker_info in workers: + try: + # Release ports + for port_str in worker_info.worker.worker_ports: + self._allocated_ports.discard(int(port_str)) + + # Terminate process tree + self._terminate_process_tree(worker_info.process.pid) + + logger.debug(f"Cleaned up worker {worker_info.worker.id}") + except Exception as e: + logger.error( + f"Error cleaning up worker {worker_info.worker.id}: {e}", + exc_info=True, + ) + + def _terminate_process_tree(self, pid: int): + """Terminate a process and all its children.""" + try: + parent = psutil.Process(pid) + children = parent.children(recursive=True) + + # Try graceful termination first + for child in children: + try: + child.terminate() + except psutil.NoSuchProcess: + pass + + try: + parent.terminate() + except psutil.NoSuchProcess: + return + + # Wait for graceful termination + _, alive = psutil.wait_procs([parent] + children, timeout=3) + + # Force kill remaining processes + for proc in alive: + try: + proc.kill() + except psutil.NoSuchProcess: + pass + + except psutil.NoSuchProcess: + # Process already gone + pass + except Exception: + import traceback + + logger.warning( + f"Error terminating process tree {pid}: {traceback.print_exec()}" + ) + + def _read_log_tail(self, log_file: str, lines: int = 50) -> str: + """Read the last N lines from a log file.""" + try: + with open(log_file) as f: + all_lines = f.readlines() + return "".join(all_lines[-lines:]) + except Exception as e: + return f"[Could not read log file: {e}]" + + async def create_engine( + self, + worker_id: str, + engine: str, + *args, + **kwargs, + ) -> Any: + """ + Create an engine instance on a remote worker. + + The engine parameter is a string import path (e.g., "areal.engine.ppo.actor.FSDPPPOActor") + that will be dynamically imported and instantiated on the worker. + + Args: + worker_id: Worker ID in format "role/index" + engine: Import path to the engine class (e.g., "areal.engine.ppo.actor.FSDPPPOActor") + *args: Initialization arguments + **kwargs: Initialization keyword arguments + + Returns: + Result from engine initialization + + Raises: + WorkerNotFoundError: If worker doesn't exist + WorkerFailedError: If worker process has failed + EngineCreationError: If engine creation fails + """ + # Verify worker exists and is alive + worker_info = self._verify_worker_alive(worker_id) + + # Validate engine is a string import path + if not isinstance(engine, str): + raise EngineCreationError( + worker_id, + f"Engine must be a string import path, got {type(engine)}", + ) + + # Build JSON payload with serialized args and kwargs + payload = { + "engine": engine, + "init_args": serialize_value(list(args)), + "init_kwargs": serialize_value(kwargs), + } + + # Send HTTP request to create engine + port = int(worker_info.worker.worker_ports[0]) + url = f"http://{worker_info.worker.ip}:{port}/create_engine" + + try: + logger.info(f"Creating engine '{engine}' on worker '{worker_id}'") + + response = await self._async_http_client.post( + url, + content=orjson.dumps(payload), + headers={"Content-Type": "application/json"}, + timeout=300.0, + ) + + if response.status_code == 200: + result = response.json() + logger.info(f"Engine created successfully on worker '{worker_id}'") + return result.get("result") + elif response.status_code == 400: + # Import error or bad request + error_detail = response.json().get("detail", "Unknown error") + if "Failed to import" in error_detail: + raise EngineImportError(engine, error_detail) + else: + raise EngineCreationError(worker_id, error_detail, 400) + elif response.status_code == 500: + # Engine initialization failed + error_detail = response.json().get("detail", "Unknown error") + raise EngineCreationError(worker_id, error_detail, 500) + else: + raise EngineCreationError( + worker_id, + f"Unexpected status code: {response.status_code}", + response.status_code, + ) + + except httpx.ConnectError as e: + # Check if worker died + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_id, worker_info.process.returncode, stderr + ) from e + raise RPCConnectionError( + worker_id, worker_info.worker.ip, port, str(e) + ) from e + + except httpx.TimeoutException as e: + raise EngineCreationError(worker_id, f"Request timed out: {e}") from e + + except (EngineCreationError, EngineImportError, RPCConnectionError): + raise + + except Exception as e: + raise EngineCreationError(worker_id, f"Unexpected error: {str(e)}") from e + + def call_engine( + self, + worker_id: str, + method: str, + *args, + max_retries: int = 3, + retry_delay: float = 1.0, + **kwargs, + ) -> Any: + """ + Call a method on an engine. + + Args: + worker_id: Worker ID in format "role/index" + method: Method name to call + *args: Method arguments + max_retries: Maximum number of retry attempts + retry_delay: Initial delay between retries (exponential backoff) + **kwargs: Method keyword arguments + + Returns: + Result from method call + + Raises: + WorkerNotFoundError: If worker doesn't exist + WorkerFailedError: If worker process has failed + EngineCallError: If method call fails + """ + # Get worker info (initial verification) + worker_info = self._find_worker_by_id(worker_id) + if worker_info is None: + raise WorkerNotFoundError(worker_id) + + # Serialize args and kwargs (convert tensors to SerializedTensor dicts) + serialized_args = serialize_value(list(args)) + serialized_kwargs = serialize_value(kwargs) + + # Build JSON payload + payload = { + "method": method, + "args": serialized_args, + "kwargs": serialized_kwargs, + } + + # Retry logic with exponential backoff + port = int(worker_info.worker.worker_ports[0]) + url = f"http://{worker_info.worker.ip}:{port}/call" + last_error = None + + for attempt in range(1, max_retries + 1): + # Check worker health before each attempt + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_id, + worker_info.process.returncode, + stderr, + ) + + try: + logger.debug( + f"Calling method '{method}' on worker '{worker_id}' (attempt {attempt})" + ) + + response = self._http_client.post( + url, + content=orjson.dumps(payload), + headers={"Content-Type": "application/json"}, + timeout=7200.0, # 2 hours for long-running operations + ) + + result, should_retry, error_msg = self._handle_call_response( + response, worker_id, method, attempt + ) + if not should_retry: + if attempt > 1: + logger.info( + f"Method '{method}' succeeded on worker '{worker_id}' " + f"after {attempt} attempts" + ) + return result + last_error = error_msg + + except Exception as e: + last_error = self._handle_call_exception(e, worker_info, worker_id) + + # Retry with exponential backoff + if attempt < max_retries: + delay = retry_delay * (2 ** (attempt - 1)) + logger.warning( + f"Method '{method}' failed on worker '{worker_id}' " + f"(attempt {attempt}/{max_retries}): {last_error}. " + f"Retrying in {delay:.1f}s..." + ) + time.sleep(delay) + + # All retries exhausted + raise EngineCallError( + worker_id, + method, + last_error or "Max retries exceeded", + attempt=max_retries, + ) + + async def async_call_engine( + self, + worker_id: str, + method: str, + *args, + max_retries: int = 3, + retry_delay: float = 1.0, + **kwargs, + ) -> Any: + """ + Async version of call_engine for calling engine methods asynchronously. + + Args: + worker_id: Worker ID in format "role/index" + method: Method name to call + *args: Method arguments + max_retries: Maximum number of retry attempts + retry_delay: Initial delay between retries (exponential backoff) + **kwargs: Method keyword arguments + + Returns: + Result from method call + + Raises: + WorkerNotFoundError: If worker doesn't exist + WorkerFailedError: If worker process has failed + EngineCallError: If method call fails + """ + # Get worker info (initial verification) + worker_info = self._find_worker_by_id(worker_id) + if worker_info is None: + raise WorkerNotFoundError(worker_id) + + # Route to different endpoint based on method + port = int(worker_info.worker.worker_ports[0]) + if method == "run_workflow": + # Special routing for workflow execution + url = f"http://{worker_info.worker.ip}:{port}/run_workflow" + # Serialize kwargs for workflow execution + payload = serialize_value(kwargs) + elif method == "export_stats": + url = f"http://{worker_info.worker.ip}:{port}/export_stats" + payload = None + else: + # Standard engine method call + url = f"http://{worker_info.worker.ip}:{port}/call" + # Serialize args and kwargs + serialized_args = serialize_value(list(args)) + serialized_kwargs = serialize_value(kwargs) + payload = { + "method": method, + "args": serialized_args, + "kwargs": serialized_kwargs, + } + + last_error = None + + for attempt in range(1, max_retries + 1): + # Check worker health before each attempt + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_id, + worker_info.process.returncode, + stderr, + ) + + try: + logger.info( + f"Async calling method '{method}' on worker '{worker_id}' (attempt {attempt})" + ) + + response = await self._async_http_client.post( + url, + content=orjson.dumps(payload), + headers={"Content-Type": "application/json"}, + timeout=7200.0, # 2 hours for long-running operations + ) + + result, should_retry, error_msg = self._handle_call_response( + response, worker_id, method, attempt + ) + if not should_retry: + if attempt > 1: + logger.info( + f"Method '{method}' succeeded on worker '{worker_id}' " + f"after {attempt} attempts" + ) + return result + last_error = error_msg + + except Exception as e: + last_error = self._handle_call_exception(e, worker_info, worker_id) + + # Retry with exponential backoff + if attempt < max_retries: + delay = retry_delay * (2 ** (attempt - 1)) + logger.warning( + f"Method '{method}' failed on worker '{worker_id}' " + f"(attempt {attempt}/{max_retries}): {last_error}. " + f"Retrying in {delay:.1f}s..." + ) + import asyncio + + await asyncio.sleep(delay) + + # All retries exhausted + raise EngineCallError( + worker_id, + method, + last_error or "Max retries exceeded", + attempt=max_retries, + ) + + def _find_worker_by_id(self, worker_id: str) -> WorkerInfo | None: + """Find a worker by its ID.""" + for workers in self._workers.values(): + for worker_info in workers: + if worker_info.worker.id == worker_id: + return worker_info + return None + + def _verify_worker_alive(self, worker_id: str) -> WorkerInfo: + """ + Verify a worker exists and is alive. + + Args: + worker_id: Worker ID to verify + + Returns: + WorkerInfo object + + Raises: + WorkerNotFoundError: If worker doesn't exist + WorkerFailedError: If worker process has failed + """ + worker_info = self._find_worker_by_id(worker_id) + if worker_info is None: + raise WorkerNotFoundError(worker_id) + + # Check if process has exited + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_id, + worker_info.process.returncode, + stderr, + ) + + return worker_info + + def _handle_call_response( + self, response, worker_id: str, method: str, attempt: int + ): + """ + Handle HTTP response from engine call. + + Args: + response: HTTP response object + worker_id: Worker ID + method: Method name being called + attempt: Current retry attempt number + + Returns: + Tuple of (result, should_retry, error_message) + - result: The result from the call if successful, None otherwise + - should_retry: Whether to retry the request + - error_message: Error message if failed, None if successful + """ + if response.status_code == 200: + result = response.json().get("result") + # Deserialize result (convert SerializedTensor dicts back to tensors) + deserialized_result = deserialize_value(result) + return deserialized_result, False, None + elif response.status_code == 400: + # Bad request (e.g., method doesn't exist) - don't retry + error_detail = response.json().get("detail", "Unknown error") + raise EngineCallError(worker_id, method, error_detail, attempt) + elif response.status_code == 500: + # Engine method failed - don't retry + error_detail = response.json().get("detail", "Unknown error") + raise EngineCallError(worker_id, method, error_detail, attempt) + elif response.status_code == 503: + # Service unavailable - retry + return None, True, "Service unavailable" + else: + # Other errors - retry + return None, True, f"HTTP {response.status_code}: {response.text}" + + def _handle_call_exception( + self, e: Exception, worker_info: WorkerInfo, worker_id: str + ) -> str: + """ + Handle exceptions during engine calls and return error message. + + Args: + e: The exception that occurred + worker_info: Worker information + worker_id: Worker ID + + Returns: + Error message string + + Raises: + WorkerFailedError: If worker has died + EngineCallError: If non-retryable error + """ + if isinstance(e, httpx.ConnectError): + # Check if worker died + if worker_info.process.poll() is not None: + stderr = self._read_log_tail(worker_info.log_file) + raise WorkerFailedError( + worker_id, + worker_info.process.returncode, + stderr, + ) from e + return f"Connection error: {e}" + elif isinstance(e, httpx.TimeoutException): + return f"Timeout: {e}" + elif isinstance(e, EngineCallError): + raise + else: + return f"Unexpected error: {e}" + + def __del__(self): + """Cleanup on deletion.""" + try: + self.delete_workers() + except Exception: + pass + try: + self._http_client.close() + except Exception: + pass + try: + import asyncio + + # Close async client if event loop is available + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + asyncio.create_task(self._async_http_client.aclose()) + else: + loop.run_until_complete(self._async_http_client.aclose()) + except RuntimeError: + # No event loop, sync close + pass + except Exception: + pass diff --git a/areal/scheduler/rpc/rpc_client.py b/areal/scheduler/rpc/rpc_client.py deleted file mode 100644 index 28f4b8082..000000000 --- a/areal/scheduler/rpc/rpc_client.py +++ /dev/null @@ -1,137 +0,0 @@ -import gzip -import time -from http import HTTPStatus -from typing import Any, Union - -import cloudpickle -import requests - -from areal.api.cli_args import InferenceEngineConfig, TrainEngineConfig -from areal.api.engine_api import InferenceEngine, TrainEngine -from areal.utils import logging -from areal.utils.http import response_ok, response_retryable - -logger = logging.getLogger("RPCClient") - - -class RPCClient: - def __init__(self): - self._addrs = {} - - def register(self, worker_id: str, ip: str, port: int) -> None: - self._addrs[worker_id] = (ip, port) - logger.info(f"Registered worker {worker_id} at {ip}:{port}") - - def create_engine( - self, - worker_id: str, - engine_obj: Union[InferenceEngine, TrainEngine], - init_config: Union[InferenceEngineConfig, TrainEngineConfig], - ) -> None: - ip, port = self._addrs[worker_id] - url = f"http://{ip}:{port}/create_engine" - logger.info(f"send create_engine to {worker_id} ({ip}:{port})") - payload = (engine_obj, init_config) - serialized_data = cloudpickle.dumps(payload) - serialized_obj = gzip.compress(serialized_data) - resp = requests.post(url, data=serialized_obj) - logger.info( - f"send create_engine to {worker_id} ({ip}:{port}), status={resp.status_code}" - ) - if resp.status_code == HTTPStatus.OK: - logger.info(f"create engine success.") - return cloudpickle.loads(resp.content) - else: - logger.error(f"Failed to create engine, {resp.status_code}, {resp.content}") - raise RuntimeError( - f"Failed to create engine, {resp.status_code}, {resp.content}" - ) - - def call_engine( - self, worker_id: str, method: str, max_retries: int = 3, *args, **kwargs - ) -> Any: - """ - call the rpc server with method name and args, retry on failure - - Parameters - ---------- - worker_id: str - the id of the worker to call - method: str - the method name to call - max_retries: int - max retries on failure - *args: - args to pass to the method - **kwargs: - kwargs to pass to the method - - Returns - ------- - the deserialized result from the rpc server - """ - req = (method, args, kwargs) - serialized_data = cloudpickle.dumps(req) - - return self._call_engine_with_serialized_data( - worker_id, serialized_data, max_retries - ) - - def _call_engine_with_serialized_data( - self, worker_id: str, serialized_data: bytes, max_retries=3 - ) -> Any: - """ - call the rpc server with serialized data, retry on failure - - Parameters - ---------- - worker_id: str - the id of the worker to call - serialized_data: bytes - the serialized data to send - max_retries: int - max retries on failure - - Returns - ------- - the deserialized result from the rpc server - """ - ip, port = self._addrs[worker_id] - url = f"http://{ip}:{port}/call" - last_exception = None - - for attempt in range(max_retries): - try: - resp = requests.post(url, data=serialized_data, timeout=7200) - logger.info( - f"Sent call to {worker_id} ({ip}:{port}), status={resp.status_code}, attempt {attempt + 1}/{max_retries}" - ) - - if response_ok(resp.status_code): - return cloudpickle.loads(resp.content) - elif response_retryable(resp.status_code): - last_exception = RuntimeError( - f"Retryable HTTP status {resp.status_code}: {resp.content}" - ) - else: - raise RuntimeError( - f"Non-retryable HTTP error: {resp.status_code} - {resp.content}" - ) - - except (RuntimeError, TimeoutError) as e: - logger.error(f"stop retrying, error on attempt {attempt + 1}: {e}") - raise e - except Exception as e: - last_exception = e - logger.error(f"error on attempt {attempt + 1}: {e}") - - if last_exception is not None: - if attempt < max_retries - 1: - logger.warning( - f"Retrying in 1 second... ({attempt + 1}/{max_retries})" - ) - time.sleep(1) - continue - else: - logger.error(f"Max retries exceeded for {url}") - raise last_exception diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index dd05f23e9..ad54b4f99 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -1,142 +1,366 @@ import argparse -import gzip -import os +import importlib import traceback -from http import HTTPStatus -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from typing import AnyStr +from concurrent.futures import Future -import cloudpickle +from flask import Flask, jsonify, request -from areal.api.controller_api import DistributedBatch -from areal.controller.batch import DistributedBatchMemory -from areal.utils import logging +from areal.api.cli_args import BaseExperimentConfig +from areal.api.engine_api import InferenceEngine, TrainEngine +from areal.platforms import current_platform +from areal.scheduler.rpc.serialization import deserialize_value, serialize_value +from areal.utils import logging, name_resolve, seeding, stats_tracker +from areal.utils.data import ( + broadcast_tensor_container, + tensor_container_to, +) -logger = logging.getLogger("RPCServer") +logger = logging.getLogger("SyncRPCServer") +# Global engine instance - must be TrainEngine or InferenceEngine +_engine: TrainEngine | InferenceEngine | None = None -def process_input_to_distributed_batch(*args, **kwargs): - for i in range(len(args)): - if isinstance(args[i], DistributedBatch): - args = list(args) - args[i] = args[i].get_data() - args = tuple(args) +# Create Flask app +app = Flask(__name__) - for k in list(kwargs.keys()): - if isinstance(kwargs[k], DistributedBatch): - kwargs[k] = kwargs[k].get_data() - return args, kwargs +@app.route("/health", methods=["GET"]) +def health_check(): + """Health check endpoint to verify server is alive.""" + global _engine + return jsonify({"status": "healthy", "engine_initialized": _engine is not None}) -def process_output_to_distributed_batch(result): - if isinstance(result, dict): - return DistributedBatchMemory.from_dict(result) - elif isinstance(result, (list, tuple)): - return DistributedBatchMemory.from_list(list(result)) - else: - return result +@app.route("/configure", methods=["POST"]) +def configure(): + """Configure worker with experiment config.""" + try: + data = request.get_json() + if data is None: + return jsonify({"detail": "Invalid JSON in request body"}), 400 + config = data.get("config") + if config is None: + return jsonify({"detail": "Missing 'config' field in request"}), 400 -class EngineRPCServer(BaseHTTPRequestHandler): - engine = None + role = data.get("role") + if role is None: + return jsonify({"detail": "Missing 'role' field in request"}), 400 - def _read_body(self, timeout=120.0) -> AnyStr: - old_timeout = None + rank = data.get("rank") + if rank is None: + return jsonify({"detail": "Missing 'rank' field in request"}), 400 + + config = deserialize_value(config) + config: BaseExperimentConfig + + name_resolve.reconfigure(config.cluster.name_resolve) + seeding.set_random_seed(config.seed, key=f"{role}{rank}") + + return jsonify( + { + "status": "success", + "message": "Worker configured successful.", + "result": None, + } + ) + except Exception as e: + logger.error(f"Unexpected error in configure: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + +@app.route("/create_engine", methods=["POST"]) +def create_engine(): + """ + Create and initialize a TrainEngine or InferenceEngine instance on this worker. + + Expected JSON payload: + { + "engine": "areal.engine.ppo.actor.FSDPPPOActor", # Import path + "init_args": [...], # Positional arguments + "init_kwargs": {...} # Keyword arguments + } + """ + global _engine + + try: + data = request.get_json() + if data is None: + return jsonify({"error": "Invalid JSON in request body"}), 400 + + engine_path = data.get("engine") + # Deserialize init_args and init_kwargs (may contain tensors or dataclasses) + init_args = deserialize_value(data.get("init_args", [])) + init_kwargs = deserialize_value(data.get("init_kwargs", {})) + + if not engine_path: + return jsonify({"error": "Missing 'engine' field in request"}), 400 + + # Dynamic import try: - length = int(self.headers["Content-Length"]) - old_timeout = self.request.gettimeout() - logger.info(f"Receive rpc call, path: {self.path}, timeout: {old_timeout}") - # set max read timeout = 120s here, if read hang raise exception - self.request.settimeout(timeout) - return self.rfile.read(length) - except Exception as e: - raise e - finally: - self.request.settimeout(old_timeout) + module_path, class_name = engine_path.rsplit(".", 1) + module = importlib.import_module(module_path) + engine_class = getattr(module, class_name) - def do_POST(self): - data = None + # Validate that the class is a TrainEngine or InferenceEngine + if not issubclass(engine_class, TrainEngine) and not issubclass( + engine_class, InferenceEngine + ): + raise TypeError( + f"Engine class must be a subclass of TrainEngine or InferenceEngine, " + f"got {engine_class}.." + ) + except (ValueError, ImportError, AttributeError) as e: + logger.error(f"Failed to import engine '{engine_path}': {e}") + return ( + jsonify( + {"error": f"Failed to import engine '{engine_path}': {str(e)}"} + ), + 400, + ) + except TypeError as e: + logger.error(f"Invalid engine type: {e}") + return jsonify({"error": str(e)}), 400 + + # Instantiate engine try: - data = self._read_body() + _engine = engine_class(*init_args, **init_kwargs) + logger.info(f"Engine '{engine_path}' instantiated successfully") + return jsonify( + { + "status": "success", + "message": f"Engine '{engine_path}' created and initialized", + "result": None, + } + ) except Exception as e: - self.send_response( - HTTPStatus.REQUEST_TIMEOUT - ) # 408 means read request timeout - self.end_headers() - self.wfile.write(f"Exception: {e}\n{traceback.format_exc()}".encode()) - logger.error(f"Exception in do_POST: {e}\n{traceback.format_exc()}") - return + logger.error(f"Failed to instantiate engine: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Failed to instantiate engine: {str(e)}"}), 500 + + except Exception as e: + logger.error( + f"Unexpected error in create_engine: {e}\n{traceback.format_exc()}" + ) + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + + +@app.route("/call", methods=["POST"]) +def call_engine_method(): + """ + Call a method on the engine instance. + + Expected JSON payload: + { + "method": "train_batch", + "args": [...], + "kwargs": {...} + } + """ + global _engine + + if _engine is None: + return ( + jsonify({"error": "Engine not initialized. Call /create_engine first."}), + 503, + ) + + try: + data = request.get_json() + if data is None: + return jsonify({"error": "Invalid JSON in request body"}), 400 + + method_name = data.get("method") + args = data.get("args", []) + kwargs = data.get("kwargs", {}) + + if not method_name: + return jsonify({"error": "Missing 'method' field in request"}), 400 + + # Deserialize args and kwargs (convert SerializedTensor dicts to tensors) + args = deserialize_value(args) + kwargs = deserialize_value(kwargs) try: - if self.path == "/create_engine": - decompressed_data = gzip.decompress(data) - engine_obj, init_args = cloudpickle.loads(decompressed_data) - EngineRPCServer.engine = engine_obj - result = EngineRPCServer.engine.initialize(init_args) - logger.info(f"Engine created and initialized on RPC server: {result}") - self.send_response(HTTPStatus.OK) - self.end_headers() - self.wfile.write(cloudpickle.dumps(result)) - elif self.path == "/call": - if EngineRPCServer.engine is None: - self.send_response(HTTPStatus.INTERNAL_SERVER_ERROR) - self.end_headers() - self.wfile.write(b"Engine is none") - logger.error("Call received but engine is none.") - return - action, args, kwargs = cloudpickle.loads(data) - method = getattr(EngineRPCServer.engine, action) - # NOTE: DO NOT print args here, args may be a very huge tensor - logger.info(f"RPC server calling engine method: {action}") - args, kwargs = process_input_to_distributed_batch(*args, **kwargs) - result = method(*args, **kwargs) - result = process_output_to_distributed_batch(result) - self.send_response(HTTPStatus.OK) - self.end_headers() - self.wfile.write(cloudpickle.dumps(result)) - else: - self.send_response(HTTPStatus.NOT_FOUND) - self.end_headers() + should_bcast = kwargs.pop("_should_bcast", True) + if should_bcast and isinstance(_engine, TrainEngine): + logger.info(f"Broadcasting data for TrainEngine method: {method_name}") + + args = tensor_container_to(args, current_platform.current_device()) + args = broadcast_tensor_container( + args, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + kwargs = tensor_container_to(kwargs, current_platform.current_device()) + kwargs = broadcast_tensor_container( + kwargs, + src_rank=_engine.current_data_parallel_head(), + group=_engine.context_and_model_parallel_group, + ) + logger.info("Broadcasting data done.") except Exception as e: - self.send_response(HTTPStatus.INTERNAL_SERVER_ERROR) - self.end_headers() - self.wfile.write(f"Exception: {e}\n{traceback.format_exc()}".encode()) - logger.error(f"Exception in do_POST: {e}\n{traceback.format_exc()}") - - -def start_rpc_server(port): - server = ThreadingHTTPServer(("0.0.0.0", port), EngineRPCServer) - server.serve_forever() - - -def get_serve_port(args): - port = args.port - port_str = os.environ.get("PORT_LIST", "").strip() - - # Check if PORT_LIST is set - if port_str: - # Split by comma and strip whitespace - ports = [p.strip() for p in port_str.split(",")] - # Use the first valid port from the list - if ports and ports[0]: - try: - return int(ports[0]) - except ValueError: - logger.warning( - f"Invalid port '{ports[0]}' in PORT_LIST. Falling back to --port argument." + logger.error( + f"Broadcasting data for method '{method_name}' failed: {e}\n{traceback.format_exc()}" + ) + return ( + jsonify({"error": f"Data broadcast '{method_name}' failed: {str(e)}"}), + 500, + ) + + # Special case for `submit` on inference engines + try: + if method_name == "submit" and isinstance(_engine, InferenceEngine): + workflow_path = kwargs["workflow_path"] + workflow_kwargs = kwargs["workflow_kwargs"] + episode_data = kwargs["data"] + should_accept_path = kwargs["should_accept_path"] + + # Deserialize episode_data (may contain tensors) + episode_data = deserialize_value(episode_data) + + # Dynamic import workflow + module_path, class_name = workflow_path.rsplit(".", 1) + module = importlib.import_module(module_path) + workflow_class = getattr(module, class_name) + logger.info(f"Imported workflow class: {workflow_path}") + + # Instantiate workflow + workflow_kwargs = deserialize_value(workflow_kwargs) + workflow = workflow_class(**workflow_kwargs) + logger.info(f"Workflow '{workflow_path}' instantiated successfully") + + should_accept = None + if should_accept_path is not None: + # Dynamic import filtering function + module_path, fn_name = should_accept_path.rsplit(".", 1) + module = importlib.import_module(module_path) + should_accept = getattr(module, fn_name) + logger.info(f"Imported filtering function: {should_accept_path}") + + args = [] + kwargs = dict( + data=episode_data, + workflow=workflow, + should_accept=should_accept, ) - return port + except Exception as e: + logger.error( + f"Workflow data conversion failed: {e}\n{traceback.format_exc()}" + ) + return ( + jsonify({"error": f"workflow data conversion failed: {str(e)}"}), + 500, + ) + # Call method directly + logger.info(f"Calling engine method: {method_name}") + try: + # Get the method - will raise AttributeError if it doesn't exist + method = getattr(_engine, method_name) + result = method(*args, **kwargs) + + # HACK: handle update weights future + if isinstance(result, Future): + result = result.result() + + # Serialize result (convert tensors to SerializedTensor dicts) + serialized_result = serialize_value(result) + return jsonify({"status": "success", "result": serialized_result}) + + except AttributeError as e: + logger.error(f"Method '{method_name}' not found on engine: {e}") + return ( + jsonify({"error": f"Engine does not have method '{method_name}'"}), + 400, + ) + except Exception as e: + logger.error( + f"Engine method '{method_name}' failed: {e}\n{traceback.format_exc()}" + ) + return ( + jsonify({"error": f"Engine method '{method_name}' failed: {str(e)}"}), + 500, + ) + + except Exception as e: + logger.error(f"Unexpected error in call: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--port", type=int, required=False) +@app.route("/export_stats", methods=["POST"]) +def export_stats(): + """Export training statistics from stats_tracker.""" + try: + global _engine + if _engine is None: + return jsonify({"error": "Engine not initialized"}), 503 - args, unknown = parser.parse_known_args() - port = get_serve_port(args) + # TrainEngine: reduce stats across data_parallel_group + assert isinstance(_engine, TrainEngine) + result = stats_tracker.export(reduce_group=_engine.data_parallel_group) + return jsonify({"status": "success", "result": result}) - logger.info(f"About to start RPC server on {port}") + except Exception as e: + logger.error(f"Unexpected error in export_stats: {e}\n{traceback.format_exc()}") + return jsonify({"error": f"Internal server error: {str(e)}"}), 500 - start_rpc_server(port) + +def cleanup_engine(): + """Clean up engine on shutdown.""" + global _engine + if _engine is not None: + try: + _engine.destroy() + logger.info("Engine destroyed successfully") + except Exception as e: + logger.error(f"Error destroying engine: {e}") + _engine = None + + +def main(): + """Main entry point for the sync RPC server.""" + parser = argparse.ArgumentParser( + description="AReaL Sync RPC Server for TrainEngine/InferenceEngine" + ) + parser.add_argument("--port", type=int, required=True, help="Port to serve on") + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)" + ) + parser.add_argument( + "--werkzeug-log-level", + type=str, + default="WARNING", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Log level for Werkzeug (Flask's WSGI server). Default: WARNING", + ) + + args, _ = parser.parse_known_args() + + # Configure Werkzeug logging + import logging as stdlib_logging + + werkzeug_logger = stdlib_logging.getLogger("werkzeug") + werkzeug_logger.setLevel(getattr(stdlib_logging, args.werkzeug_log_level)) + + logger.info(f"Starting sync RPC server on {args.host}:{args.port}") + logger.info(f"Werkzeug log level: {args.werkzeug_log_level}") + + # Run Flask app with single-threaded synchronous mode + # threaded=False ensures NCCL compatibility + try: + app.run( + host=args.host, + port=args.port, + threaded=False, # Single-threaded synchronous execution + processes=1, # Single process + debug=False, + use_reloader=False, + ) + except KeyboardInterrupt: + logger.info("Shutting down sync RPC server") + finally: + cleanup_engine() + + +if __name__ == "__main__": + main() diff --git a/areal/scheduler/rpc/serialization.py b/areal/scheduler/rpc/serialization.py new file mode 100644 index 000000000..ba13ba5cd --- /dev/null +++ b/areal/scheduler/rpc/serialization.py @@ -0,0 +1,317 @@ +"""Tensor and dataclass serialization utilities for RPC communication. + +This module provides utilities to serialize and deserialize PyTorch tensors +and dataclass instances for transmission over HTTP/JSON. Tensors are encoded +as base64 strings and dataclasses preserve their type information with metadata +stored in Pydantic models. + +Assumptions: +- All tensors are on CPU +- Gradient tracking (requires_grad) is not preserved +- Dataclasses are reconstructed with their original types +""" + +import base64 +import importlib +from dataclasses import is_dataclass +from typing import Any, Literal + +import torch +from pydantic import BaseModel, Field + + +class SerializedTensor(BaseModel): + """Pydantic model for serialized tensor with metadata. + + Attributes + ---------- + type : str + Type marker, always "tensor" + data : str + Base64-encoded tensor data + shape : List[int] + Tensor shape + dtype : str + String representation of dtype (e.g., "torch.float32") + """ + + type: Literal["tensor"] = Field(default="tensor") + data: str + shape: list[int] + dtype: str + + @classmethod + def from_tensor(cls, tensor: torch.Tensor) -> "SerializedTensor": + """Create SerializedTensor from a PyTorch tensor. + + Assumes tensor is on CPU or will be moved to CPU for serialization. + + Parameters + ---------- + tensor : torch.Tensor + Input tensor to serialize + + Returns + ------- + SerializedTensor + Serialized tensor with metadata + """ + # Move to CPU for serialization (detach to avoid gradient tracking) + cpu_tensor = tensor.detach().cpu() + + # Convert to bytes and encode as base64 + buffer = cpu_tensor.numpy().tobytes() + data_b64 = base64.b64encode(buffer).decode("utf-8") + + return cls( + data=data_b64, + shape=list(tensor.shape), + dtype=str(tensor.dtype), + ) + + def to_tensor(self) -> torch.Tensor: + """Reconstruct PyTorch tensor from serialized data. + + Returns CPU tensor without gradient tracking. + + Returns + ------- + torch.Tensor + Reconstructed CPU tensor + """ + # Decode base64 to bytes + buffer = base64.b64decode(self.data.encode("utf-8")) + + # Parse dtype string (e.g., "torch.float32" -> torch.float32) + dtype_str = self.dtype.replace("torch.", "") + dtype = getattr(torch, dtype_str) + + # Reconstruct tensor from bytes + import numpy as np + + np_array = np.frombuffer(buffer, dtype=self._torch_dtype_to_numpy(dtype)) + # Copy the array to make it writable before converting to tensor + np_array = np_array.copy() + tensor = torch.from_numpy(np_array).reshape(self.shape) + + # Cast to correct dtype (numpy might have different dtype) + tensor = tensor.to(dtype) + + return tensor + + @staticmethod + def _torch_dtype_to_numpy(torch_dtype: torch.dtype): + """Convert torch dtype to numpy dtype for buffer reading. + + Parameters + ---------- + torch_dtype : torch.dtype + PyTorch data type + + Returns + ------- + numpy.dtype + Corresponding NumPy data type + """ + import numpy as np + + dtype_map = { + torch.float32: np.float32, + torch.float64: np.float64, + torch.float16: np.float16, + torch.int32: np.int32, + torch.int64: np.int64, + torch.int16: np.int16, + torch.int8: np.int8, + torch.uint8: np.uint8, + torch.bool: np.bool_, + } + return dtype_map.get(torch_dtype, np.float32) + + +class SerializedDataclass(BaseModel): + """Pydantic model for serialized dataclass with metadata. + + Attributes + ---------- + type : str + Type marker, always "dataclass" + class_path : str + Full import path to the dataclass (e.g., "areal.api.cli_args.InferenceEngineConfig") + data : dict + Dataclass fields as dictionary (recursively serialized) + """ + + type: Literal["dataclass"] = Field(default="dataclass") + class_path: str + data: dict[str, Any] + + @classmethod + def from_dataclass(cls, dataclass_instance: Any) -> "SerializedDataclass": + """Create SerializedDataclass from a dataclass instance. + + Parameters + ---------- + dataclass_instance : Any + Dataclass instance to serialize + + Returns + ------- + SerializedDataclass + Serialized dataclass with metadata + """ + class_path = ( + f"{dataclass_instance.__class__.__module__}." + f"{dataclass_instance.__class__.__name__}" + ) + # Get fields without recursive conversion to preserve nested dataclass instances + # We'll handle recursive serialization in serialize_value() + from dataclasses import fields + + data = {} + for field in fields(dataclass_instance): + data[field.name] = getattr(dataclass_instance, field.name) + + return cls(class_path=class_path, data=data) + + def to_dataclass(self) -> Any: + """Reconstruct dataclass instance from serialized data. + + Returns + ------- + Any + Reconstructed dataclass instance + + Raises + ------ + ImportError + If the dataclass module cannot be imported + AttributeError + If the dataclass class is not found in the module + """ + # Dynamically import the dataclass type + module_path, class_name = self.class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + dataclass_type = getattr(module, class_name) + + # Return the dataclass type and data for caller to deserialize fields + return dataclass_type, self.data + + +def serialize_value(value: Any) -> Any: + """Recursively serialize a value, converting tensors and dataclasses to serialized dicts. + + This function transparently handles: + - torch.Tensor -> SerializedTensor dict (CPU only, no gradient tracking) + - dataclass instances -> SerializedDataclass dict (preserves type information) + - dict -> recursively serialize values + - list/tuple -> recursively serialize elements + - primitives (int, float, str, bool, None) -> unchanged + + Parameters + ---------- + value : Any + Value to serialize (can be nested structure) + + Returns + ------- + Any + Serialized value (JSON-compatible with SerializedTensor and SerializedDataclass dicts) + """ + # Handle None + if value is None: + return None + + # Handle torch.Tensor + if isinstance(value, torch.Tensor): + return SerializedTensor.from_tensor(value).model_dump() + + # Handle dataclass instances (check before dict, as dataclasses can be dict-like) + # Note: is_dataclass returns True for both classes and instances, so check it's not a type + if is_dataclass(value) and not isinstance(value, type): + serialized_dc = SerializedDataclass.from_dataclass(value) + # Recursively serialize the data fields + serialized_data = { + key: serialize_value(val) for key, val in serialized_dc.data.items() + } + return { + "type": "dataclass", + "class_path": serialized_dc.class_path, + "data": serialized_data, + } + + # Handle dict - recursively serialize values + if isinstance(value, dict): + return {key: serialize_value(val) for key, val in value.items()} + + # Handle list - recursively serialize elements + if isinstance(value, list): + return [serialize_value(item) for item in value] + + # Handle tuple - convert to list and recursively serialize + if isinstance(value, tuple): + return [serialize_value(item) for item in value] + + # Primitives (int, float, str, bool) pass through unchanged + return value + + +def deserialize_value(value: Any) -> Any: + """Recursively deserialize a value, converting SerializedTensor and SerializedDataclass dicts back. + + This function transparently handles: + - SerializedTensor dict -> torch.Tensor (CPU, no gradient tracking) + - SerializedDataclass dict -> dataclass instance (reconstructed with original type) + - dict -> recursively deserialize values + - list -> recursively deserialize elements + - primitives -> unchanged + + Parameters + ---------- + value : Any + Value to deserialize (potentially containing SerializedTensor and SerializedDataclass dicts) + + Returns + ------- + Any + Deserialized value with torch.Tensor and dataclass objects restored + """ + # Handle None + if value is None: + return None + + # Handle dict - check if it's a SerializedDataclass or SerializedTensor + if isinstance(value, dict): + # Check for SerializedDataclass marker (check before tensor) + if value.get("type") == "dataclass": + try: + serialized_dc = SerializedDataclass.model_validate(value) + dataclass_type, data = serialized_dc.to_dataclass() + # Recursively deserialize the fields + deserialized_data = { + key: deserialize_value(val) for key, val in data.items() + } + # Reconstruct the dataclass instance + return dataclass_type(**deserialized_data) + except Exception: + # If parsing fails, treat as regular dict + pass + + # Check for SerializedTensor marker + if value.get("type") == "tensor": + try: + serialized_tensor = SerializedTensor.model_validate(value) + return serialized_tensor.to_tensor() + except Exception: + # If parsing fails, treat as regular dict + pass + + # Regular dict - recursively deserialize values + return {key: deserialize_value(val) for key, val in value.items()} + + # Handle list - recursively deserialize elements + if isinstance(value, list): + return [deserialize_value(item) for item in value] + + # Primitives pass through unchanged + return value diff --git a/areal/tests/test_local_scheduler.py b/areal/tests/test_local_scheduler.py new file mode 100644 index 000000000..b5eba66bc --- /dev/null +++ b/areal/tests/test_local_scheduler.py @@ -0,0 +1,1696 @@ +import asyncio +import os +import time +from unittest.mock import AsyncMock, Mock, call, patch + +import httpx +import psutil +import pytest + +from areal.api.scheduler_api import ( + Job, + SchedulingSpec, + SchedulingStrategy, + Worker, +) +from areal.scheduler.exceptions import ( + EngineCallError, + EngineCreationError, + EngineImportError, + GPUAllocationError, + PortAllocationError, + RPCConnectionError, + WorkerCreationError, + WorkerFailedError, + WorkerNotFoundError, + WorkerTimeoutError, +) +from areal.scheduler.local import LocalScheduler, WorkerInfo + +# ============================================================================ +# Fixtures and Helper Functions +# ============================================================================ + + +@pytest.fixture +def scheduler(tmp_path): + """Create a LocalScheduler instance with default configuration.""" + return LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + +@pytest.fixture +def multi_gpu_scheduler(tmp_path): + """Create a LocalScheduler instance with multiple GPUs.""" + return LocalScheduler(gpu_devices=[0, 1, 2], log_dir=str(tmp_path)) + + +def create_mock_process(pid=1234, is_alive=True, exit_code=None): + """Create a mock subprocess.Popen process. + + Args: + pid: Process ID + is_alive: Whether process is still running + exit_code: Exit code if process has terminated + + Returns: + Mock process object + """ + mock_proc = Mock() + mock_proc.pid = pid + mock_proc.poll.return_value = None if is_alive else exit_code + if not is_alive: + mock_proc.returncode = exit_code + return mock_proc + + +def create_worker_info( + worker_id="test/0", + role="test", + ip="127.0.0.1", + ports=None, + gpu_devices=None, + log_file="/tmp/test.log", + process=None, +): + """Create a WorkerInfo instance with sensible defaults. + + Args: + worker_id: Worker identifier + role: Worker role name + ip: IP address + ports: List of port strings + gpu_devices: List of GPU device IDs + log_file: Path to log file + process: Mock process object (created if not provided) + + Returns: + WorkerInfo instance + """ + if ports is None: + ports = ["8000"] + if gpu_devices is None: + gpu_devices = [0] + if process is None: + process = create_mock_process() + + return WorkerInfo( + worker=Worker(id=worker_id, ip=ip, worker_ports=ports, engine_ports=[]), + process=process, + role=role, + gpu_devices=gpu_devices, + created_at=time.time(), + log_file=log_file, + ) + + +def create_mock_http_response(status_code=200, json_data=None): + """Create a mock HTTP response. + + Args: + status_code: HTTP status code + json_data: Dictionary to return from response.json() + + Returns: + Mock response object + """ + mock_response = Mock() + mock_response.status_code = status_code + if json_data is not None: + mock_response.json.return_value = json_data + return mock_response + + +class TestLocalSchedulerInitialization: + """Test LocalScheduler initialization and GPU detection.""" + + def test_init_with_explicit_gpu_devices(self, tmp_path): + """Should initialize with explicitly provided GPU devices.""" + scheduler = LocalScheduler( + gpu_devices=[0, 1, 2], + log_dir=str(tmp_path), + startup_timeout=60.0, + health_check_interval=2.0, + ) + + assert scheduler.gpu_devices == [0, 1, 2] + assert scheduler.log_dir == tmp_path + assert scheduler.startup_timeout == 60.0 + assert scheduler.health_check_interval == 2.0 + assert scheduler._gpu_counter == 0 + assert len(scheduler._allocated_ports) == 0 + assert len(scheduler._workers) == 0 + assert tmp_path.exists() + + def test_init_without_gpu_devices_uses_cuda_visible_devices(self, tmp_path): + """Should detect GPUs from CUDA_VISIBLE_DEVICES environment variable.""" + with patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1,3"}): + scheduler = LocalScheduler(log_dir=str(tmp_path)) + assert scheduler.gpu_devices == [0, 1, 3] + + def test_init_with_invalid_cuda_visible_devices(self, tmp_path): + """Should fall back to default [0] when CUDA_VISIBLE_DEVICES is invalid.""" + with patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "invalid,gpu,ids"}): + scheduler = LocalScheduler(log_dir=str(tmp_path)) + assert scheduler.gpu_devices == [0] + + def test_init_without_cuda_visible_devices(self, tmp_path): + """Should default to [0] when CUDA_VISIBLE_DEVICES is not set.""" + with patch.dict(os.environ, {}, clear=True): + if "CUDA_VISIBLE_DEVICES" in os.environ: + del os.environ["CUDA_VISIBLE_DEVICES"] + scheduler = LocalScheduler(log_dir=str(tmp_path)) + assert scheduler.gpu_devices == [0] + + def test_init_creates_log_directory(self, tmp_path): + """Should create log directory if it doesn't exist.""" + log_dir = tmp_path / "nested" / "log" / "dir" + assert not log_dir.exists() + + scheduler = LocalScheduler(log_dir=str(log_dir)) + + assert log_dir.exists() + assert scheduler.log_dir == log_dir + + def test_init_creates_http_clients(self, tmp_path): + """Should initialize both sync and async HTTP clients.""" + scheduler = LocalScheduler(log_dir=str(tmp_path)) + + assert isinstance(scheduler._http_client, httpx.Client) + assert isinstance(scheduler._async_http_client, httpx.AsyncClient) + + +class TestGPUAllocation: + """Test GPU allocation strategies.""" + + def test_allocate_gpus_round_robin(self, tmp_path): + """Should allocate GPUs in round-robin fashion.""" + scheduler = LocalScheduler(gpu_devices=[0, 1, 2], log_dir=str(tmp_path)) + + # First allocation + gpus1 = scheduler._allocate_gpus(2) + assert gpus1 == [0, 1] + + # Second allocation (wraps around) + gpus2 = scheduler._allocate_gpus(3) + assert gpus2 == [2, 0, 1] + + # Third allocation + gpus3 = scheduler._allocate_gpus(1) + assert gpus3 == [2] + + def test_allocate_gpus_exceeds_available(self, tmp_path): + """Should raise GPUAllocationError when requesting more GPUs than available.""" + scheduler = LocalScheduler(gpu_devices=[0, 1], log_dir=str(tmp_path)) + + with pytest.raises(GPUAllocationError) as exc_info: + scheduler._allocate_gpus(3) + + assert "Requested 3 GPUs but only 2 available" in str(exc_info.value) + + def test_allocate_gpus_single_gpu_multiple_times(self, scheduler): + """Should allow multiple workers to share a single GPU via round-robin.""" + # Multiple allocations should all get GPU 0 + for _ in range(5): + gpus = scheduler._allocate_gpus(1) + assert gpus == [0] + + def test_get_colocated_gpus_success(self, multi_gpu_scheduler): + """Should return GPU devices from target worker for colocation.""" + # Create mock workers for target role + worker1 = create_worker_info( + worker_id="actor/0", role="actor", ports=["8000"], gpu_devices=[0, 1] + ) + worker2 = create_worker_info( + worker_id="actor/1", role="actor", ports=["8001"], gpu_devices=[2] + ) + multi_gpu_scheduler._workers["actor"] = [worker1, worker2] + + # Get colocated GPUs + gpus = multi_gpu_scheduler._get_colocated_gpus("actor", 0) + assert gpus == [0, 1] + + gpus = multi_gpu_scheduler._get_colocated_gpus("actor", 1) + assert gpus == [2] + + def test_get_colocated_gpus_role_not_found(self, scheduler): + """Should raise WorkerNotFoundError when target role doesn't exist.""" + with pytest.raises(WorkerNotFoundError) as exc_info: + scheduler._get_colocated_gpus("nonexistent", 0) + + assert "Cannot colocate with role 'nonexistent' - role not found" in str( + exc_info.value + ) + + def test_get_colocated_gpus_worker_index_out_of_range(self, scheduler): + """Should raise ValueError when worker index is out of range.""" + # Create only one worker for target role + worker = create_worker_info(worker_id="actor/0", role="actor", gpu_devices=[0]) + scheduler._workers["actor"] = [worker] + + with pytest.raises(ValueError) as exc_info: + scheduler._get_colocated_gpus("actor", 5) + + assert "only 1 workers exist" in str(exc_info.value) + + +class TestPortAllocation: + """Test port allocation and tracking.""" + + def test_allocate_ports_success(self, tmp_path): + """Should allocate requested number of free ports.""" + with patch("areal.scheduler.local.find_free_ports") as mock_find_ports: + mock_find_ports.return_value = [8000, 8001, 8002] + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + ports = scheduler._allocate_ports(3) + + assert ports == [8000, 8001, 8002] + assert scheduler._allocated_ports == {8000, 8001, 8002} + mock_find_ports.assert_called_once_with(3, exclude_ports=set()) + + def test_allocate_ports_excludes_already_allocated(self, tmp_path): + """Should exclude already allocated ports from search.""" + with patch("areal.scheduler.local.find_free_ports") as mock_find_ports: + mock_find_ports.side_effect = [ + [8000, 8001], + [8002, 8003], + ] + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + # First allocation + ports1 = scheduler._allocate_ports(2) + assert ports1 == [8000, 8001] + + # Second allocation should exclude previously allocated ports + ports2 = scheduler._allocate_ports(2) + assert ports2 == [8002, 8003] + assert scheduler._allocated_ports == {8000, 8001, 8002, 8003} + + # Verify excluded ports were passed + calls = mock_find_ports.call_args_list + assert calls[0] == call(2, exclude_ports=set()) + assert calls[1] == call(2, exclude_ports={8000, 8001}) + + def test_allocate_ports_failure(self, tmp_path): + """Should raise PortAllocationError when port allocation fails.""" + with patch("areal.scheduler.local.find_free_ports") as mock_find_ports: + mock_find_ports.side_effect = ValueError("No free ports available") + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + with pytest.raises(PortAllocationError) as exc_info: + scheduler._allocate_ports(5) + + assert "No free ports available" in str(exc_info.value) + + +class TestWorkerCreation: + """Test worker creation with various configurations.""" + + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_with_default_spec( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should create workers with default spec (1 GPU, 2 ports) when no specs provided.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.side_effect = [[8000, 8001], [8002, 8003]] + + # Mock process + mock_process1 = Mock() + mock_process1.pid = 1234 + mock_process1.poll.return_value = None + mock_process2 = Mock() + mock_process2.pid = 1235 + mock_process2.poll.return_value = None + mock_popen.side_effect = [mock_process1, mock_process2] + + scheduler = LocalScheduler(gpu_devices=[0, 1], log_dir=str(tmp_path)) + + job = Job(replicas=2, role="rollout") + worker_ids = scheduler.create_workers(job) + + assert worker_ids == ["rollout/0", "rollout/1"] + assert "rollout" in scheduler._workers + assert len(scheduler._workers["rollout"]) == 2 + + # Verify default spec was used + assert mock_popen.call_count == 2 + + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_with_single_spec_for_all( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should use single spec for all workers when specs length is 1.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.side_effect = [[8000, 8001, 8002]] * 3 + + # Mock processes + mock_processes = [] + for i in range(3): + mock_proc = Mock() + mock_proc.pid = 1000 + i + mock_proc.poll.return_value = None + mock_processes.append(mock_proc) + mock_popen.side_effect = mock_processes + + scheduler = LocalScheduler(gpu_devices=[0, 1, 2], log_dir=str(tmp_path)) + + job = Job( + replicas=3, + role="actor", + tasks=[SchedulingSpec(cpu=1, mem=1024, gpu=2, port_count=3)], + ) + worker_ids = scheduler.create_workers(job) + + assert len(worker_ids) == 3 + assert mock_popen.call_count == 3 + + # All workers should use the same spec + for worker_info in scheduler._workers["actor"]: + assert len(worker_info.worker.worker_ports) == 3 + + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_with_per_worker_specs( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should use individual specs when specs length equals replicas.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.side_effect = [[8000], [8001, 8002]] + + # Mock processes + mock_proc1 = Mock() + mock_proc1.pid = 1000 + mock_proc1.poll.return_value = None + mock_proc2 = Mock() + mock_proc2.pid = 1001 + mock_proc2.poll.return_value = None + mock_popen.side_effect = [mock_proc1, mock_proc2] + + scheduler = LocalScheduler(gpu_devices=[0, 1], log_dir=str(tmp_path)) + + job = Job( + replicas=2, + role="critic", + tasks=[ + SchedulingSpec(cpu=1, mem=1024, gpu=1, port_count=1), + SchedulingSpec(cpu=1, mem=1024, gpu=1, port_count=2), + ], + ) + worker_ids = scheduler.create_workers(job) + + assert len(worker_ids) == 2 + assert len(scheduler._workers["critic"][0].worker.worker_ports) == 1 + assert len(scheduler._workers["critic"][1].worker.worker_ports) == 2 + + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_with_custom_command( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should use custom command from spec when provided.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.return_value = [8000, 8001] + + mock_proc = Mock() + mock_proc.pid = 1234 + mock_proc.poll.return_value = None + mock_popen.return_value = mock_proc + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + job = Job( + replicas=1, + role="custom", + tasks=[ + SchedulingSpec( + cpu=1, + mem=1024, + gpu=1, + port_count=2, + cmd="python my_custom_server.py --port 8000", + ) + ], + ) + worker_ids = scheduler.create_workers(job) + + assert len(worker_ids) == 1 + + # Verify custom command was used + popen_call = mock_popen.call_args + cmd_args = popen_call[0][0] + assert cmd_args == ["python", "my_custom_server.py", "--port", "8000"] + + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_with_environment_variables( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should merge environment variables from spec into worker environment.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.return_value = [8000, 8001] + + mock_proc = Mock() + mock_proc.pid = 1234 + mock_proc.poll.return_value = None + mock_popen.return_value = mock_proc + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + job = Job( + replicas=1, + role="envtest", + tasks=[ + SchedulingSpec( + cpu=1, + mem=1024, + gpu=1, + port_count=2, + env_vars={"CUSTOM_VAR": "custom_value", "ANOTHER_VAR": "123"}, + ) + ], + ) + worker_ids = scheduler.create_workers(job) + + assert len(worker_ids) == 1 + + # Verify environment variables were passed + popen_call = mock_popen.call_args + env = popen_call[1]["env"] + assert env["CUSTOM_VAR"] == "custom_value" + assert env["ANOTHER_VAR"] == "123" + assert env["CUDA_VISIBLE_DEVICES"] == "0" + assert env["WORKER_ID"] == "envtest/0" + + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_with_colocate_strategy( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should colocate workers on same GPUs as target role when colocate strategy is used.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.return_value = [8000, 8001] + + mock_processes = [] + for i in range(4): + mock_proc = Mock() + mock_proc.pid = 1000 + i + mock_proc.poll.return_value = None + mock_processes.append(mock_proc) + mock_popen.side_effect = mock_processes + + scheduler = LocalScheduler(gpu_devices=[0, 1, 2, 3], log_dir=str(tmp_path)) + + # Create target workers (actors) + actor_job = Job( + replicas=2, + role="actor", + tasks=[SchedulingSpec(cpu=1, mem=1024, gpu=2, port_count=2)], + ) + scheduler.create_workers(actor_job) + + # Get GPU allocations for actors + actor_gpus_0 = scheduler._workers["actor"][0].gpu_devices + actor_gpus_1 = scheduler._workers["actor"][1].gpu_devices + + # Reset mock + mock_find_ports.reset_mock() + mock_find_ports.return_value = [8010, 8011] + + # Create colocated workers (critics) + critic_job = Job( + replicas=2, + role="critic", + tasks=[SchedulingSpec(cpu=1, mem=1024, gpu=2, port_count=2)], + scheduling_strategy=SchedulingStrategy(type="colocation", target="actor"), + ) + critic_ids = scheduler.create_workers(critic_job) + + assert len(critic_ids) == 2 + + # Verify critics are colocated with actors + critic_gpus_0 = scheduler._workers["critic"][0].gpu_devices + critic_gpus_1 = scheduler._workers["critic"][1].gpu_devices + + assert critic_gpus_0 == actor_gpus_0 + assert critic_gpus_1 == actor_gpus_1 + + def test_create_workers_duplicate_role_error(self, tmp_path): + """Should raise WorkerCreationError when attempting to create workers for existing role.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + with ( + patch("areal.scheduler.local.subprocess.Popen") as mock_popen, + patch("areal.scheduler.local.find_free_ports") as mock_find_ports, + patch("areal.scheduler.local.gethostip") as mock_gethostip, + ): + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.return_value = [8000, 8001] + mock_proc = Mock() + mock_proc.pid = 1234 + mock_proc.poll.return_value = None + mock_popen.return_value = mock_proc + + job = Job(replicas=1, role="test") + scheduler.create_workers(job) + + # Try to create again + with pytest.raises(WorkerCreationError) as exc_info: + scheduler.create_workers(job) + + assert "Worker group already exists" in str(exc_info.value) + assert exc_info.value.worker_key == "test" + + def test_create_workers_zero_replicas_error(self, tmp_path): + """Should raise WorkerCreationError when replicas is 0.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + job = Job(replicas=0, role="test") + + with pytest.raises(WorkerCreationError) as exc_info: + scheduler.create_workers(job) + + assert "replicas must be greater than 0" in str(exc_info.value) + + def test_create_workers_invalid_specs_length(self, tmp_path): + """Should raise WorkerCreationError when tasks length is invalid.""" + scheduler = LocalScheduler(gpu_devices=[0, 1], log_dir=str(tmp_path)) + + job = Job( + replicas=3, + role="test", + tasks=[ + SchedulingSpec(cpu=1, mem=1024, gpu=1, port_count=2), + SchedulingSpec(cpu=1, mem=1024, gpu=1, port_count=2), + ], # 2 tasks for 3 replicas + ) + + with pytest.raises(WorkerCreationError) as exc_info: + scheduler.create_workers(job) + + assert "schedulings length (2) must be 1 or equal to replicas (3)" in str( + exc_info.value + ) + + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_subprocess_fails_immediately( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should raise WorkerCreationError when subprocess exits immediately.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.return_value = [8000, 8001] + + # Mock process that exits immediately + mock_proc = Mock() + mock_proc.pid = 1234 + mock_proc.poll.return_value = 1 # Exit code 1 + mock_proc.returncode = 1 + mock_popen.return_value = mock_proc + + # Create log file with error message + log_file = tmp_path / "test_0.log" + log_file.write_text("Error: Failed to start server\n") + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + job = Job(replicas=1, role="test") + + with patch.object( + scheduler, "_read_log_tail", return_value="Error: Failed to start server" + ): + with pytest.raises(WorkerCreationError) as exc_info: + scheduler.create_workers(job) + + assert "exited immediately with code 1" in str(exc_info.value) + + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_create_workers_cleanup_on_partial_failure( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should clean up successfully created workers when a later worker fails.""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.side_effect = [ + [8000, 8001], # First worker succeeds + ValueError("No free ports"), # Second worker fails + ] + + # First process succeeds + mock_proc1 = Mock() + mock_proc1.pid = 1234 + mock_proc1.poll.return_value = None + mock_popen.return_value = mock_proc1 + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + job = Job(replicas=2, role="test") + + with patch.object(scheduler, "_cleanup_workers") as mock_cleanup: + with pytest.raises(WorkerCreationError) as exc_info: + scheduler.create_workers(job) + + # Verify cleanup was called + assert mock_cleanup.called + assert "Resource allocation failed" in str(exc_info.value) + + def test_create_workers_colocate_strategy_missing_target(self, tmp_path): + """Should raise WorkerCreationError when colocation strategy is missing target role.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + job = Job( + replicas=1, + role="test", + tasks=[SchedulingSpec(cpu=1, mem=1024, gpu=1, port_count=2)], + scheduling_strategy=SchedulingStrategy( + type="colocation", target="" + ), # Missing target + ) + + with pytest.raises(WorkerCreationError) as exc_info: + scheduler.create_workers(job) + + assert "Colocation strategy requires target" in str(exc_info.value) + + +class TestGetWorkers: + """Test getting workers and waiting for readiness.""" + + def test_get_workers_role_not_found(self, scheduler): + """Should raise WorkerNotFoundError when role doesn't exist.""" + with pytest.raises(WorkerNotFoundError) as exc_info: + scheduler.get_workers("nonexistent") + + assert exc_info.value.worker_id == "nonexistent" + + @patch("areal.scheduler.local.time.sleep") + def test_get_workers_success(self, mock_sleep, scheduler, tmp_path): + """Should return workers when all are ready.""" + # Create mock workers + worker1 = create_worker_info( + worker_id="test/0", ports=["8000"], log_file=str(tmp_path / "test_0.log") + ) + worker2 = create_worker_info( + worker_id="test/1", ports=["8001"], log_file=str(tmp_path / "test_1.log") + ) + + scheduler._workers["test"] = [worker1, worker2] + + with patch.object(scheduler, "_is_worker_ready", return_value=True): + workers = scheduler.get_workers("test", timeout=10.0) + + assert len(workers) == 2 + assert workers[0].id == "test/0" + assert workers[1].id == "test/1" + + @patch("areal.scheduler.local.time.time") + @patch("areal.scheduler.local.time.sleep") + def test_get_workers_timeout(self, mock_sleep, mock_time, scheduler, tmp_path): + """Should raise WorkerTimeoutError when timeout is exceeded.""" + # Mock time progression - provide enough values + mock_time.side_effect = [0.0] + [i for i in range(1, 20)] + + worker = create_worker_info(log_file=str(tmp_path / "test_0.log")) + worker.created_at = 0.0 + + scheduler._workers["test"] = [worker] + + # Worker never becomes ready + with patch.object(scheduler, "_is_worker_ready", return_value=False): + with pytest.raises(WorkerTimeoutError) as exc_info: + scheduler.get_workers("test", timeout=5.0) + + assert exc_info.value.worker_key == "test" + assert exc_info.value.timeout == 5.0 + + def test_get_workers_process_died(self, scheduler, tmp_path): + """Should raise WorkerFailedError when worker process dies during readiness check.""" + log_file = tmp_path / "test_0.log" + log_file.write_text("Error: Connection refused\n") + + # Process dies after first check + mock_proc = create_mock_process() + mock_proc.poll.side_effect = [None, 1] # None (alive), then 1 (dead) + mock_proc.returncode = 1 + + worker = create_worker_info(process=mock_proc, log_file=str(log_file)) + scheduler._workers["test"] = [worker] + + with patch.object(scheduler, "_is_worker_ready", return_value=False): + with pytest.raises(WorkerFailedError) as exc_info: + scheduler.get_workers("test", timeout=10.0) + + assert exc_info.value.worker_id == "test/0" + assert exc_info.value.exit_code == 1 + + @patch("areal.scheduler.local.time.sleep") + def test_get_workers_gradual_readiness(self, mock_sleep, scheduler, tmp_path): + """Should wait for all workers to become ready gradually.""" + worker1 = create_worker_info( + worker_id="test/0", ports=["8000"], log_file=str(tmp_path / "test_0.log") + ) + worker2 = create_worker_info( + worker_id="test/1", ports=["8001"], log_file=str(tmp_path / "test_1.log") + ) + + scheduler._workers["test"] = [worker1, worker2] + + # Worker 1 ready immediately, worker 2 ready on second check + ready_calls = [True, False, True, True] + with patch.object(scheduler, "_is_worker_ready", side_effect=ready_calls): + workers = scheduler.get_workers("test", timeout=10.0) + + assert len(workers) == 2 + + +class TestWorkerHealthCheck: + """Test worker health checking functionality.""" + + @pytest.mark.parametrize( + "status_code,expected", + [ + (200, True), # Success + (503, False), # Service unavailable + (500, False), # Internal server error + ], + ) + def test_is_worker_ready_http_status( + self, scheduler, tmp_path, status_code, expected + ): + """Should return appropriate result based on HTTP status code.""" + worker_info = create_worker_info(log_file=str(tmp_path / "test.log")) + mock_response = create_mock_http_response(status_code=status_code) + + with patch.object(scheduler._http_client, "get", return_value=mock_response): + assert scheduler._is_worker_ready(worker_info) is expected + + def test_is_worker_ready_connection_error(self, scheduler, tmp_path): + """Should return False when connection to worker fails.""" + worker_info = create_worker_info(log_file=str(tmp_path / "test.log")) + + with patch.object( + scheduler._http_client, + "get", + side_effect=httpx.ConnectError("Connection refused"), + ): + assert scheduler._is_worker_ready(worker_info) is False + + def test_check_worker_health_all_healthy(self, scheduler, tmp_path): + """Should pass when all workers are healthy.""" + worker1 = create_worker_info( + worker_id="test/0", ports=["8000"], log_file=str(tmp_path / "test_0.log") + ) + worker2 = create_worker_info( + worker_id="test/1", ports=["8001"], log_file=str(tmp_path / "test_1.log") + ) + + scheduler._workers["test"] = [worker1, worker2] + + # Should not raise + scheduler._check_worker_health("test") + + def test_check_worker_health_worker_failed(self, scheduler, tmp_path): + """Should raise WorkerFailedError when a worker has failed.""" + log_file = tmp_path / "test_0.log" + log_file.write_text("Killed by signal\n") + + mock_proc = create_mock_process(is_alive=False, exit_code=137) + worker = create_worker_info(process=mock_proc, log_file=str(log_file)) + + scheduler._workers["test"] = [worker] + + with pytest.raises(WorkerFailedError) as exc_info: + scheduler._check_worker_health("test") + + assert exc_info.value.worker_id == "test/0" + assert exc_info.value.exit_code == 137 + + def test_check_worker_health_nonexistent_role(self, scheduler): + """Should silently pass when role doesn't exist.""" + # Should not raise + scheduler._check_worker_health("nonexistent") + + +class TestDeleteWorkers: + """Test worker deletion and cleanup.""" + + def test_delete_workers_specific_role(self, scheduler, tmp_path): + """Should delete workers for specific role.""" + # Create mock workers for multiple roles + worker1 = create_worker_info( + worker_id="role1/0", + role="role1", + ports=["8000"], + log_file=str(tmp_path / "role1_0.log"), + ) + worker2 = create_worker_info( + worker_id="role2/0", + role="role2", + ports=["8001"], + log_file=str(tmp_path / "role2_0.log"), + ) + + scheduler._workers["role1"] = [worker1] + scheduler._workers["role2"] = [worker2] + scheduler._allocated_ports = {8000, 8001} + + with patch.object(scheduler, "_terminate_process_tree"): + scheduler.delete_workers("role1") + + # role1 should be deleted, role2 should remain + assert "role1" not in scheduler._workers + assert "role2" in scheduler._workers + assert 8000 not in scheduler._allocated_ports + assert 8001 in scheduler._allocated_ports + + def test_delete_workers_all_roles(self, scheduler, tmp_path): + """Should delete all workers when role is None.""" + worker1 = create_worker_info( + worker_id="role1/0", + role="role1", + ports=["8000"], + log_file=str(tmp_path / "role1_0.log"), + ) + worker2 = create_worker_info( + worker_id="role2/0", + role="role2", + ports=["8001"], + log_file=str(tmp_path / "role2_0.log"), + ) + + scheduler._workers["role1"] = [worker1] + scheduler._workers["role2"] = [worker2] + scheduler._allocated_ports = {8000, 8001} + + with patch.object(scheduler, "_terminate_process_tree"): + scheduler.delete_workers(None) + + # All workers should be deleted + assert len(scheduler._workers) == 0 + assert len(scheduler._allocated_ports) == 0 + + def test_delete_workers_nonexistent_role(self, scheduler): + """Should log warning and return when role doesn't exist.""" + # Should not raise + scheduler.delete_workers("nonexistent") + + def test_cleanup_workers_releases_ports(self, scheduler, tmp_path): + """Should release allocated ports when cleaning up workers.""" + worker = create_worker_info( + ports=["8000", "8001"], log_file=str(tmp_path / "test.log") + ) + scheduler._allocated_ports = {8000, 8001, 8002} + + with patch.object(scheduler, "_terminate_process_tree"): + scheduler._cleanup_workers([worker]) + + # Ports 8000 and 8001 should be released + assert scheduler._allocated_ports == {8002} + + def test_cleanup_workers_handles_errors(self, scheduler, tmp_path): + """Should continue cleanup even if terminating a process fails.""" + worker1 = create_worker_info( + worker_id="test/0", ports=["8000"], log_file=str(tmp_path / "test_0.log") + ) + worker2 = create_worker_info( + worker_id="test/1", ports=["8001"], log_file=str(tmp_path / "test_1.log") + ) + + # First termination fails, second succeeds + with patch.object( + scheduler, + "_terminate_process_tree", + side_effect=[Exception("Failed to terminate"), None], + ): + # Should not raise, just log error + scheduler._cleanup_workers([worker1, worker2]) + + +class TestProcessTermination: + """Test process termination functionality.""" + + @patch("areal.scheduler.local.psutil.Process") + @patch("areal.scheduler.local.psutil.wait_procs") + def test_terminate_process_tree_graceful( + self, mock_wait_procs, mock_process_class, tmp_path + ): + """Should gracefully terminate process tree.""" + # Mock parent process + mock_parent = Mock() + mock_child1 = Mock() + mock_child2 = Mock() + + mock_parent.children.return_value = [mock_child1, mock_child2] + mock_process_class.return_value = mock_parent + + # All processes terminate gracefully + mock_wait_procs.return_value = ([], []) # (gone, alive) + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + scheduler._terminate_process_tree(1234) + + # Verify termination sequence + mock_child1.terminate.assert_called_once() + mock_child2.terminate.assert_called_once() + mock_parent.terminate.assert_called_once() + + # Should not call kill since all terminated gracefully + mock_child1.kill.assert_not_called() + mock_child2.kill.assert_not_called() + mock_parent.kill.assert_not_called() + + @patch("areal.scheduler.local.psutil.Process") + @patch("areal.scheduler.local.psutil.wait_procs") + def test_terminate_process_tree_force_kill( + self, mock_wait_procs, mock_process_class, tmp_path + ): + """Should force kill processes that don't terminate gracefully.""" + mock_parent = Mock() + mock_child = Mock() + + mock_parent.children.return_value = [mock_child] + + # Return mock_parent only when called with pid=1234, otherwise raise NoSuchProcess + # This prevents interference from __del__ cleanup of previous test's schedulers + def process_side_effect(pid): + if pid == 1234: + return mock_parent + raise psutil.NoSuchProcess(pid) + + mock_process_class.side_effect = process_side_effect + + # Child doesn't terminate gracefully + mock_wait_procs.return_value = ([], [mock_child]) # (gone, alive) + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + scheduler._terminate_process_tree(1234) + + # Verify force kill was called + mock_child.terminate.assert_called_once() + mock_child.kill.assert_called_once() + + @patch("areal.scheduler.local.psutil.Process") + def test_terminate_process_tree_no_such_process(self, mock_process_class, tmp_path): + """Should handle gracefully when process doesn't exist.""" + mock_process_class.side_effect = psutil.NoSuchProcess(1234) + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + # Should not raise + scheduler._terminate_process_tree(1234) + + @patch("areal.scheduler.local.psutil.Process") + def test_terminate_process_tree_handles_child_no_such_process( + self, mock_process_class, tmp_path + ): + """Should handle when child process disappears during termination.""" + mock_parent = Mock() + mock_child = Mock() + mock_child.terminate.side_effect = psutil.NoSuchProcess(1235) + + mock_parent.children.return_value = [mock_child] + mock_process_class.return_value = mock_parent + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + # Should not raise + scheduler._terminate_process_tree(1234) + + +class TestLogFileHandling: + """Test log file reading and handling.""" + + def test_read_log_tail_success(self, tmp_path): + """Should read last N lines from log file.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + log_file = tmp_path / "test.log" + log_lines = [f"Line {i}\n" for i in range(100)] + log_file.write_text("".join(log_lines)) + + tail = scheduler._read_log_tail(str(log_file), lines=10) + + # Should contain last 10 lines + assert "Line 90" in tail + assert "Line 99" in tail + assert "Line 89" not in tail + + def test_read_log_tail_file_not_found(self, tmp_path): + """Should return error message when log file doesn't exist.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + tail = scheduler._read_log_tail("/nonexistent/file.log") + + assert "Could not read log file" in tail + + def test_read_log_tail_fewer_lines_than_requested(self, tmp_path): + """Should return all lines when file has fewer lines than requested.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + log_file = tmp_path / "test.log" + log_file.write_text("Line 1\nLine 2\nLine 3\n") + + tail = scheduler._read_log_tail(str(log_file), lines=50) + + assert "Line 1" in tail + assert "Line 2" in tail + assert "Line 3" in tail + + +class TestEngineCreation: + """Test engine creation on workers.""" + + def test_create_engine_success(self, scheduler, tmp_path): + """Should successfully create engine on worker.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response( + status_code=200, + json_data={"result": {"status": "initialized", "name": "TestEngine"}}, + ) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + result = asyncio.run( + scheduler.create_engine( + "test/0", "test_engines.DummyEngine", name="TestEngine", param=123 + ) + ) + + assert result == {"status": "initialized", "name": "TestEngine"} + + def test_create_engine_worker_not_found(self, scheduler): + """Should raise WorkerNotFoundError when worker doesn't exist.""" + with pytest.raises(WorkerNotFoundError) as exc_info: + asyncio.run( + scheduler.create_engine("nonexistent/0", "test_engines.DummyEngine") + ) + + assert exc_info.value.worker_id == "nonexistent/0" + + def test_create_engine_worker_died(self, scheduler, tmp_path): + """Should raise WorkerFailedError when worker process has died.""" + log_file = tmp_path / "test.log" + log_file.write_text("Worker crashed\n") + + mock_proc = create_mock_process(is_alive=False, exit_code=1) + worker = create_worker_info(process=mock_proc, log_file=str(log_file)) + scheduler._workers["test"] = [worker] + + with pytest.raises(WorkerFailedError) as exc_info: + asyncio.run(scheduler.create_engine("test/0", "test_engines.DummyEngine")) + + assert exc_info.value.worker_id == "test/0" + assert exc_info.value.exit_code == 1 + + def test_create_engine_invalid_engine_type(self, scheduler, tmp_path): + """Should raise EngineCreationError when engine is not a string.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + with pytest.raises(EngineCreationError) as exc_info: + asyncio.run(scheduler.create_engine("test/0", 123)) # Invalid type + + assert "Engine must be a string import path" in str(exc_info.value) + + def test_create_engine_import_error(self, scheduler, tmp_path): + """Should raise EngineImportError when engine import fails.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response( + status_code=400, + json_data={"detail": "Failed to import 'nonexistent.Engine'"}, + ) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + with pytest.raises(EngineImportError) as exc_info: + asyncio.run(scheduler.create_engine("test/0", "nonexistent.Engine")) + + assert "nonexistent.Engine" in str(exc_info.value) + + def test_create_engine_initialization_error(self, scheduler, tmp_path): + """Should raise EngineCreationError when engine initialization fails.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response( + status_code=500, + json_data={"detail": "Engine initialization failed: out of memory"}, + ) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + with pytest.raises(EngineCreationError) as exc_info: + asyncio.run( + scheduler.create_engine("test/0", "test_engines.DummyEngine") + ) + + assert "out of memory" in str(exc_info.value) + assert exc_info.value.status_code == 500 + + def test_create_engine_connection_error_worker_died(self, scheduler, tmp_path): + """Should raise WorkerFailedError when connection fails and worker is dead.""" + log_file = tmp_path / "test.log" + log_file.write_text("Worker crashed during engine creation\n") + + # First call returns None (alive), second call returns exit code (dead) + mock_proc = create_mock_process() + mock_proc.poll.side_effect = [None, 1] + mock_proc.returncode = 1 + + worker = create_worker_info(process=mock_proc, log_file=str(log_file)) + scheduler._workers["test"] = [worker] + + with patch.object( + scheduler._http_client, + "post", + side_effect=httpx.ConnectError("Connection refused"), + ): + with pytest.raises(WorkerFailedError) as exc_info: + asyncio.run( + scheduler.create_engine("test/0", "test_engines.DummyEngine") + ) + + assert exc_info.value.worker_id == "test/0" + + def test_create_engine_connection_error_worker_alive(self, scheduler, tmp_path): + """Should raise RPCConnectionError when connection fails but worker is alive.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + with patch.object( + scheduler._http_client, + "post", + side_effect=httpx.ConnectError("Connection refused"), + ): + with pytest.raises(RPCConnectionError) as exc_info: + asyncio.run( + scheduler.create_engine("test/0", "test_engines.DummyEngine") + ) + + assert exc_info.value.worker_id == "test/0" + assert exc_info.value.host == "127.0.0.1" + assert exc_info.value.port == 8000 + + def test_create_engine_timeout(self, scheduler, tmp_path): + """Should raise EngineCreationError when request times out.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + with patch.object( + scheduler._http_client, + "post", + side_effect=httpx.TimeoutException("Request timeout"), + ): + with pytest.raises(EngineCreationError) as exc_info: + asyncio.run( + scheduler.create_engine("test/0", "test_engines.DummyEngine") + ) + + assert "Request timed out" in str(exc_info.value) + + +class TestEngineMethodCalls: + """Test calling methods on engines (sync and async).""" + + def test_call_engine_success(self, scheduler, tmp_path): + """Should successfully call engine method synchronously.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response( + status_code=200, json_data={"result": 42} + ) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + result = scheduler.call_engine("test/0", "compute", arg1=10, arg2=20) + + assert result == 42 + + def test_call_engine_worker_not_found(self, scheduler): + """Should raise WorkerNotFoundError when worker doesn't exist.""" + with pytest.raises(WorkerNotFoundError): + scheduler.call_engine("nonexistent/0", "method") + + def test_call_engine_worker_died(self, scheduler, tmp_path): + """Should raise WorkerFailedError when worker dies before call.""" + log_file = tmp_path / "test.log" + log_file.write_text("Worker crashed\n") + + mock_proc = create_mock_process(is_alive=False, exit_code=1) + worker = create_worker_info(process=mock_proc, log_file=str(log_file)) + scheduler._workers["test"] = [worker] + + with pytest.raises(WorkerFailedError): + scheduler.call_engine("test/0", "method") + + def test_call_engine_method_error(self, scheduler, tmp_path): + """Should raise EngineCallError when method call returns 400/500.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response( + status_code=400, json_data={"detail": "Method 'nonexistent' not found"} + ) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + with pytest.raises(EngineCallError) as exc_info: + scheduler.call_engine("test/0", "nonexistent") + + assert "Method 'nonexistent' not found" in str(exc_info.value) + + @patch("areal.scheduler.local.time.sleep") + def test_call_engine_retry_on_503(self, mock_sleep, scheduler, tmp_path): + """Should retry on 503 Service Unavailable.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + # First call returns 503, second call succeeds + mock_response_503 = create_mock_http_response(status_code=503) + mock_response_200 = create_mock_http_response( + status_code=200, json_data={"result": "success"} + ) + + with patch.object( + scheduler._http_client, + "post", + side_effect=[mock_response_503, mock_response_200], + ): + result = scheduler.call_engine("test/0", "method", max_retries=3) + + assert result == "success" + assert mock_sleep.called + + @patch("areal.scheduler.local.time.sleep") + def test_call_engine_max_retries_exhausted(self, mock_sleep, scheduler, tmp_path): + """Should raise EngineCallError after max retries.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response(status_code=503) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + with pytest.raises(EngineCallError) as exc_info: + scheduler.call_engine("test/0", "method", max_retries=3) + + assert "Max retries exceeded" in str( + exc_info.value + ) or "Service unavailable" in str(exc_info.value) + assert exc_info.value.attempt == 3 + + @patch("areal.scheduler.local.time.sleep") + def test_call_engine_exponential_backoff(self, mock_sleep, scheduler, tmp_path): + """Should use exponential backoff for retries.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + mock_response = create_mock_http_response(status_code=503) + + with patch.object(scheduler._http_client, "post", return_value=mock_response): + try: + scheduler.call_engine( + "test/0", "method", max_retries=3, retry_delay=1.0 + ) + except EngineCallError: + pass + + # Verify exponential backoff: 1.0, 2.0 + sleep_calls = [call_args[0][0] for call_args in mock_sleep.call_args_list] + assert sleep_calls[0] == 1.0 # First retry + assert sleep_calls[1] == 2.0 # Second retry + + def test_async_call_engine_success(self, scheduler, tmp_path): + """Should successfully call engine method asynchronously.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + # Use Mock instead of AsyncMock for the response object + mock_response = create_mock_http_response( + status_code=200, json_data={"result": 42} + ) + + # But AsyncMock for post() since it's an async method + async_mock_post = AsyncMock(return_value=mock_response) + with patch.object(scheduler._async_http_client, "post", async_mock_post): + result = asyncio.run( + scheduler.async_call_engine("test/0", "compute", arg1=10, arg2=20) + ) + + assert result == 42 + + def test_async_call_engine_worker_not_found(self, scheduler): + """Should raise WorkerNotFoundError when worker doesn't exist (async).""" + with pytest.raises(WorkerNotFoundError): + asyncio.run(scheduler.async_call_engine("nonexistent/0", "method")) + + def test_async_call_engine_retry_with_backoff(self, scheduler, tmp_path): + """Should retry with exponential backoff in async mode.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + # First call returns 503, second call succeeds + # Use Mock (not AsyncMock) for response objects since response.json() is synchronous + mock_response_503 = create_mock_http_response(status_code=503) + mock_response_200 = create_mock_http_response( + status_code=200, json_data={"result": "success"} + ) + + # AsyncMock for post() since it's an async method + async_mock_post = AsyncMock(side_effect=[mock_response_503, mock_response_200]) + with patch.object(scheduler._async_http_client, "post", async_mock_post): + with patch("asyncio.sleep") as mock_async_sleep: + result = asyncio.run( + scheduler.async_call_engine("test/0", "method", max_retries=3) + ) + + assert result == "success" + assert mock_async_sleep.called + + +class TestFindWorkerById: + """Test finding workers by ID.""" + + def test_find_worker_by_id_success(self, scheduler, tmp_path): + """Should find worker by ID.""" + worker1 = create_worker_info( + worker_id="role1/0", + role="role1", + ports=["8000"], + log_file=str(tmp_path / "role1_0.log"), + ) + worker2 = create_worker_info( + worker_id="role2/0", + role="role2", + ports=["8001"], + log_file=str(tmp_path / "role2_0.log"), + ) + + scheduler._workers["role1"] = [worker1] + scheduler._workers["role2"] = [worker2] + + found = scheduler._find_worker_by_id("role2/0") + + assert found is worker2 + assert found.worker.id == "role2/0" + + def test_find_worker_by_id_not_found(self, scheduler, tmp_path): + """Should return None when worker ID is not found.""" + worker = create_worker_info( + worker_id="role1/0", role="role1", log_file=str(tmp_path / "role1_0.log") + ) + scheduler._workers["role1"] = [worker] + + found = scheduler._find_worker_by_id("nonexistent/99") + + assert found is None + + +class TestSchedulerCleanup: + """Test scheduler cleanup and destructor.""" + + def test_destructor_deletes_all_workers(self, scheduler, tmp_path): + """Should delete all workers when scheduler is destroyed.""" + worker = create_worker_info(log_file=str(tmp_path / "test.log")) + scheduler._workers["test"] = [worker] + + with patch.object(scheduler, "delete_workers") as mock_delete: + scheduler.__del__() + + mock_delete.assert_called_once() + + def test_destructor_closes_http_clients(self, scheduler): + """Should close HTTP clients when scheduler is destroyed.""" + with patch.object(scheduler._http_client, "close") as mock_close: + scheduler.__del__() + + mock_close.assert_called_once() + + def test_destructor_handles_errors_gracefully(self, scheduler): + """Should handle errors gracefully in destructor.""" + with patch.object(scheduler, "delete_workers", side_effect=Exception("Error")): + # Should not raise + scheduler.__del__() + + +class TestEdgeCases: + """Test various edge cases and corner scenarios.""" + + def test_gpu_counter_wraps_correctly(self, tmp_path): + """Should correctly wrap GPU counter for round-robin allocation.""" + scheduler = LocalScheduler(gpu_devices=[0, 1], log_dir=str(tmp_path)) + + # Allocate many times to ensure wrapping + for i in range(10): + gpus = scheduler._allocate_gpus(1) + expected_gpu = i % 2 + assert gpus == [expected_gpu] + + def test_port_allocation_accumulates_correctly(self, tmp_path): + """Should correctly accumulate allocated ports over multiple allocations.""" + with patch("areal.scheduler.local.find_free_ports") as mock_find_ports: + mock_find_ports.side_effect = [ + [8000, 8001], + [8002, 8003], + [8004, 8005, 8006], + ] + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + scheduler._allocate_ports(2) + scheduler._allocate_ports(2) + scheduler._allocate_ports(3) + + assert scheduler._allocated_ports == { + 8000, + 8001, + 8002, + 8003, + 8004, + 8005, + 8006, + } + + @patch("areal.scheduler.local.gethostip") + @patch("areal.scheduler.local.subprocess.Popen") + @patch("areal.scheduler.local_scheduler.find_free_ports") + def test_worker_id_format( + self, mock_find_ports, mock_popen, mock_gethostip, tmp_path + ): + """Should create worker IDs in correct format (role/index).""" + mock_gethostip.return_value = "127.0.0.1" + mock_find_ports.return_value = [8000, 8001] + + mock_processes = [] + for i in range(5): + mock_proc = Mock() + mock_proc.pid = 1000 + i + mock_proc.poll.return_value = None + mock_processes.append(mock_proc) + mock_popen.side_effect = mock_processes + + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + job = Job(replicas=5, role="worker") + worker_ids = scheduler.create_workers(job) + + assert worker_ids == [ + "worker/0", + "worker/1", + "worker/2", + "worker/3", + "worker/4", + ] + + def test_empty_workers_dict_operations(self, tmp_path): + """Should handle operations on empty workers dictionary gracefully.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + # Delete all workers when none exist + scheduler.delete_workers(None) + + # Check health of non-existent role + scheduler._check_worker_health("nonexistent") + + # Find worker by ID when no workers exist + assert scheduler._find_worker_by_id("any/0") is None + + def test_concurrent_gpu_allocations(self, tmp_path): + """Should handle concurrent GPU allocations correctly.""" + scheduler = LocalScheduler(gpu_devices=[0, 1, 2], log_dir=str(tmp_path)) + + # Simulate multiple workers requesting GPUs simultaneously + results = [] + for _ in range(6): + gpus = scheduler._allocate_gpus(1) + results.append(gpus[0]) + + # Should cycle through GPUs in order + assert results == [0, 1, 2, 0, 1, 2] + + def test_log_directory_with_special_characters(self, tmp_path): + """Should handle log directory paths with special characters.""" + log_dir = tmp_path / "logs with spaces" / "special-chars_123" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(log_dir)) + + assert log_dir.exists() + assert scheduler.log_dir == log_dir + + +class TestRPCWorkflowIntegration: + """Integration tests for RPC workflow endpoint execution.""" + + @pytest.mark.integration + @pytest.mark.asyncio + async def test_run_workflow_endpoint_basic(self, tmp_path): + """Should execute workflow via RPC endpoint with real worker processes.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + try: + job = Job(replicas=1, role="test") + worker_ids = scheduler.create_workers(job) + assert len(worker_ids) == 1 + + workers = scheduler.get_workers(role="test", timeout=30.0) + worker_id = workers[0].id + + result = await scheduler.async_call_engine( + worker_id=worker_id, + method="run_workflow", + workflow="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + data={"test_id": 123, "value": "test_data"}, + ) + + assert result is not None + assert "input_ids" in result + assert "attention_mask" in result + assert "loss_mask" in result + assert "rewards" in result + + from areal.tests.utils import TestWorkflow + + workflow = TestWorkflow() + ref_result = await workflow.arun_episode(None, None) + assert result.keys() == ref_result.keys() + for key in result: + assert type(result[key]) is type(ref_result[key]) + + finally: + scheduler.delete_workers(role="test") + + @pytest.mark.integration + @pytest.mark.asyncio + async def test_run_workflow_endpoint_multiple_calls(self, tmp_path): + """Should handle multiple sequential workflow calls via RPC endpoint.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + try: + job = Job(replicas=1, role="test") + scheduler.create_workers(job) + + workers = scheduler.get_workers(role="test", timeout=30.0) + worker_id = workers[0].id + + from areal.tests.utils import TestWorkflow + + workflow = TestWorkflow() + ref_result = await workflow.arun_episode(None, None) + + for i in range(3): + result = await scheduler.async_call_engine( + worker_id=worker_id, + method="run_workflow", + workflow="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + data={"test_id": i, "value": f"test_{i}"}, + ) + + assert result is not None + assert result.keys() == ref_result.keys() + for key in result: + assert type(result[key]) is type(ref_result[key]) + + finally: + scheduler.delete_workers(role="test") + + @pytest.mark.integration + @pytest.mark.asyncio + async def test_run_workflow_serialization(self, tmp_path): + """Should correctly serialize and deserialize tensors through RPC.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + try: + job = Job(replicas=1, role="test") + scheduler.create_workers(job) + + workers = scheduler.get_workers(role="test", timeout=30.0) + worker_id = workers[0].id + + result = await scheduler.async_call_engine( + worker_id=worker_id, + method="run_workflow", + workflow="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + data={"test_id": 456, "value": "serialization_test"}, + ) + + assert result is not None + assert "input_ids" in result + assert "attention_mask" in result + assert "loss_mask" in result + assert "rewards" in result + + import torch + + assert isinstance(result["input_ids"], torch.Tensor) + assert isinstance(result["attention_mask"], torch.Tensor) + assert isinstance(result["loss_mask"], torch.Tensor) + assert isinstance(result["rewards"], torch.Tensor) + + assert result["input_ids"].dtype == torch.long + assert result["attention_mask"].dtype == torch.bool + assert result["loss_mask"].dtype == torch.bool + + finally: + scheduler.delete_workers(role="test") + + @pytest.mark.integration + @pytest.mark.asyncio + async def test_run_workflow_with_kwargs(self, tmp_path): + """Should support workflow instantiation with custom kwargs.""" + scheduler = LocalScheduler(gpu_devices=[0], log_dir=str(tmp_path)) + + try: + job = Job(replicas=1, role="test") + scheduler.create_workers(job) + + workers = scheduler.get_workers(role="test", timeout=30.0) + worker_id = workers[0].id + + result = await scheduler.async_call_engine( + worker_id=worker_id, + method="run_workflow", + workflow="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + data={"test_id": 789, "custom_param": "value"}, + ) + + assert result is not None + assert "input_ids" in result + assert "attention_mask" in result + assert "loss_mask" in result + assert "rewards" in result + + finally: + scheduler.delete_workers(role="test") diff --git a/areal/tests/test_rollout_controller.py b/areal/tests/test_rollout_controller.py new file mode 100644 index 000000000..afc028fd3 --- /dev/null +++ b/areal/tests/test_rollout_controller.py @@ -0,0 +1,1000 @@ +import asyncio +import queue +from concurrent.futures import Future +from unittest.mock import Mock, patch + +import pytest +import torch + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import InferenceEngineConfig +from areal.api.io_struct import ModelRequest, ParamSpec, WeightUpdateMeta +from areal.api.scheduler_api import Worker +from areal.controller.batch import DistributedBatchMemory +from areal.controller.rollout_controller import RolloutController +from areal.core.async_task_runner import TaskQueueFullError + + +class MockScheduler: + def __init__(self): + self.workers = [] + self.call_count = 0 + self.engine_calls = [] + + def create_workers(self, role, scheduler_config, *args, **kwargs): + worker_ids = [f"{role}/{i}" for i in range(scheduler_config.replicas)] + self.workers = [ + Worker( + id=wid, + ip="127.0.0.1", + worker_ports=["8000", "8001"], + engine_ports=["9000", "9001"], + ) + for wid in worker_ids + ] + return worker_ids + + def get_workers(self, role, timeout=None): + return self.workers + + async def create_engine(self, worker_id, engine, config): + pass + + async def async_call_engine(self, worker_id, method, *args, **kwargs): + self.engine_calls.append((worker_id, method, args, kwargs)) + self.call_count += 1 + + if method == "run_workflow": + await asyncio.sleep(0.01) + return { + "input_ids": torch.randint(0, 100, (1, 10)), + "attention_mask": torch.ones(1, 10, dtype=torch.bool), + "loss_mask": torch.tensor( + [0] * 5 + [1] * 5, dtype=torch.bool + ).unsqueeze(0), + "rewards": torch.randn(1), + } + elif method == "agenerate": + return Mock() + return None + + def call_engine(self, worker_id, method, *args, **kwargs): + self.engine_calls.append((worker_id, method, args, kwargs)) + + # For weight update methods that await call_engine, return a coroutine + if method in [ + "update_weights_from_distributed", + "update_weights_from_disk", + "init_weights_update_group", + ]: + return self._async_call_engine_internal(worker_id, method, *args, **kwargs) + + return None + + async def _async_call_engine_internal(self, worker_id, method, *args, **kwargs): + await asyncio.sleep(0.001) + return None + + def delete_workers(self, role): + self.workers.clear() + + +class MockInferenceEngine: + @classmethod + def __module__(cls): + return "areal.tests.test_rollout_controller" + + @classmethod + def __name__(cls): + return "MockInferenceEngine" + + +class TestRolloutControllerInitialization: + def test_constructor(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + assert controller.config == config + assert controller.scheduler == scheduler + assert controller.workers == [] + assert controller._current_worker_idx == 0 + assert controller._version == 0 + assert controller.runner is None + assert controller.executor is None + assert controller.staleness_manager is None + + def test_initialize_creates_workers(self): + config = InferenceEngineConfig( + consumer_batch_size=16, + max_head_offpolicyness=2, + enable_rollout_tracing=False, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + assert len(controller.workers) == 2 + assert controller.runner is not None + assert controller.executor is not None + assert controller.staleness_manager is not None + + controller.destroy() + + def test_initialize_creates_staleness_manager(self): + config = InferenceEngineConfig( + consumer_batch_size=32, + max_head_offpolicyness=5, + max_concurrent_rollouts=100, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + assert controller.staleness_manager.max_concurrent_rollouts == 100 + assert controller.staleness_manager.consumer_batch_size == 32 + assert controller.staleness_manager.max_staleness == 5 + + controller.destroy() + + def test_initialize_uses_consumer_batch_size_as_fallback(self): + config = InferenceEngineConfig( + consumer_batch_size=64, + max_head_offpolicyness=3, + max_concurrent_rollouts=None, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + assert controller.staleness_manager.max_concurrent_rollouts == 64 + + controller.destroy() + + def test_initialize_with_tracing_enabled(self): + config = InferenceEngineConfig( + consumer_batch_size=16, + max_head_offpolicyness=2, + enable_rollout_tracing=True, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + assert controller.runner.enable_tracing is True + + controller.destroy() + + +class TestRolloutControllerDestroy: + def test_destroy_cleans_up_resources(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + assert controller.runner is not None + assert controller.executor is not None + assert len(controller.workers) > 0 + + controller.destroy() + + assert controller.runner is None + assert controller.executor is None + assert len(controller.workers) == 0 + + def test_destroy_deletes_workers_via_scheduler(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + assert len(scheduler.workers) == 2 + + controller.destroy() + + assert len(scheduler.workers) == 0 + + def test_destroy_handles_scheduler_error(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + scheduler.delete_workers = Mock(side_effect=Exception("Test error")) + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + controller.destroy() + + +class TestRolloutControllerCapacity: + def test_get_capacity_initial_state(self): + config = InferenceEngineConfig( + consumer_batch_size=16, + max_concurrent_rollouts=32, + max_head_offpolicyness=2, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + capacity = controller.get_capacity() + assert capacity == 32 + + controller.destroy() + + def test_get_capacity_uses_version(self): + config = InferenceEngineConfig( + consumer_batch_size=8, + max_concurrent_rollouts=1000, + max_head_offpolicyness=2, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + capacity_v0 = controller.get_capacity() + + controller.set_version(5) + capacity_v5 = controller.get_capacity() + + assert capacity_v5 > capacity_v0 + + controller.destroy() + + +class TestRolloutControllerWorkerSelection: + def test_choose_worker_round_robin(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d3") + controller.initialize(alloc_mode=alloc_mode) + + worker_ids = [] + for _ in range(6): + worker = controller._choose_worker() + worker_ids.append(worker.id) + + assert worker_ids[0] == "rollout/0" + assert worker_ids[1] == "rollout/1" + assert worker_ids[2] == "rollout/2" + assert worker_ids[3] == "rollout/0" + assert worker_ids[4] == "rollout/1" + assert worker_ids[5] == "rollout/2" + + controller.destroy() + + +class TestRolloutControllerSubmitAndWait: + def test_submit_adds_to_pending_queue(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + data = {"test": "data"} + controller.submit( + data, workflow_path="areal.tests.utils.TestWorkflow", workflow_kwargs={} + ) + + assert len(controller._pending_inputs) == 1 + assert controller._pending_inputs[0].data == data + + controller.destroy() + + def test_submit_multiple_requests(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + for i in range(5): + controller.submit( + {"id": i}, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + assert len(controller._pending_inputs) == 5 + + controller.destroy() + + def test_wait_returns_distributed_batch(self): + config = InferenceEngineConfig( + consumer_batch_size=16, max_concurrent_rollouts=50 + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + for i in range(3): + controller.submit( + {"id": i}, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + batch = controller.wait(count=3, timeout=5.0) + + assert isinstance(batch, DistributedBatchMemory) + assert len(batch) == 3 + + controller.destroy() + + def test_wait_timeout_when_insufficient_results(self): + config = InferenceEngineConfig( + consumer_batch_size=16, max_concurrent_rollouts=10 + ) + scheduler = MockScheduler() + + async def slow_workflow(*args, **kwargs): + await asyncio.sleep(10.0) + return None + + scheduler.async_call_engine = slow_workflow + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + controller.submit( + {"id": 0}, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + with pytest.raises(TimeoutError, match="Timed out waiting for"): + controller.wait(count=1, timeout=0.2) + + controller.destroy() + + def test_wait_handles_rejected_rollouts(self): + config = InferenceEngineConfig( + consumer_batch_size=16, max_concurrent_rollouts=20 + ) + scheduler = MockScheduler() + + call_count = 0 + + async def mixed_results(*args, **kwargs): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) + if call_count % 2 == 0: + return None + return { + "input_ids": torch.randint(0, 100, (1, 10)), + "attention_mask": torch.ones(1, 10, dtype=torch.bool), + "loss_mask": torch.ones(1, 10, dtype=torch.bool), + "rewards": torch.randn(1), + } + + scheduler.async_call_engine = mixed_results + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + for i in range(6): + controller.submit( + {"id": i}, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + batch = controller.wait(count=3, timeout=2.0) + assert len(batch) == 3 + + controller.destroy() + + +class TestRolloutControllerBatchOperations: + def test_rollout_batch_submits_all_data(self): + config = InferenceEngineConfig( + consumer_batch_size=16, max_concurrent_rollouts=50 + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + batch_data = [{"id": i, "value": f"item_{i}"} for i in range(4)] + batch = controller.rollout_batch( + batch_data, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + assert isinstance(batch, DistributedBatchMemory) + assert len(batch) == 4 + + controller.destroy() + + def test_rollout_batch_waits_for_all_results(self): + config = InferenceEngineConfig( + consumer_batch_size=16, max_concurrent_rollouts=100 + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + batch_data = [{"id": i} for i in range(10)] + batch = controller.rollout_batch( + batch_data, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + assert len(batch) == 10 + + controller.destroy() + + +class TestRolloutControllerVersionManagement: + def test_get_version_initial(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + assert controller.get_version() == 0 + + def test_set_version_updates_controller_version(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + controller.set_version(42) + assert controller.get_version() == 42 + + controller.destroy() + + def test_set_version_calls_workers(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + controller.set_version(10) + + version_calls = [ + call for call in scheduler.engine_calls if call[1] == "set_version" + ] + assert len(version_calls) == 2 + + controller.destroy() + + def test_set_version_handles_worker_error(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + + def failing_call(*args, **kwargs): + raise Exception("Worker error") + + scheduler.call_engine = failing_call + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + controller.set_version(5) + + +class TestRolloutControllerWeightUpdates: + def test_init_weights_update_group_returns_future(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + meta = WeightUpdateMeta(type="disk", path="/tmp/test") + future = controller.init_weights_update_group(meta) + + assert isinstance(future, Future) + future.result(timeout=5.0) + + controller.destroy() + + def test_update_weights_from_distributed_returns_future(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + meta = WeightUpdateMeta(type="disk", path="/tmp/test") + param_specs = [ParamSpec(name="test", shape=(10, 10), dtype="float32")] + future = controller.update_weights_from_distributed(meta, param_specs) + + assert isinstance(future, Future) + future.result(timeout=5.0) + + controller.destroy() + + def test_update_weights_from_disk_returns_future(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + meta = WeightUpdateMeta(type="disk", path="/tmp/test") + future = controller.update_weights_from_disk(meta) + + assert isinstance(future, Future) + + controller.destroy() + + +class TestRolloutControllerLifecycle: + def test_pause_calls_all_workers(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d3") + controller.initialize(alloc_mode=alloc_mode) + + controller.pause() + + pause_calls = [call for call in scheduler.engine_calls if call[1] == "pause"] + assert len(pause_calls) == 3 + + controller.destroy() + + def test_resume_calls_all_workers(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d3") + controller.initialize(alloc_mode=alloc_mode) + + controller.resume() + + resume_calls = [call for call in scheduler.engine_calls if call[1] == "resume"] + assert len(resume_calls) == 3 + + controller.destroy() + + def test_pause_handles_worker_error(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + + def failing_call(*args, **kwargs): + raise Exception("Worker error") + + scheduler.call_engine = failing_call + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + controller.pause() + + def test_resume_handles_worker_error(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + + def failing_call(*args, **kwargs): + raise Exception("Worker error") + + scheduler.call_engine = failing_call + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + controller.resume() + + +class TestRolloutControllerAgenerate: + def test_agenerate_chooses_worker(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + req = ModelRequest(input_ids=[1, 2, 3, 4, 5]) + + async def test_agenerate(): + result = await controller.agenerate(req) + return result + + asyncio.run(test_agenerate()) + + agenerate_calls = [ + call for call in scheduler.engine_calls if call[1] == "agenerate" + ] + assert len(agenerate_calls) == 1 + assert agenerate_calls[0][3]["req"] == req + + controller.destroy() + + def test_agenerate_round_robin(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d3") + controller.initialize(alloc_mode=alloc_mode) + + async def test_multiple_agenerate(): + for _ in range(6): + req = ModelRequest(input_ids=[1, 2, 3]) + await controller.agenerate(req) + + asyncio.run(test_multiple_agenerate()) + + agenerate_calls = [ + call for call in scheduler.engine_calls if call[1] == "agenerate" + ] + worker_ids = [call[0] for call in agenerate_calls] + + assert worker_ids[0] == "rollout/0" + assert worker_ids[1] == "rollout/1" + assert worker_ids[2] == "rollout/2" + assert worker_ids[3] == "rollout/0" + + controller.destroy() + + +class TestRolloutControllerErrorHandling: + def test_commit_raises_on_queue_full(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + with patch.object( + controller.runner, "submit", side_effect=TaskQueueFullError("Queue full") + ): + controller.submit( + {"id": 0}, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + with pytest.raises(queue.Full): + controller._commit_one_to_runner() + + controller.destroy() + + def test_wait_returns_empty_batch_on_no_results(self): + config = InferenceEngineConfig( + consumer_batch_size=16, max_concurrent_rollouts=50 + ) + scheduler = MockScheduler() + + async def reject_all(*args, **kwargs): + await asyncio.sleep(0.01) + return None + + scheduler.async_call_engine = reject_all + + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + with pytest.raises(TimeoutError): + controller.wait(count=1, timeout=0.5) + + controller.destroy() + + +class TestRolloutControllerIntegration: + def test_end_to_end_workflow(self): + config = InferenceEngineConfig( + consumer_batch_size=8, + max_concurrent_rollouts=20, + max_head_offpolicyness=2, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d2") + controller.initialize(alloc_mode=alloc_mode) + + capacity = controller.get_capacity() + assert capacity == 20 + + for i in range(5): + controller.submit( + {"id": i}, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + + batch = controller.wait(count=5, timeout=5.0) + assert len(batch) == 5 + + controller.set_version(1) + assert controller.get_version() == 1 + + controller.destroy() + + def test_multiple_batch_cycles(self): + config = InferenceEngineConfig( + consumer_batch_size=4, + max_concurrent_rollouts=50, + max_head_offpolicyness=5, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + for cycle in range(3): + batch_data = [{"id": i, "cycle": cycle} for i in range(4)] + batch = controller.rollout_batch( + batch_data, + workflow_path="areal.tests.utils.TestWorkflow", + workflow_kwargs={}, + ) + assert len(batch) == 4 + + controller.destroy() + + +class TestRolloutControllerNotImplemented: + def test_register_callback_not_implemented(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + with pytest.raises(NotImplementedError): + controller.register_callback_to_all_worker("test", lambda: None) + + def test_abort_all_requests_not_implemented(self): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + with pytest.raises(NotImplementedError): + controller.abort_all_requests() + + +@pytest.mark.parametrize("num_workers", [1, 2, 4]) +def test_parametrized_worker_count(num_workers): + config = InferenceEngineConfig(consumer_batch_size=16) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str(f"sglang.d{num_workers}") + controller.initialize(alloc_mode=alloc_mode) + + assert len(controller.workers) == num_workers + + controller.destroy() + + +@pytest.mark.parametrize( + "consumer_batch_size,max_concurrent_rollouts,expected_capacity", + [(16, 32, 32), (32, 64, 64), (8, 100, 24)], +) +def test_parametrized_capacity_settings( + consumer_batch_size, max_concurrent_rollouts, expected_capacity +): + config = InferenceEngineConfig( + consumer_batch_size=consumer_batch_size, + max_concurrent_rollouts=max_concurrent_rollouts, + max_head_offpolicyness=2, + ) + scheduler = MockScheduler() + controller = RolloutController( + inf_engine=MockInferenceEngine, + config=config, + scheduler=scheduler, + ) + + alloc_mode = AllocationMode.from_str("sglang.d1") + controller.initialize(alloc_mode=alloc_mode) + + capacity = controller.get_capacity() + assert capacity == expected_capacity + + controller.destroy() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/areal/tests/test_serialization.py b/areal/tests/test_serialization.py new file mode 100644 index 000000000..76b7836cb --- /dev/null +++ b/areal/tests/test_serialization.py @@ -0,0 +1,648 @@ +"""Pytest test suite for tensor and dataclass serialization utilities.""" + +from dataclasses import dataclass + +import pytest +import torch + +from areal.scheduler.rpc.serialization import ( + SerializedDataclass, + SerializedTensor, + deserialize_value, + serialize_value, +) + + +# Test dataclasses +@dataclass +class SimpleConfig: + """Simple test dataclass.""" + + batch_size: int + learning_rate: float + name: str + + +@dataclass +class ConfigWithTensor: + """Dataclass containing a tensor field.""" + + data: torch.Tensor + label: str + + +@dataclass +class NestedConfig: + """Dataclass containing another dataclass.""" + + inner: SimpleConfig + outer_value: int + + +class TestSerializedTensor: + """Test suite for SerializedTensor Pydantic model.""" + + def test_from_tensor_float32(self): + """Test serialization of float32 tensor.""" + tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + serialized = SerializedTensor.from_tensor(tensor) + + assert serialized.type == "tensor" + assert serialized.shape == [3] + assert serialized.dtype == "torch.float32" + + def test_from_tensor_various_dtypes(self): + """Test serialization of tensors with various dtypes.""" + dtypes = [ + torch.float32, + torch.float64, + torch.int32, + torch.int64, + torch.bool, + torch.uint8, + ] + + for dtype in dtypes: + tensor = torch.tensor([1, 2, 3], dtype=dtype) + serialized = SerializedTensor.from_tensor(tensor) + assert serialized.dtype == str(dtype) + + def test_from_tensor_various_shapes(self): + """Test serialization of tensors with various shapes.""" + shapes = [ + (), # scalar + (5,), # 1D + (3, 4), # 2D + (2, 3, 4), # 3D + (2, 3, 4, 5), # 4D + ] + + for shape in shapes: + tensor = torch.randn(shape) + serialized = SerializedTensor.from_tensor(tensor) + assert serialized.shape == list(shape) + + def test_from_tensor_with_requires_grad(self): + """Test serialization ignores requires_grad flag.""" + tensor = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + serialized = SerializedTensor.from_tensor(tensor) + # Serialization should work but requires_grad is not preserved + assert serialized.type == "tensor" + + def test_roundtrip_float32(self): + """Test serialize-deserialize roundtrip for float32 tensor.""" + original = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + serialized = SerializedTensor.from_tensor(original) + reconstructed = serialized.to_tensor() + + assert torch.allclose(original, reconstructed) + assert reconstructed.dtype == original.dtype + assert reconstructed.shape == original.shape + + def test_roundtrip_various_dtypes(self): + """Test roundtrip for various dtypes.""" + test_cases = [ + (torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32), torch.float32), + (torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64), torch.float64), + (torch.tensor([1, 2, 3], dtype=torch.int32), torch.int32), + (torch.tensor([1, 2, 3], dtype=torch.int64), torch.int64), + (torch.tensor([True, False, True], dtype=torch.bool), torch.bool), + ] + + for original, expected_dtype in test_cases: + serialized = SerializedTensor.from_tensor(original) + reconstructed = serialized.to_tensor() + + assert reconstructed.dtype == expected_dtype + if expected_dtype == torch.bool: + assert torch.equal(original, reconstructed) + else: + assert torch.allclose(original.float(), reconstructed.float()) + + def test_roundtrip_ignores_requires_grad(self): + """Test roundtrip does not preserve requires_grad.""" + original = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + serialized = SerializedTensor.from_tensor(original) + reconstructed = serialized.to_tensor() + + # requires_grad is not preserved + assert reconstructed.requires_grad is False + assert torch.allclose(original.detach(), reconstructed) + + def test_empty_tensor(self): + """Test serialization of empty tensor.""" + tensor = torch.tensor([]) + serialized = SerializedTensor.from_tensor(tensor) + reconstructed = serialized.to_tensor() + + assert reconstructed.shape == torch.Size([0]) + assert torch.equal(tensor, reconstructed) + + def test_large_tensor(self): + """Test serialization of large tensor.""" + tensor = torch.randn(100, 100) + serialized = SerializedTensor.from_tensor(tensor) + reconstructed = serialized.to_tensor() + + assert torch.allclose(tensor, reconstructed) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_cuda_tensor(self): + """Test serialization of CUDA tensor (moves to CPU).""" + tensor = torch.tensor([1.0, 2.0, 3.0]).cuda() + serialized = SerializedTensor.from_tensor(tensor) + + # Serialization works with CUDA tensors + assert serialized.type == "tensor" + + # Reconstructed tensor is always on CPU + reconstructed = serialized.to_tensor() + assert reconstructed.device.type == "cpu" + assert torch.allclose(tensor.cpu(), reconstructed) + + +class TestSerializeValue: + """Test suite for serialize_value function.""" + + def test_serialize_none(self): + """Test serialization of None.""" + assert serialize_value(None) is None + + def test_serialize_primitives(self): + """Test serialization of primitive types.""" + assert serialize_value(42) == 42 + assert serialize_value(3.14) == 3.14 + assert serialize_value("hello") == "hello" + assert serialize_value(True) is True + + def test_serialize_tensor(self): + """Test serialization of torch tensor.""" + tensor = torch.tensor([1.0, 2.0, 3.0]) + result = serialize_value(tensor) + + assert isinstance(result, dict) + assert result["type"] == "tensor" + assert "data" in result + assert "shape" in result + assert "dtype" in result + + def test_serialize_list_of_primitives(self): + """Test serialization of list of primitives.""" + original = [1, 2, 3, "hello", True] + result = serialize_value(original) + + assert result == original + + def test_serialize_list_of_tensors(self): + """Test serialization of list of tensors.""" + tensors = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])] + result = serialize_value(tensors) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(item["type"] == "tensor" for item in result) + + def test_serialize_dict_of_primitives(self): + """Test serialization of dict with primitives.""" + original = {"a": 1, "b": "hello", "c": 3.14} + result = serialize_value(original) + + assert result == original + + def test_serialize_dict_of_tensors(self): + """Test serialization of dict with tensors.""" + tensors = { + "input_ids": torch.tensor([1, 2, 3]), + "attention_mask": torch.tensor([1, 1, 1]), + } + result = serialize_value(tensors) + + assert isinstance(result, dict) + assert all(result[key]["type"] == "tensor" for key in result) + + def test_serialize_nested_dict(self): + """Test serialization of nested dictionary.""" + nested = { + "level1": {"level2": {"tensor": torch.tensor([1.0, 2.0]), "value": 42}} + } + result = serialize_value(nested) + + assert isinstance(result, dict) + assert isinstance(result["level1"], dict) + assert isinstance(result["level1"]["level2"], dict) + assert result["level1"]["level2"]["tensor"]["type"] == "tensor" + assert result["level1"]["level2"]["value"] == 42 + + def test_serialize_mixed_structure(self): + """Test serialization of complex mixed structure.""" + mixed = { + "tensors": [torch.tensor([1.0]), torch.tensor([2.0])], + "metadata": {"batch_size": 32, "device": "cpu"}, + "mask": torch.tensor([True, False, True]), + } + result = serialize_value(mixed) + + assert isinstance(result["tensors"], list) + assert result["tensors"][0]["type"] == "tensor" + assert result["metadata"]["batch_size"] == 32 + assert result["mask"]["type"] == "tensor" + + def test_serialize_tuple(self): + """Test serialization of tuple (converts to list).""" + original = (1, 2, torch.tensor([3.0])) + result = serialize_value(original) + + assert isinstance(result, list) + assert result[0] == 1 + assert result[1] == 2 + assert result[2]["type"] == "tensor" + + +class TestDeserializeValue: + """Test suite for deserialize_value function.""" + + def test_deserialize_none(self): + """Test deserialization of None.""" + assert deserialize_value(None) is None + + def test_deserialize_primitives(self): + """Test deserialization of primitive types.""" + assert deserialize_value(42) == 42 + assert deserialize_value(3.14) == 3.14 + assert deserialize_value("hello") == "hello" + assert deserialize_value(True) is True + + def test_deserialize_tensor(self): + """Test deserialization of serialized tensor.""" + tensor = torch.tensor([1.0, 2.0, 3.0]) + serialized = serialize_value(tensor) + result = deserialize_value(serialized) + + assert isinstance(result, torch.Tensor) + assert torch.allclose(tensor, result) + + def test_deserialize_list_of_primitives(self): + """Test deserialization of list of primitives.""" + original = [1, 2, 3, "hello", True] + result = deserialize_value(original) + + assert result == original + + def test_deserialize_list_of_tensors(self): + """Test deserialization of list of tensors.""" + tensors = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])] + serialized = serialize_value(tensors) + result = deserialize_value(serialized) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(item, torch.Tensor) for item in result) + assert torch.allclose(tensors[0], result[0]) + assert torch.allclose(tensors[1], result[1]) + + def test_deserialize_dict_of_tensors(self): + """Test deserialization of dict with tensors.""" + original = { + "input_ids": torch.tensor([1, 2, 3]), + "attention_mask": torch.tensor([1, 1, 1]), + } + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert isinstance(result, dict) + assert all(isinstance(result[key], torch.Tensor) for key in result) + assert torch.equal(original["input_ids"], result["input_ids"]) + assert torch.equal(original["attention_mask"], result["attention_mask"]) + + def test_deserialize_nested_structure(self): + """Test deserialization of nested structure.""" + original = { + "level1": { + "tensor": torch.tensor([1.0, 2.0]), + "value": 42, + } + } + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert isinstance(result["level1"]["tensor"], torch.Tensor) + assert result["level1"]["value"] == 42 + assert torch.allclose(original["level1"]["tensor"], result["level1"]["tensor"]) + + def test_deserialize_invalid_tensor_dict(self): + """Test deserialization handles invalid tensor dict gracefully.""" + # Dict with type="tensor" but missing required fields + invalid = {"type": "tensor", "invalid_field": "value"} + result = deserialize_value(invalid) + + # Should treat as regular dict if parsing fails + assert isinstance(result, dict) + assert result["type"] == "tensor" + + +class TestRoundtrip: + """Test suite for full serialize-deserialize roundtrips.""" + + def test_roundtrip_simple_tensor(self): + """Test roundtrip for simple tensor.""" + original = torch.tensor([1.0, 2.0, 3.0]) + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert torch.allclose(original, result) + + def test_roundtrip_trajectory_dict(self): + """Test roundtrip for typical trajectory dictionary.""" + trajectory = { + "input_ids": torch.tensor([101, 102, 103, 104]), + "attention_mask": torch.tensor([1, 1, 1, 1]), + "rewards": torch.tensor([0.1, 0.2, 0.3, 0.4]), + "logprobs": torch.tensor([-1.0, -2.0, -3.0, -4.0]), + } + + serialized = serialize_value(trajectory) + result = deserialize_value(serialized) + + assert isinstance(result, dict) + assert set(result.keys()) == set(trajectory.keys()) + for key in trajectory: + assert torch.allclose(trajectory[key].float(), result[key].float()) + + def test_roundtrip_mixed_types(self): + """Test roundtrip for mixed type structure.""" + original = { + "tensors": [torch.tensor([1.0]), torch.tensor([2.0])], + "metadata": {"count": 2, "name": "test"}, + "value": 42, + "flag": True, + } + + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert len(result["tensors"]) == 2 + assert torch.allclose(original["tensors"][0], result["tensors"][0]) + assert result["metadata"] == original["metadata"] + assert result["value"] == original["value"] + assert result["flag"] == original["flag"] + + def test_roundtrip_with_none_values(self): + """Test roundtrip with None values in structure.""" + original = { + "tensor": torch.tensor([1.0, 2.0]), + "optional": None, + "nested": {"value": 42, "empty": None}, + } + + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert torch.allclose(original["tensor"], result["tensor"]) + assert result["optional"] is None + assert result["nested"]["value"] == 42 + assert result["nested"]["empty"] is None + + def test_roundtrip_empty_structures(self): + """Test roundtrip for empty structures.""" + test_cases = [ + {}, # Empty dict + [], # Empty list + {"empty_list": [], "empty_dict": {}}, # Nested empty + ] + + for original in test_cases: + serialized = serialize_value(original) + result = deserialize_value(serialized) + assert result == original + + +class TestSerializedDataclass: + """Test suite for SerializedDataclass Pydantic model.""" + + def test_from_dataclass_simple(self): + """Test serialization of simple dataclass.""" + config = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + serialized = SerializedDataclass.from_dataclass(config) + + assert serialized.type == "dataclass" + assert "SimpleConfig" in serialized.class_path + assert serialized.data["batch_size"] == 32 + assert serialized.data["learning_rate"] == 0.001 + assert serialized.data["name"] == "test" + + def test_to_dataclass_simple(self): + """Test deserialization of simple dataclass.""" + config = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + serialized = SerializedDataclass.from_dataclass(config) + dataclass_type, data = serialized.to_dataclass() + + reconstructed = dataclass_type(**data) + assert isinstance(reconstructed, SimpleConfig) + assert reconstructed.batch_size == 32 + assert reconstructed.learning_rate == 0.001 + assert reconstructed.name == "test" + + def test_roundtrip_simple_dataclass(self): + """Test serialize-deserialize roundtrip for simple dataclass.""" + original = SimpleConfig(batch_size=64, learning_rate=0.01, name="experiment") + serialized = SerializedDataclass.from_dataclass(original) + dataclass_type, data = serialized.to_dataclass() + reconstructed = dataclass_type(**data) + + assert reconstructed.batch_size == original.batch_size + assert reconstructed.learning_rate == original.learning_rate + assert reconstructed.name == original.name + + +class TestSerializeValueDataclass: + """Test suite for serialize_value with dataclasses.""" + + def test_serialize_simple_dataclass(self): + """Test serialization of simple dataclass.""" + config = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + result = serialize_value(config) + + assert isinstance(result, dict) + assert result["type"] == "dataclass" + assert "SimpleConfig" in result["class_path"] + assert result["data"]["batch_size"] == 32 + + def test_serialize_dataclass_with_tensor(self): + """Test serialization of dataclass containing tensor.""" + config = ConfigWithTensor(data=torch.tensor([1.0, 2.0, 3.0]), label="example") + result = serialize_value(config) + + assert result["type"] == "dataclass" + assert result["data"]["label"] == "example" + # Tensor should be serialized within dataclass + assert result["data"]["data"]["type"] == "tensor" + + def test_serialize_nested_dataclass(self): + """Test serialization of nested dataclass.""" + inner = SimpleConfig(batch_size=16, learning_rate=0.01, name="inner") + outer = NestedConfig(inner=inner, outer_value=42) + result = serialize_value(outer) + + assert result["type"] == "dataclass" + assert "NestedConfig" in result["class_path"] + # Inner dataclass should also be serialized + assert result["data"]["inner"]["type"] == "dataclass" + assert result["data"]["inner"]["data"]["batch_size"] == 16 + assert result["data"]["outer_value"] == 42 + + def test_serialize_list_of_dataclasses(self): + """Test serialization of list containing dataclasses.""" + configs = [ + SimpleConfig(batch_size=32, learning_rate=0.001, name="config1"), + SimpleConfig(batch_size=64, learning_rate=0.002, name="config2"), + ] + result = serialize_value(configs) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(item["type"] == "dataclass" for item in result) + assert result[0]["data"]["batch_size"] == 32 + assert result[1]["data"]["batch_size"] == 64 + + def test_serialize_dict_with_dataclass_values(self): + """Test serialization of dict with dataclass values.""" + data = { + "config": SimpleConfig(batch_size=32, learning_rate=0.001, name="test"), + "value": 42, + } + result = serialize_value(data) + + assert result["config"]["type"] == "dataclass" + assert result["config"]["data"]["batch_size"] == 32 + assert result["value"] == 42 + + +class TestDeserializeValueDataclass: + """Test suite for deserialize_value with dataclasses.""" + + def test_deserialize_simple_dataclass(self): + """Test deserialization of simple dataclass.""" + config = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + serialized = serialize_value(config) + result = deserialize_value(serialized) + + assert isinstance(result, SimpleConfig) + assert result.batch_size == 32 + assert result.learning_rate == 0.001 + assert result.name == "test" + + def test_deserialize_dataclass_with_tensor(self): + """Test deserialization of dataclass containing tensor.""" + original_tensor = torch.tensor([1.0, 2.0, 3.0]) + config = ConfigWithTensor(data=original_tensor, label="example") + serialized = serialize_value(config) + result = deserialize_value(serialized) + + assert isinstance(result, ConfigWithTensor) + assert result.label == "example" + assert isinstance(result.data, torch.Tensor) + assert torch.allclose(original_tensor, result.data) + + def test_deserialize_nested_dataclass(self): + """Test deserialization of nested dataclass.""" + inner = SimpleConfig(batch_size=16, learning_rate=0.01, name="inner") + outer = NestedConfig(inner=inner, outer_value=42) + serialized = serialize_value(outer) + result = deserialize_value(serialized) + + assert isinstance(result, NestedConfig) + assert isinstance(result.inner, SimpleConfig) + assert result.inner.batch_size == 16 + assert result.inner.learning_rate == 0.01 + assert result.outer_value == 42 + + def test_deserialize_list_of_dataclasses(self): + """Test deserialization of list containing dataclasses.""" + configs = [ + SimpleConfig(batch_size=32, learning_rate=0.001, name="config1"), + SimpleConfig(batch_size=64, learning_rate=0.002, name="config2"), + ] + serialized = serialize_value(configs) + result = deserialize_value(serialized) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(item, SimpleConfig) for item in result) + assert result[0].batch_size == 32 + assert result[1].batch_size == 64 + + +class TestRoundtripDataclass: + """Test suite for full serialize-deserialize roundtrips with dataclasses.""" + + def test_roundtrip_simple_dataclass(self): + """Test roundtrip for simple dataclass.""" + original = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert result.batch_size == original.batch_size + assert result.learning_rate == original.learning_rate + assert result.name == original.name + + def test_roundtrip_dataclass_with_tensor(self): + """Test roundtrip for dataclass with tensor field.""" + original = ConfigWithTensor(data=torch.tensor([1.0, 2.0, 3.0]), label="test") + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert isinstance(result, ConfigWithTensor) + assert result.label == original.label + assert torch.allclose(original.data, result.data) + + def test_roundtrip_nested_dataclass(self): + """Test roundtrip for nested dataclass.""" + inner = SimpleConfig(batch_size=16, learning_rate=0.01, name="inner") + original = NestedConfig(inner=inner, outer_value=42) + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert isinstance(result, NestedConfig) + assert isinstance(result.inner, SimpleConfig) + assert result.inner.batch_size == original.inner.batch_size + assert result.outer_value == original.outer_value + + def test_roundtrip_mixed_dataclass_and_tensor(self): + """Test roundtrip for structure with both dataclasses and tensors.""" + config = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + original = { + "config": config, + "tensor": torch.tensor([1.0, 2.0, 3.0]), + "metadata": {"count": 3, "type": "experiment"}, + } + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert isinstance(result["config"], SimpleConfig) + assert result["config"].batch_size == 32 + assert isinstance(result["tensor"], torch.Tensor) + assert torch.allclose(original["tensor"], result["tensor"]) + assert result["metadata"] == original["metadata"] + + def test_roundtrip_list_of_mixed_types(self): + """Test roundtrip for list containing dataclasses, tensors, and primitives.""" + config = SimpleConfig(batch_size=32, learning_rate=0.001, name="test") + original = [ + config, + torch.tensor([1.0, 2.0]), + 42, + "string", + ] + serialized = serialize_value(original) + result = deserialize_value(serialized) + + assert isinstance(result[0], SimpleConfig) + assert result[0].batch_size == 32 + assert isinstance(result[1], torch.Tensor) + assert torch.allclose(original[1], result[1]) + assert result[2] == 42 + assert result[3] == "string" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/areal/tests/test_train_controller.py b/areal/tests/test_train_controller.py new file mode 100644 index 000000000..5ca2fa7f2 --- /dev/null +++ b/areal/tests/test_train_controller.py @@ -0,0 +1,881 @@ +"""Unit tests for TrainController. + +Tests cover initialization, worker management, batch operations, +RPC wrappers, PPO/SFT methods, weight management, and error handling. +""" + +import asyncio +from unittest.mock import Mock, patch + +import pytest +import torch + +from areal.api.alloc_mode import ParallelStrategy +from areal.api.cli_args import SchedulingSpec, SchedulingStrategy, TrainEngineConfig +from areal.api.engine_api import TrainEngine +from areal.api.io_struct import ( + AllocationMode, + FinetuneSpec, + SaveLoadMeta, + WeightUpdateMeta, +) +from areal.api.scheduler_api import Worker +from areal.controller.batch import DistributedBatchMemory +from areal.controller.train_controller import TrainController + + +class MockTrainEngine(TrainEngine): + """Mock TrainEngine for testing.""" + + @classmethod + def __module__(cls): + return "areal.tests.test_train_controller" + + @classmethod + def __name__(cls): + return "MockTrainEngine" + + +class MockScheduler: + """Mock Scheduler for testing TrainController.""" + + def __init__(self): + self.workers = [] + self.call_count = 0 + self.engine_calls = [] + self.deleted_roles = [] + + def create_workers(self, job): + """Create mock workers based on job configuration.""" + worker_ids = [f"{job.role}/{i}" for i in range(job.replicas)] + self.workers = [ + Worker( + id=wid, + ip="127.0.0.1", + worker_ports=["8000", "8001"], + engine_ports=["9000", "9001"], + ) + for wid in worker_ids + ] + return worker_ids + + def get_workers(self, role, timeout=None): + """Return list of workers for the given role.""" + return self.workers + + async def create_engine(self, worker_id, engine, config): + """Mock engine creation.""" + await asyncio.sleep(0.001) + return None + + async def async_call_engine(self, worker_id, method, *args, **kwargs): + """Mock async engine call.""" + self.engine_calls.append((worker_id, method, args, kwargs)) + self.call_count += 1 + + # Return appropriate mock results based on method + if method == "is_data_parallel_head": + # First worker in each DP group is the head + worker_idx = int(worker_id.split("/")[-1]) + return worker_idx % 2 == 0 # Every other worker is a DP head + + elif method == "get_version": + return 1 + + elif method == "train_batch": + return {"loss": 0.5, "lr": 0.001, "grad_norm": 1.0} + + elif method == "eval_batch": + return torch.tensor(0.3) + + elif method == "forward": + return {"logits": torch.randn(4, 10, 50257)} + + elif method == "compute_logp": + return {"log_probs": torch.randn(4, 10)} + + elif method == "compute_advantages": + return {"advantages": torch.randn(4)} + + elif method == "ppo_update": + return {"ppo_loss": 0.2, "kl_div": 0.01} + + elif method == "train_lm": + return {"lm_loss": 0.4, "perplexity": 1.5} + + elif method == "evaluate_lm": + return torch.tensor(0.35) + + await asyncio.sleep(0.001) + return None + + def delete_workers(self, role): + """Mock worker deletion.""" + self.deleted_roles.append(role) + self.workers.clear() + + +# ==================== FIXTURES ==================== + + +@pytest.fixture +def mock_scheduler(): + """Provide a MockScheduler instance.""" + return MockScheduler() + + +@pytest.fixture +def train_config(): + """Provide a TrainEngineConfig for testing.""" + return TrainEngineConfig( + scheduling_spec=SchedulingSpec(cpu=4, gpu=1, mem=16000, port_count=2) + ) + + +@pytest.fixture +def alloc_mode(): + """Provide an AllocationMode for testing.""" + mode = AllocationMode.from_str("d4t2p1") + return mode + + +@pytest.fixture +def parallel_strategy(): + """Provide a ParallelStrategy for testing.""" + return ParallelStrategy( + data_parallel_size=4, tensor_parallel_size=2, pipeline_parallel_size=1 + ) + + +@pytest.fixture +def ft_spec(): + """Provide a FinetuneSpec for testing.""" + return FinetuneSpec(total_train_epochs=10, dataset_size=1000, train_batch_size=32) + + +@pytest.fixture +def scheduling_strategy(): + """Provide a SchedulingStrategy for testing.""" + return SchedulingStrategy(type="separation", target="") + + +@pytest.fixture +def train_controller(mock_scheduler, train_config): + """Provide a TrainController instance.""" + return TrainController( + train_engine=MockTrainEngine, config=train_config, scheduler=mock_scheduler + ) + + +def create_mock_distributed_batch(size=4, seq_len=10): + """Create a mock DistributedBatch for testing.""" + data = { + "input_ids": torch.randint(0, 100, (size, seq_len)), + "attention_mask": torch.ones(size, seq_len, dtype=torch.bool), + "loss_mask": torch.ones(size, seq_len, dtype=torch.bool), + } + return DistributedBatchMemory.from_dict(data) + + +# ==================== TEST CLASSES ==================== + + +class TestTrainControllerInitialization: + """Tests for TrainController initialization and setup.""" + + def test_constructor(self, mock_scheduler, train_config): + """Test TrainController constructor.""" + controller = TrainController( + train_engine=MockTrainEngine, config=train_config, scheduler=mock_scheduler + ) + + assert controller.train_engine == MockTrainEngine + assert controller.config == train_config + assert controller.scheduler == mock_scheduler + assert controller.workers == [] + assert controller.worker_is_dp_head == [] + assert controller.logger is None + + def test_initialize( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test initialize method creates workers and engines.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + # Verify workers were created + assert len(train_controller.workers) == alloc_mode.train.world_size + assert train_controller._worker_role == "train_worker" + assert train_controller.alloc_mode == alloc_mode + + # Verify DP heads were identified + assert len(train_controller.worker_is_dp_head) == len(train_controller.workers) + + # Verify scheduler was called + assert train_controller.scheduler.call_count > 0 + + def test_create_process_group_sets_parallel_strategy( + self, train_controller, parallel_strategy + ): + """Test that create_process_group correctly assigns parallel_strategy. + + This is a regression test for the bug at line 79 where parallel_strategy + was being assigned to itself instead of the parameter. + """ + # Setup: Add mock workers + train_controller.workers = [Mock(), Mock()] + train_controller.parallel_strategy = parallel_strategy + + # Call create_process_group with a different strategy + new_strategy = ParallelStrategy( + data_parallel_size=8, tensor_parallel_size=1, pipeline_parallel_size=1 + ) + + with patch.object(train_controller, "custom_function_call") as mock_custom_call: + train_controller.create_process_group(new_strategy) + + # Verify the parallel_strategy should be updated to new_strategy + # NOTE: This test currently fails due to the bug at line 79 + # After fixing the bug, this assertion should pass + assert train_controller.parallel_strategy == new_strategy + assert train_controller.parallel_strategy.data_parallel_size == 8 + + # Verify custom_function_call was invoked + mock_custom_call.assert_called_once_with( + "create_process_group", new_strategy + ) + + def test_identify_dp_heads( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test _identify_dp_heads correctly identifies DP head workers.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + # MockScheduler returns True for even-indexed workers + for idx, is_head in enumerate(train_controller.worker_is_dp_head): + assert is_head == (idx % 2 == 0) + + +class TestTrainControllerDestroy: + """Tests for TrainController cleanup and destruction.""" + + def test_destroy(self, train_controller, alloc_mode, ft_spec, scheduling_strategy): + """Test destroy method cleans up resources.""" + # Initialize first + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + initial_worker_count = len(train_controller.workers) + assert initial_worker_count > 0 + + # Call destroy + train_controller.destroy() + + # Verify cleanup + assert len(train_controller.workers) == 0 + assert len(train_controller.worker_is_dp_head) == 0 + assert "train_worker" in train_controller.scheduler.deleted_roles + + def test_destroy_handles_errors( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test destroy handles errors gracefully.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + # Make delete_workers raise an exception + def raise_error(role): + raise RuntimeError("Simulated error") + + train_controller.scheduler.delete_workers = raise_error + + # Should not raise, just log the error + train_controller.destroy() + + # Workers should still be cleared + assert len(train_controller.workers) == 0 + + +class TestTrainControllerBatchOperations: + """Tests for batch splitting and alignment operations.""" + + def test_align_batches_with_dp_rebalance( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test _align_batches_with_dp with rebalance=True.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + batch = create_mock_distributed_batch(size=16) + chunks = train_controller._align_batches_with_dp(batch, rebalance=True) + + # Should split into dp_size chunks + assert len(chunks) == alloc_mode.train.dp_size + + # Each chunk should be a DistributedBatch + for chunk in chunks: + assert isinstance(chunk, DistributedBatchMemory) + + def test_align_batches_with_dp_no_rebalance( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test _align_batches_with_dp with rebalance=False.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + batch = create_mock_distributed_batch(size=16) + chunks = train_controller._align_batches_with_dp(batch, rebalance=False) + + # Should split into dp_size chunks + assert len(chunks) == alloc_mode.train.dp_size + + # Each chunk should be a DistributedBatch + for chunk in chunks: + assert isinstance(chunk, DistributedBatchMemory) + + +class TestTrainControllerMergeResults: + """Tests for result merging from workers.""" + + def test_merge_results_with_tensor_dict(self, train_controller): + """Test _merge_results with dictionary of tensors.""" + results = [ + {"loss": torch.tensor([0.5, 0.6])}, + {"loss": torch.tensor([0.3, 0.4])}, + ] + + merged = train_controller._merge_results(results, "train_batch") + + # Should concatenate into DistributedBatch + assert isinstance(merged, DistributedBatchMemory) + assert "loss" in merged.get_data() + + def test_merge_results_with_empty_dict(self, train_controller): + """Test _merge_results with empty dictionaries.""" + results = [{}, {}] + + merged = train_controller._merge_results(results, "some_method") + + # Should return empty DistributedBatch + assert isinstance(merged, DistributedBatchMemory) + assert len(merged.get_data()) == 0 + + def test_merge_results_with_non_tensor(self, train_controller): + """Test _merge_results with non-tensor results.""" + results = [{"status": "ok"}, {"status": "ok"}] + + merged = train_controller._merge_results(results, "some_method") + + # Should return first result (already synchronized) + assert merged == {"status": "ok"} + + def test_merge_results_accepts_method_parameter(self, train_controller): + """Test that _merge_results accepts method parameter. + + This is a regression test for the bug at line 279 where the method + parameter was missing from the signature. + """ + results = [torch.tensor(0.5), torch.tensor(0.3)] + + # This should work without TypeError + try: + result = train_controller._merge_results(results, "train_batch") + # Test passes if no exception + assert result is not None + except TypeError as e: + if "missing" in str(e) and "required positional argument" in str(e): + pytest.fail(f"_merge_results missing required parameter: {e}") + + +class TestTrainControllerRPCWrappers: + """Tests for RPC wrapper methods.""" + + def test_train_mode( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test train() method sets training mode.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + result = train_controller.train(mode=True) + + # Should return self for chaining + assert result is train_controller + + # Verify custom_function_call was invoked + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "train" in engine_calls + + def test_eval_mode( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test eval() method sets evaluation mode.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + result = train_controller.eval() + + # Should return self for chaining + assert result is train_controller + + # Verify train(False) was called + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "train" in engine_calls + + def test_forward(self, train_controller, alloc_mode, ft_spec, scheduling_strategy): + """Test forward() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + result = train_controller.forward(batch) + + # Should return merged results from DP heads + assert result is not None + + # Verify forward was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "forward" in engine_calls + + def test_train_batch( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test train_batch() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + + def loss_fn(output, batch_data): + return torch.tensor(0.5) + + def loss_weight_fn(batch_data): + return torch.tensor(1.0) + + result = train_controller.train_batch(batch, loss_fn, loss_weight_fn) + + # Should return stats dictionary + assert isinstance(result, dict) + + # Verify train_batch was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "train_batch" in engine_calls + + def test_eval_batch( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test eval_batch() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + + def loss_fn(output, batch_data): + return torch.tensor(0.3) + + def loss_weight_fn(batch_data): + return torch.tensor(1.0) + + result = train_controller.eval_batch(batch, loss_fn, loss_weight_fn) + + # Should return loss tensor or merged results + assert result is not None + + # Verify eval_batch was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "eval_batch" in engine_calls + + def test_step_lr_scheduler( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test step_lr_scheduler() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + train_controller.step_lr_scheduler() + + # Verify step_lr_scheduler was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "step_lr_scheduler" in engine_calls + + +class TestTrainControllerPPOMethods: + """Tests for PPO-specific methods.""" + + def test_compute_logp( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test compute_logp() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + result = train_controller.compute_logp() + + # Should return merged results + assert result is not None + + # Verify compute_logp was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "compute_logp" in engine_calls + + def test_compute_advantages( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test compute_advantages() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + result = train_controller.compute_advantages() + + # Should return merged results + assert result is not None + + # Verify compute_advantages was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "compute_advantages" in engine_calls + + def test_ppo_update( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test ppo_update() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + result = train_controller.ppo_update(batch) + + # Should return stats dictionary + assert isinstance(result, dict) + + # Verify ppo_update was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "ppo_update" in engine_calls + + +class TestTrainControllerSFTMethods: + """Tests for SFT-specific methods.""" + + def test_train_lm(self, train_controller, alloc_mode, ft_spec, scheduling_strategy): + """Test train_lm() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + result = train_controller.train_lm(batch) + + # Should return stats dictionary + assert isinstance(result, dict) + + # Verify train_lm was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "train_lm" in engine_calls + + def test_evaluate_lm( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test evaluate_lm() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + result = train_controller.evaluate_lm(batch) + + # Should return loss tensor or merged results + assert result is not None + + # Verify evaluate_lm was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "evaluate_lm" in engine_calls + + +class TestTrainControllerWeightManagement: + """Tests for weight management operations.""" + + def test_set_version( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test set_version() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + train_controller.set_version(42) + + # Verify set_version was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "set_version" in engine_calls + + def test_get_version( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test get_version() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + version = train_controller.get_version() + + # Should return version number + assert isinstance(version, int) + + # Verify get_version was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "get_version" in engine_calls + + def test_update_weights( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test update_weights() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + meta = WeightUpdateMeta(type="disk", path="/tmp/weights") + train_controller.update_weights(meta) + + # Verify update_weights was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "update_weights" in engine_calls + + def test_save(self, train_controller, alloc_mode, ft_spec, scheduling_strategy): + """Test save() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + meta = SaveLoadMeta( + path="/tmp/checkpoint", weight_format="safetensors", with_optim=True + ) + train_controller.save(meta) + + # Verify save was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "save" in engine_calls + + def test_load(self, train_controller, alloc_mode, ft_spec, scheduling_strategy): + """Test load() method.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + meta = SaveLoadMeta( + path="/tmp/checkpoint", weight_format="safetensors", with_optim=True + ) + train_controller.load(meta) + + # Verify load was called on engines + engine_calls = [call[1] for call in train_controller.scheduler.engine_calls] + assert "load" in engine_calls + + def test_connect_engine_raises_not_implemented(self, train_controller): + """Test connect_engine() raises NotImplementedError.""" + from areal.api.engine_api import InferenceEngine + + mock_engine = Mock(spec=InferenceEngine) + meta = WeightUpdateMeta(type="nccl", alloc_mode=AllocationMode.from_str("d4")) + + with pytest.raises(NotImplementedError) as exc_info: + train_controller.connect_engine(mock_engine, meta) + + assert "not implemented for TrainController" in str(exc_info.value) + assert "RolloutController" in str(exc_info.value) + + +class TestTrainControllerCustomFunctionCall: + """Tests for custom_function_call orchestration.""" + + def test_custom_function_call_with_distributed_batch( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test custom_function_call with DistributedBatch argument.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + # Clear previous calls from initialization + train_controller.scheduler.engine_calls = [] + + batch = create_mock_distributed_batch(size=16) + result = train_controller.custom_function_call("forward", input_=batch) + + # Should split batch across DP groups and call only DP heads + assert result is not None + + # Count how many workers were called + worker_calls = len(train_controller.scheduler.engine_calls) + + # Should call all workers (DP heads get data, others get empty) + assert worker_calls == len(train_controller.workers) + + def test_custom_function_call_with_regular_args( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test custom_function_call with non-DistributedBatch arguments.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + # Clear previous calls + train_controller.scheduler.engine_calls = [] + + result = train_controller.custom_function_call("set_version", 5) + + # set_version returns None, which is expected - just verify it doesn't crash + # The key test is that all workers were called + assert result is None + + # Verify all workers were called + assert len(train_controller.scheduler.engine_calls) == len( + train_controller.workers + ) + + def test_custom_function_call_filters_dp_heads( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test custom_function_call only returns results from DP heads.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + batch = create_mock_distributed_batch(size=8) + train_controller.custom_function_call("train_batch", input_=batch) + + # Results should only come from DP head workers + # (verified by _merge_results receiving filtered results) + + +class TestTrainControllerEdgeCases: + """Tests for edge cases and error handling.""" + + def test_empty_distributed_batch( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test handling of empty DistributedBatch.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + empty_batch = DistributedBatchMemory.from_dict({}) + result = train_controller.forward(empty_batch) + + # Should handle empty batch gracefully + assert result is not None + + def test_create_process_group_requires_workers( + self, train_controller, parallel_strategy + ): + """Test create_process_group asserts workers exist.""" + # Don't initialize, so workers list is empty + with pytest.raises(AssertionError, match="Workers are not created"): + train_controller.create_process_group(parallel_strategy) + + def test_method_chaining( + self, train_controller, alloc_mode, ft_spec, scheduling_strategy + ): + """Test that train() and eval() support method chaining.""" + train_controller.initialize( + role="train_worker", + alloc_mode=alloc_mode, + ft_spec=ft_spec, + scheduling_strategy=scheduling_strategy, + ) + + # Should be able to chain calls + result = train_controller.train().eval().train() + assert result is train_controller diff --git a/areal/tests/utils.py b/areal/tests/utils.py index 1666eff49..3d61b45ed 100644 --- a/areal/tests/utils.py +++ b/areal/tests/utils.py @@ -1,5 +1,13 @@ +import asyncio import os +import random +from typing import Any +import torch + +from areal.api.engine_api import InferenceEngine +from areal.api.workflow_api import RolloutWorkflow +from areal.experimental.openai.types import InteractionWithTokenLogpReward from areal.utils import logging logger = logging.getLogger("areal.tests.utils") @@ -26,3 +34,28 @@ def get_bool_env_var(name: str, default: str = "false") -> bool: def is_in_ci(): return get_bool_env_var("AREAL_IS_IN_CI") + + +class TestWorkflow(RolloutWorkflow): + async def arun_episode( + self, engine: InferenceEngine, data: dict[str, Any] + ) -> dict[str, Any] | None | dict[str, InteractionWithTokenLogpReward]: + await asyncio.sleep(0.1) + prompt_len = random.randint(2, 8) + gen_len = random.randint(2, 8) + seqlen = prompt_len + gen_len + return dict( + input_ids=torch.randint( + 0, + 100, + ( + 1, + seqlen, + ), + ), + attention_mask=torch.ones(1, seqlen, dtype=torch.bool), + loss_mask=torch.tensor( + [0] * prompt_len + [1] * gen_len, dtype=torch.bool + ).unsqueeze(0), + rewards=torch.randn(1), + ) diff --git a/areal/utils/stats_logger.py b/areal/utils/stats_logger.py index b4cc1d4c8..51e9031f7 100644 --- a/areal/utils/stats_logger.py +++ b/areal/utils/stats_logger.py @@ -36,13 +36,13 @@ def init(self): if dist.is_initialized() and dist.get_rank() != 0: return + self.start_time = time.perf_counter() + # wandb init, connect to remote wandb host if self.config.wandb.wandb_base_url: os.environ["WANDB_API_KEY"] = self.config.wandb.wandb_api_key if self.config.wandb.wandb_api_key: os.environ["WANDB_BASE_URL"] = self.config.wandb.wandb_base_url - self.start_time = time.perf_counter() - # wandb init, connect to remote wandb host if self.config.wandb.mode != "disabled": wandb.login() diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index dc090d39a..084c6a343 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -1,4 +1,5 @@ import asyncio +import importlib import os import uuid from collections.abc import Callable @@ -47,7 +48,7 @@ def __init__( self, reward_fn: Callable[..., Any], gconfig: GenerationHyperparameters, - tokenizer: PreTrainedTokenizerFast, + tokenizer: PreTrainedTokenizerFast | str, enable_thinking: bool = False, rollout_stat_scope: str = "rollout", dump_dir: str | None = None, @@ -138,10 +139,26 @@ async def _collect_samples( async def arun_episode( self, engine: InferenceEngine, data: dict[str, Any] ) -> dict[str, torch.Tensor]: + # NOTE: tokenizer and reward_fn are not jsonifiable for remote execution, + # so we need to load it eagerly during execution. + if isinstance(self.tokenizer, str): + from areal.utils.hf_utils import load_hf_tokenizer + + tokenizer = load_hf_tokenizer(self.tokenizer) + if tokenizer.pad_token_id not in self.gconfig.stop_token_ids: + self.gconfig.stop_token_ids.append(tokenizer.pad_token_id) + if tokenizer.eos_token_id not in self.gconfig.stop_token_ids: + self.gconfig.stop_token_ids.append(tokenizer.eos_token_id) + self.tokenizer = tokenizer + + if isinstance(self.reward_fn, str): + module_path, fname = self.reward_fn.rsplit(".", 1) + module = importlib.import_module(module_path) + self.reward_fn = getattr(module, fname) + self.async_reward_fn = AsyncRewardWrapper(self.reward_fn) + input_ids = self.get_input_ids_fn( - self.data_extract_prompt_fn(data), - self.tokenizer, - self.enable_thinking, + self.data_extract_prompt_fn(data), self.tokenizer, self.enable_thinking ) n_samples = self.gconfig.n_samples req = ModelRequest( diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 8de0add4a..1dbcb2600 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -75,6 +75,8 @@ For detailed examples, see the experiment configurations in the `examples/` dire - [MegatronEngine Configuration](section-megatron-engine) - [PerfTracer Configuration](section-perf-tracer) - [Scheduler Configuration](section-scheduler) +- [Scheduling Specification](section-scheduling) +- [SchedulingStrategy](section-scheduling-strategy) - [SessionTracer Configuration](section-session-tracer) ______________________________________________________________________ @@ -316,57 +318,59 @@ Configuration for model optimization during training. Configuration for PPO actor model, a subclass of a TrainEngine. -| Parameter | Type | Default | Description | -| --------------------------- | ------------------------------------------------- | --------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"disk"` | - | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `group_size` | integer | `1` | Number of sequences in each group | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | -| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | -| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | -| `temperature` | float | `1.0` | Temperature during generation. | -| `m2_threshold` | float \| None | `None` | The second momentum threshold for M2PO. | -| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | -| `reward_scaling` | float | `1.0` | Reward scaling factor | -| `reward_bias` | float | `0.0` | Reward bias | -| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | -| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | -| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | -| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | -| `discount` | float | `1.0` | Discount factor for future rewards | -| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | -| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | -| `kl_ctl` | float | `0.1` | KL divergence coefficient | -| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | -| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | -| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | -| `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | -| `importance_sampling_level` | string | `"token"` | Level at which to compute importance sampling ratios. 'token': per-token ratios (standard PPO). 'sequence': sequence-level geometric mean of per-token ratios (GSPO). **Choices:** `token`, `sequence` | -| `dynamic_sampling` | boolean | `False` | Enable dynamic sampling (within DAPO). If enabled, groups with the same reward will be masked out. Note that enabling this option will lead to variable batch sizes. If you want to use a constant batch size with dynamic filtering, you should use the `should_accept_fn` parameter in `prepare_batch`. | -| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | -| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | -| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | +| Parameter | Type | Default | Description | +| --------------------------- | --------------------------------------------------- | --------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"disk"` | - | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | - | +| `group_size` | integer | `1` | Number of sequences in each group | +| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | +| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | +| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | +| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | +| `temperature` | float | `1.0` | Temperature during generation. | +| `m2_threshold` | float \| None | `None` | The second momentum threshold for M2PO. | +| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | +| `reward_scaling` | float | `1.0` | Reward scaling factor | +| `reward_bias` | float | `0.0` | Reward bias | +| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | +| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | +| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | +| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | +| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | +| `discount` | float | `1.0` | Discount factor for future rewards | +| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | +| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | +| `kl_ctl` | float | `0.1` | KL divergence coefficient | +| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | +| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | +| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | +| `behav_imp_weight_cap` | float \| None | `None` | Filter out tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing loss. Must be > 1.0. use_decoupled_loss must be true. | +| `importance_sampling_level` | string | `"token"` | Level at which to compute importance sampling ratios. 'token': per-token ratios (standard PPO). 'sequence': sequence-level geometric mean of per-token ratios (GSPO). **Choices:** `token`, `sequence` | +| `dynamic_sampling` | boolean | `False` | Enable dynamic sampling (within DAPO). If enabled, groups with the same reward will be masked out. Note that enabling this option will lead to variable batch sizes. If you want to use a constant batch size with dynamic filtering, you should use the `should_accept_fn` parameter in `prepare_batch`. | +| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | +| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | +| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | (section-ppo-critic)= @@ -374,32 +378,34 @@ Configuration for PPO actor model, a subclass of a TrainEngine. Configuration for PPO critic model, a subclass of a TrainEngine. -| Parameter | Type | Default | Description | -| ------------------------ | ------------------------------------------------- | --------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"disk"` | - | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.5` | Clipping factor for value loss | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | +| Parameter | Type | Default | Description | +| ------------------------ | --------------------------------------------------- | --------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"disk"` | - | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | - | +| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | +| `eps_clip` | float | `0.5` | Clipping factor for value loss | +| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | (section-train-engine)= @@ -407,29 +413,31 @@ Configuration for PPO critic model, a subclass of a TrainEngine. Core configuration for model training, including optimization and backend settings. -| Parameter | Type | Default | Description | -| ------------------------ | ------------------------------------------------- | --------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"disk"` | - | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| Parameter | Type | Default | Description | +| ------------------------ | --------------------------------------------------- | --------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. **Choices:** `flash_attention_2` | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"disk"` | - | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | train engine schedule specs | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | - | (section-generation-hyperparameters)= @@ -457,21 +465,23 @@ Controls text generation behavior for rollout. Configuration for inference servers, including offpolicyness control. -| Parameter | Type | Default | Description | -| ------------------------- | --------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string \| None | `None` | - | -| `trial_name` | string \| None | `None` | - | -| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | -| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | -| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | -| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | -| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | -| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | -| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | -| `setup_timeout` | float | `120.0` | Timeout in seconds of connecting to remote servers or launching local servers. | -| `request_timeout` | float | `3600` | Timeout for HTTP requests. | -| `request_retries` | integer | `3` | Number of retries for failed requests. | -| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | +| Parameter | Type | Default | Description | +| ------------------------- | --------------------------------------------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `experiment_name` | string \| None | `None` | - | +| `trial_name` | string \| None | `None` | - | +| `max_concurrent_rollouts` | integer \| None | `None` | Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size. | +| `queue_size` | integer \| None | `None` | Input/Output queue size for async rollout. | +| `consumer_batch_size` | integer | `1` | Batch size for consuming rollouts from the queue. | +| `max_head_offpolicyness` | integer | `0` | Maximum off-policyness for the head. If the current version is more than this many versions behind, the request will not be accepted. | +| `enable_rollout_tracing` | boolean | `False` | Whether to output verbose tracing messages for each generation request. | +| `check_trajectory_format` | boolean | `False` | Whether to check the format of produced trajectories of a customized workflow. Useful when debugging the workflow in isolation. Should be False during RL training. | +| `schedule_policy` | string | `"round_robin"` | Request scheduling policy **Choices:** `round_robin` | +| `setup_timeout` | float | `120.0` | Timeout in seconds of connecting to remote servers or launching local servers. | +| `request_timeout` | float | `3600` | Timeout for HTTP requests. | +| `request_retries` | integer | `3` | Number of retries for failed requests. | +| `pause_grace_period` | float | `0.0` | The grace period after calling /pause_generation. Wait until all requests have been dropped. | +| `scheduling_spec` | [`SchedulingSpec`](section-scheduling) | **Required** | inference engine schedule specs | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | - | (section-sg-lang)= @@ -832,6 +842,40 @@ Configuration for worker scheduling. Used in the single-controller mode. Experim | `reward_model_path` | string | `""` | - | | `reward_model_service_url` | string | `"http://localhost:30000/classify"` | - | +(section-scheduling)= + +## Scheduling Specification + +Configuration class: SchedulingSpec + +| Parameter | Type | Default | Description | +| ------------ | -------------- | ------------ | ------------------------------------------------------------------------ | +| `cpu` | integer | `0` | Number of CPU cores required | +| `gpu` | integer | `0` | Number of GPU units required | +| `mem` | integer | `0` | Amount of memory (GB) required | +| `port_count` | integer | `2` | Number of ports to expose | +| `image` | string | `""` | Docker/Singularity container image to use | +| `type` | string | `"worker"` | Task type (e.g., worker, engine) **Choices:** `worker`, `engine` | +| `env_vars` | `dict` | **Required** | Environment variables for the container | +| `cmd` | string \| None | `None` | Command to execute inside the container. Defaults to AReaL's RPC server. | +| `nodelist` | string \| None | `None` | - | +| `exclude` | string \| None | `None` | - | +| `partition` | string \| None | `None` | - | +| `time_limit` | string \| None | `None` | - | +| `begin` | string \| None | `None` | - | +| `deadline` | string \| None | `None` | - | + +(section-scheduling-strategy)= + +## SchedulingStrategy + +Configuration class: SchedulingStrategy + +| Parameter | Type | Default | Description | +| --------- | -------------- | -------------- | ----------------------------------------- | +| `type` | string | `"separation"` | - **Choices:** `separation`, `colocation` | +| `target` | string \| None | `None` | The target role to be colocated with | + (section-session-tracer)= ## SessionTracer Configuration diff --git a/examples/single-controller/gsm8k_grpo.py b/examples/single-controller/gsm8k_grpo.py new file mode 100644 index 000000000..415028966 --- /dev/null +++ b/examples/single-controller/gsm8k_grpo.py @@ -0,0 +1,216 @@ +import os +import sys + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import GRPOConfig, SGLangConfig, load_expr_config +from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta +from areal.controller.rollout_controller import RolloutController +from areal.controller.train_controller import TrainController +from areal.dataset import get_custom_dataset +from areal.engine.ppo.actor import FSDPPPOActor +from areal.engine.sglang_remote import RemoteSGLangEngine +from areal.scheduler.local import LocalScheduler +from areal.utils import stats_tracker +from areal.utils.data import ( + cycle_dataloader, +) +from areal.utils.dataloader import create_dataloader +from areal.utils.device import log_gpu_stats +from areal.utils.evaluator import Evaluator +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.recover import RecoverHandler +from areal.utils.saver import Saver +from areal.utils.stats_logger import StatsLogger + + +def main(args): + config, _ = load_expr_config(args, GRPOConfig) + config: GRPOConfig + + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + # Create dataset and dataloaders + train_dataset = get_custom_dataset( + split="train", dataset_config=config.train_dataset, tokenizer=tokenizer + ) + + train_dataloader = create_dataloader( + train_dataset, + rank=0, + world_size=1, + dataset_config=config.train_dataset, + ) + + ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=len(train_dataloader) * config.train_dataset.batch_size, + train_batch_size=config.train_dataset.batch_size, + ) + + # Initialize scheduler + scheduler = LocalScheduler(exp_config=config) + + # Initialize train controller + allocation_mode = AllocationMode.from_str(config.allocation_mode) + actor = TrainController(FSDPPPOActor, config=config.actor, scheduler=scheduler) + actor.initialize( + role="actor", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None + ) + + # Initialize inference engine + rollout = RolloutController( + RemoteSGLangEngine, config=config.rollout, scheduler=scheduler + ) + rollout.initialize( + role="rollout", + alloc_mode=allocation_mode, + engine_args=SGLangConfig.build_args( + sglang_config=config.sglang, + tp_size=allocation_mode.gen.tp_size, + base_gpu_id=0, + ), + ) + + weight_update_meta = WeightUpdateMeta.from_disk( + experiment_name=config.experiment_name, + trial_name=config.trial_name, + file_root=config.cluster.fileroot, + ) + actor.connect_engine(rollout, weight_update_meta) + + ref = None + if config.actor.kl_ctl > 0 and config.ref is not None: + ref = TrainController(FSDPPPOActor, config=config.ref, scheduler=scheduler) + ref.initialize( + role="ref", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None + ) + + # Run training. + saver = Saver(config.saver, ft_spec) + stats_logger = StatsLogger(config, ft_spec) + evaluator = Evaluator(config.evaluator, ft_spec) + + recover_handler = RecoverHandler(config.recover, ft_spec) + + try: + recover_info = recover_handler.load( + actor, + saver, + evaluator, + stats_logger, + train_dataloader, + inference_engine=rollout, + weight_update_meta=weight_update_meta, + ) + start_step = ( + recover_info.last_step_info.next().global_step + if recover_info is not None + else 0 + ) + + total_epochs = config.total_train_epochs + steps_per_epoch = len(train_dataloader) + max_steps = total_epochs * steps_per_epoch + + data_generator = cycle_dataloader(train_dataloader) + for global_step in range(start_step, max_steps): + epoch = global_step // steps_per_epoch + step = global_step % steps_per_epoch + step_info = StepInfo( + global_step=global_step, + epoch=epoch, + epoch_step=step, + steps_per_epoch=steps_per_epoch, + ) + + with stats_tracker.record_timing("rollout"): + if config.async_training: + batch = actor.prepare_batch( + train_dataloader, + workflow_path="areal.workflow.rlvr.RLVRWorkflow", + workflow_kwargs=dict( + reward_fn="areal.reward.gsm8k.gsm8k_reward_fn", + gconfig=config.gconfig, + tokenizer=config.tokenizer_path, + enable_thinking=False, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), + "generated", + ), + ), + ) + else: + batch = actor.rollout_batch( + next(data_generator), + workflow_path="areal.workflow.rlvr.RLVRWorkflow", + workflow_kwargs=dict( + reward_fn="areal.reward.gsm8k.gsm8k_reward_fn", + gconfig=config.gconfig, + tokenizer=config.tokenizer_path, + enable_thinking=False, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), + "generated", + ), + ), + ) + + if config.actor.recompute_logprob or config.actor.use_decoupled_loss: + with stats_tracker.record_timing("recompute_logp"): + logp = actor.compute_logp(batch) + batch["prox_logp"] = logp + log_gpu_stats("recompute logp") + + if ref is not None: + with stats_tracker.record_timing("ref_logp"): + batch["ref_logp"] = ref.compute_logp(batch) + log_gpu_stats("ref logp") + + with stats_tracker.record_timing("compute_advantage"): + batch = actor.compute_advantages(batch) + log_gpu_stats("compute advantages") + + with stats_tracker.record_timing("train_step"): + actor.ppo_update(batch) + actor.step_lr_scheduler() + log_gpu_stats("ppo update") + + # pause inference for updating weights, save, and evaluation + rollout.pause() + + with stats_tracker.record_timing("update_weights"): + actor.update_weights(weight_update_meta) + + actor.set_version(global_step + 1) + rollout.set_version(global_step + 1) + + with stats_tracker.record_timing("save"): + saver.save(actor, epoch, step, global_step, tokenizer=tokenizer) + + with stats_tracker.record_timing("checkpoint_for_recover"): + recover_handler.dump( + actor, + step_info, + saver, + evaluator, + stats_logger, + train_dataloader, + tokenizer=tokenizer, + ) + + # Upload statistics to the logger (e.g., wandb) + stats_logger.commit(epoch, step, global_step, actor.export_stats()) + + # Resume rollout + rollout.resume() + + finally: + stats_logger.close() + rollout.destroy() + if ref is not None: + ref.destroy() + actor.destroy() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/single-controller/gsm8k_grpo.yaml b/examples/single-controller/gsm8k_grpo.yaml new file mode 100644 index 000000000..63107f19b --- /dev/null +++ b/examples/single-controller/gsm8k_grpo.yaml @@ -0,0 +1,171 @@ +experiment_name: gsm8k-grpo +trial_name: trial0 + +seed: 1 +total_train_epochs: 10 +tokenizer_path: ${actor.path} +async_training: true + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /storage/openpsi/experiments + name_resolve: + type: nfs + nfs_record_root: /storage/openpsi/name_resolve + +allocation_mode: sglang.d4p1t1+d4p1t1 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 16 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: + type: worker + port_count: 1 + gpu: 1 + cmd: python3 -m areal.scheduler.rpc.rpc_server + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /storage/openpsi/models/Qwen__Qwen2.5-1.5B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + backend: fsdp + group_size: ${gconfig.n_samples} + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: 5.0 + dynamic_sampling: false + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + scheduling_spec: + type: worker + port_count: 1 + gpu: 1 + cmd: python3 -m areal.scheduler.rpc.rpc_server + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + backend: fsdp + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: + type: worker + port_count: 1 + gpu: 1 + cmd: python3 -m areal.scheduler.rpc.rpc_server + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +# datasets +train_dataset: + batch_size: 16 + shuffle: true + pin_memory: true + num_workers: 4 + path: /storage/openpsi/data/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 16 + shuffle: true + pin_memory: true + num_workers: 4 + path: /storage/openpsi/data/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +launcher: + inference_server_cpus_per_gpu: 4 + inference_server_mem_per_gpu: 32768 + trainer_cpus_per_gpu: 4 + trainer_mem_per_gpu: 32768 diff --git a/examples/single-controller/gsm8k_sft.py b/examples/single-controller/gsm8k_sft.py new file mode 100644 index 000000000..6bf231874 --- /dev/null +++ b/examples/single-controller/gsm8k_sft.py @@ -0,0 +1,148 @@ +import sys + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import SFTConfig, load_expr_config +from areal.api.io_struct import FinetuneSpec, StepInfo +from areal.controller.batch import DistributedBatchMemory +from areal.controller.train_controller import TrainController +from areal.dataset import get_custom_dataset +from areal.engine.sft.lm_engine import FSDPLMEngine +from areal.scheduler.local import LocalScheduler +from areal.utils import logging, stats_tracker +from areal.utils.data import ( + pad_sequences_to_tensors, +) +from areal.utils.dataloader import create_dataloader +from areal.utils.evaluator import Evaluator +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.recover import RecoverHandler +from areal.utils.saver import Saver +from areal.utils.stats_logger import StatsLogger + +logger = logging.getLogger("Trainer") + + +def main(args): + config, _ = load_expr_config(args, SFTConfig) + config: SFTConfig + + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + # Create dataset and dataloaders + train_dataset = get_custom_dataset( + split="train", dataset_config=config.train_dataset, tokenizer=tokenizer + ) + valid_dataset = get_custom_dataset( + split="test", dataset_config=config.valid_dataset, tokenizer=tokenizer + ) + train_dataloader = create_dataloader( + train_dataset, + rank=0, + world_size=1, + dataset_config=config.train_dataset, + collate_fn=pad_sequences_to_tensors, + ) + valid_dataloader = create_dataloader( + valid_dataset, + rank=0, + world_size=1, + dataset_config=config.valid_dataset, + collate_fn=pad_sequences_to_tensors, + ) + + ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=len(train_dataloader) * config.train_dataset.batch_size, + train_batch_size=config.train_dataset.batch_size, + ) + + # Initialize scheduler + scheduler = LocalScheduler( + fileroot=config.cluster.fileroot, + experiment_name=config.experiment_name, + trial_name=config.trial_name, + ) + # Initialize train controller + allocation_mode = AllocationMode.from_str(config.allocation_mode) + engine = TrainController(FSDPLMEngine, config=config.model, scheduler=scheduler) + engine.initialize( + role="default", + alloc_mode=allocation_mode, + ft_spec=ft_spec, + addr=None, + ) + + saver = Saver(config.saver, ft_spec) + stats_logger = StatsLogger(config, ft_spec) + evaluator = Evaluator(config.evaluator, ft_spec) + + recover_handler = RecoverHandler(config.recover, ft_spec) + + try: + # Run training. + recover_info = recover_handler.load( + engine, + saver, + evaluator, + stats_logger, + train_dataloader, + ) + start_step = ( + recover_info.last_step_info.next().global_step + if recover_info is not None + else 0 + ) + + total_epochs = config.total_train_epochs + + global_step = 0 + for epoch in range(total_epochs): + for step, data in enumerate(train_dataloader): + if global_step < start_step: + global_step += 1 + continue + step_info = StepInfo( + global_step=global_step, + epoch=epoch, + epoch_step=step, + steps_per_epoch=len(train_dataloader), + ) + + with ( + stats_tracker.record_timing("train_step"), + ): + engine.train_lm(DistributedBatchMemory.from_dict(data)) + engine.step_lr_scheduler() + + with stats_tracker.record_timing("save"): + saver.save(engine, epoch, step, global_step, tokenizer=tokenizer) + + with stats_tracker.record_timing("checkpoint_for_recover"): + recover_handler.dump( + engine, + step_info, + saver, + evaluator, + stats_logger, + train_dataloader, + tokenizer=tokenizer, + ) + + with stats_tracker.record_timing("eval"): + + def evaluate_fn(): + for data in valid_dataloader: + engine.evaluate_lm(DistributedBatchMemory.from_dict(data)) + + evaluator.evaluate(evaluate_fn, epoch, step, global_step) + + stats_logger.commit(epoch, step, global_step, engine.export_stats()) + global_step += 1 + + finally: + stats_logger.close() + engine.destroy() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/single-controller/gsm8k_sft.yaml b/examples/single-controller/gsm8k_sft.yaml new file mode 100644 index 000000000..ee3ea0726 --- /dev/null +++ b/examples/single-controller/gsm8k_sft.yaml @@ -0,0 +1,90 @@ +experiment_name: gsm8k-sft +trial_name: trial0 + +seed: 1 +total_train_epochs: 1 +tokenizer_path: ${model.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /storage/openpsi/experiments + name_resolve: + type: nfs + nfs_record_root: /storage/openpsi/name_resolve + +allocation_mode: d8p1t1 + +model: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /storage/openpsi/models/Qwen__Qwen3-1.7B + init_from_scratch: false + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 4096 + optimizer: + type: adam + lr: 2e-5 + weight_decay: 0.05 + beta1: 0.9 + beta2: 0.95 + eps: 1e-5 + lr_scheduler_type: cosine + gradient_clipping: 1.0 + backend: fsdp + scheduling_spec: + type: worker + port_count: 1 + gpu: 1 + cmd: python3 -m areal.scheduler.rpc.rpc_server + +train_dataset: + batch_size: 128 + shuffle: true + pin_memory: true + num_workers: 4 + path: /storage/openpsi/data/gsm8k + type: sft + +valid_dataset: + batch_size: 128 + shuffle: true + pin_memory: true + num_workers: 4 + path: /storage/openpsi/data/gsm8k + type: sft + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: 1 + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled