diff --git a/scripts/docker/Dockerfile b/scripts/docker/Dockerfile index 61221f5983..5bf2b4a2e9 100644 --- a/scripts/docker/Dockerfile +++ b/scripts/docker/Dockerfile @@ -11,15 +11,9 @@ WORKDIR /workspace RUN apt update && apt install -y \ build-essential \ - curl \ - git \ - wget \ - vim \ - tmux \ - python3 \ - python3-pip \ - python3-dev \ - python3-packaging \ + curl git wget vim tmux net-tools \ + python3 python3-pip python3-dev python3-packaging \ + libomp-dev infiniband-diags libibverbs-dev librdmacm-dev rdma-core perftest \ && rm -rf /var/lib/apt/lists/* \ && ln -sf /usr/bin/python3 /usr/bin/python diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index 9cdd99a592..1037c1cc73 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -89,7 +89,7 @@ def get_actor(cls, storage_config: StorageConfig, config: BufferConfig): ray.remote(cls) .options( name=f"queue-{storage_config.name}", - namespace=ray.get_runtime_context().namespace, + namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, get_if_exists=True, ) .remote(storage_config, config) diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py index ba736b02bc..33a9d8633b 100644 --- a/trinity/buffer/ray_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -57,7 +57,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): ray.remote(cls) .options( name=f"sql-{storage_config.name}", - namespace=ray.get_runtime_context().namespace, + namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, get_if_exists=True, ) .remote(storage_config, config) @@ -171,7 +171,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): ray.remote(cls) .options( name=f"json-{storage_config.name}", - namespace=ray.get_runtime_context().namespace, + namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, get_if_exists=True, ) .remote(storage_config, config) diff --git a/trinity/common/config.py b/trinity/common/config.py index e6130395b2..540b004db4 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -93,6 +93,9 @@ class StorageConfig: rollout_args: GenerationConfig = field(default_factory=GenerationConfig) workflow_args: dict = field(default_factory=dict) + # get storage from existing experiment + ray_namespace: Optional[str] = None + # ! DO NOT SET, automatically set from algorithm.algorithm_type algorithm_type: Optional[str] = None diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 0d2d3bf8c1..27f95c1383 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -315,6 +315,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.actor_rollout_ref.rollout.n = config.algorithm.repeat_times self.critic.ppo_mini_batch_size = config.buffer.batch_size self.critic.rollout_n = self.actor_rollout_ref.rollout.n + self.critic.synchronizer = config.synchronizer if config.trainer.actor_grad_clip is not None: self.actor_rollout_ref.actor.grad_clip = config.trainer.actor_grad_clip diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 5e76375315..ecfa4ca3f3 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -21,6 +21,7 @@ import os import warnings from dataclasses import asdict +from datetime import timedelta import psutil import torch @@ -96,6 +97,7 @@ def __init__(self, config: DictConfig, role: str): backend="cpu:gloo,cuda:nccl" if is_cuda_available else "cpu:gloo,npu:hccl", rank=rank, world_size=world_size, + timeout=timedelta(seconds=self.config.synchronizer.sync_timeout), ) # build device mesh for FSDP @@ -832,7 +834,10 @@ def __init__(self, config): import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") + torch.distributed.init_process_group( + backend="nccl" if is_cuda_available else "hccl", + timeout=timedelta(seconds=self.config.synchronizer.sync_timeout), + ) self.config = config # build device mesh for Ulysses Sequence Parallel diff --git a/trinity/utils/dlc_utils.py b/trinity/utils/dlc_utils.py index 10f6f3b8bd..380e7d05b9 100644 --- a/trinity/utils/dlc_utils.py +++ b/trinity/utils/dlc_utils.py @@ -74,7 +74,7 @@ def setup_ray_cluster(namespace: str): ray.init(namespace=namespace, ignore_reinit_error=True) else: if is_master: - cmd = f"ray start --head --port={env_vars['MASTER_PORT']}" + cmd = f"ray start --head --port={env_vars['MASTER_PORT']} --node-ip-address={env_vars['MASTER_ADDR']}" else: cmd = f"ray start --address={env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}" ret = subprocess.run(cmd, shell=True, capture_output=True) @@ -86,6 +86,7 @@ def setup_ray_cluster(namespace: str): sys.exit(1) wait_for_ray_setup() + time.sleep(5) ray.init( address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", namespace=namespace, @@ -95,7 +96,7 @@ def setup_ray_cluster(namespace: str): # master wait for worker nodes to join wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"]) else: - # woker wait on the cluster status actor + # worker wait on the cluster status actor cluster_status = ( ray.remote(ClusterStatus) .options(