Skip to content

Commit 28de41a

Browse files
authored
Fix bugs in multi-node environments (#103)
1 parent 670e49a commit 28de41a

File tree

7 files changed

+19
-15
lines changed

7 files changed

+19
-15
lines changed

scripts/docker/Dockerfile

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,9 @@ WORKDIR /workspace
1111

1212
RUN apt update && apt install -y \
1313
build-essential \
14-
curl \
15-
git \
16-
wget \
17-
vim \
18-
tmux \
19-
python3 \
20-
python3-pip \
21-
python3-dev \
22-
python3-packaging \
14+
curl git wget vim tmux net-tools \
15+
python3 python3-pip python3-dev python3-packaging \
16+
libomp-dev infiniband-diags libibverbs-dev librdmacm-dev rdma-core perftest \
2317
&& rm -rf /var/lib/apt/lists/* \
2418
&& ln -sf /usr/bin/python3 /usr/bin/python
2519

trinity/buffer/queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def get_actor(cls, storage_config: StorageConfig, config: BufferConfig):
8989
ray.remote(cls)
9090
.options(
9191
name=f"queue-{storage_config.name}",
92-
namespace=ray.get_runtime_context().namespace,
92+
namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace,
9393
get_if_exists=True,
9494
)
9595
.remote(storage_config, config)

trinity/buffer/ray_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
5757
ray.remote(cls)
5858
.options(
5959
name=f"sql-{storage_config.name}",
60-
namespace=ray.get_runtime_context().namespace,
60+
namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace,
6161
get_if_exists=True,
6262
)
6363
.remote(storage_config, config)
@@ -171,7 +171,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
171171
ray.remote(cls)
172172
.options(
173173
name=f"json-{storage_config.name}",
174-
namespace=ray.get_runtime_context().namespace,
174+
namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace,
175175
get_if_exists=True,
176176
)
177177
.remote(storage_config, config)

trinity/common/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ class StorageConfig:
9393
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
9494
workflow_args: dict = field(default_factory=dict)
9595

96+
# get storage from existing experiment
97+
ray_namespace: Optional[str] = None
98+
9699
# ! DO NOT SET, automatically set from algorithm.algorithm_type
97100
algorithm_type: Optional[str] = None
98101

trinity/common/verl_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
315315
self.actor_rollout_ref.rollout.n = config.algorithm.repeat_times
316316
self.critic.ppo_mini_batch_size = config.buffer.batch_size
317317
self.critic.rollout_n = self.actor_rollout_ref.rollout.n
318+
self.critic.synchronizer = config.synchronizer
318319

319320
if config.trainer.actor_grad_clip is not None:
320321
self.actor_rollout_ref.actor.grad_clip = config.trainer.actor_grad_clip

trinity/trainer/verl/fsdp_workers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222
import warnings
2323
from dataclasses import asdict
24+
from datetime import timedelta
2425

2526
import psutil
2627
import torch
@@ -96,6 +97,7 @@ def __init__(self, config: DictConfig, role: str):
9697
backend="cpu:gloo,cuda:nccl" if is_cuda_available else "cpu:gloo,npu:hccl",
9798
rank=rank,
9899
world_size=world_size,
100+
timeout=timedelta(seconds=self.config.synchronizer.sync_timeout),
99101
)
100102

101103
# build device mesh for FSDP
@@ -832,7 +834,10 @@ def __init__(self, config):
832834
import torch.distributed
833835

834836
if not torch.distributed.is_initialized():
835-
torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl")
837+
torch.distributed.init_process_group(
838+
backend="nccl" if is_cuda_available else "hccl",
839+
timeout=timedelta(seconds=self.config.synchronizer.sync_timeout),
840+
)
836841
self.config = config
837842

838843
# build device mesh for Ulysses Sequence Parallel

trinity/utils/dlc_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def setup_ray_cluster(namespace: str):
7474
ray.init(namespace=namespace, ignore_reinit_error=True)
7575
else:
7676
if is_master:
77-
cmd = f"ray start --head --port={env_vars['MASTER_PORT']}"
77+
cmd = f"ray start --head --port={env_vars['MASTER_PORT']} --node-ip-address={env_vars['MASTER_ADDR']}"
7878
else:
7979
cmd = f"ray start --address={env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}"
8080
ret = subprocess.run(cmd, shell=True, capture_output=True)
@@ -86,6 +86,7 @@ def setup_ray_cluster(namespace: str):
8686
sys.exit(1)
8787

8888
wait_for_ray_setup()
89+
time.sleep(5)
8990
ray.init(
9091
address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}",
9192
namespace=namespace,
@@ -95,7 +96,7 @@ def setup_ray_cluster(namespace: str):
9596
# master wait for worker nodes to join
9697
wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"])
9798
else:
98-
# woker wait on the cluster status actor
99+
# worker wait on the cluster status actor
99100
cluster_status = (
100101
ray.remote(ClusterStatus)
101102
.options(

0 commit comments

Comments
 (0)