Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions scripts/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,9 @@ WORKDIR /workspace

RUN apt update && apt install -y \
build-essential \
curl \
git \
wget \
vim \
tmux \
python3 \
python3-pip \
python3-dev \
python3-packaging \
curl git wget vim tmux net-tools \
python3 python3-pip python3-dev python3-packaging \
libomp-dev infiniband-diags libibverbs-dev librdmacm-dev rdma-core perftest \
&& rm -rf /var/lib/apt/lists/* \
&& ln -sf /usr/bin/python3 /usr/bin/python

Expand Down
2 changes: 1 addition & 1 deletion trinity/buffer/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_actor(cls, storage_config: StorageConfig, config: BufferConfig):
ray.remote(cls)
.options(
name=f"queue-{storage_config.name}",
namespace=ray.get_runtime_context().namespace,
namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace,
get_if_exists=True,
)
.remote(storage_config, config)
Expand Down
4 changes: 2 additions & 2 deletions trinity/buffer/ray_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
ray.remote(cls)
.options(
name=f"sql-{storage_config.name}",
namespace=ray.get_runtime_context().namespace,
namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace,
get_if_exists=True,
)
.remote(storage_config, config)
Expand Down Expand Up @@ -171,7 +171,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
ray.remote(cls)
.options(
name=f"json-{storage_config.name}",
namespace=ray.get_runtime_context().namespace,
namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace,
get_if_exists=True,
)
.remote(storage_config, config)
Expand Down
3 changes: 3 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ class StorageConfig:
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
workflow_args: dict = field(default_factory=dict)

# get storage from existing experiment
ray_namespace: Optional[str] = None

# ! DO NOT SET, automatically set from algorithm.algorithm_type
algorithm_type: Optional[str] = None

Expand Down
1 change: 1 addition & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
self.actor_rollout_ref.rollout.n = config.algorithm.repeat_times
self.critic.ppo_mini_batch_size = config.buffer.batch_size
self.critic.rollout_n = self.actor_rollout_ref.rollout.n
self.critic.synchronizer = config.synchronizer

if config.trainer.actor_grad_clip is not None:
self.actor_rollout_ref.actor.grad_clip = config.trainer.actor_grad_clip
Expand Down
7 changes: 6 additions & 1 deletion trinity/trainer/verl/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import warnings
from dataclasses import asdict
from datetime import timedelta

import psutil
import torch
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(self, config: DictConfig, role: str):
backend="cpu:gloo,cuda:nccl" if is_cuda_available else "cpu:gloo,npu:hccl",
rank=rank,
world_size=world_size,
timeout=timedelta(seconds=self.config.synchronizer.sync_timeout),
)

# build device mesh for FSDP
Expand Down Expand Up @@ -832,7 +834,10 @@ def __init__(self, config):
import torch.distributed

if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl")
torch.distributed.init_process_group(
backend="nccl" if is_cuda_available else "hccl",
timeout=timedelta(seconds=self.config.synchronizer.sync_timeout),
)
self.config = config

# build device mesh for Ulysses Sequence Parallel
Expand Down
5 changes: 3 additions & 2 deletions trinity/utils/dlc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def setup_ray_cluster(namespace: str):
ray.init(namespace=namespace, ignore_reinit_error=True)
else:
if is_master:
cmd = f"ray start --head --port={env_vars['MASTER_PORT']}"
cmd = f"ray start --head --port={env_vars['MASTER_PORT']} --node-ip-address={env_vars['MASTER_ADDR']}"
else:
cmd = f"ray start --address={env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}"
ret = subprocess.run(cmd, shell=True, capture_output=True)
Expand All @@ -86,6 +86,7 @@ def setup_ray_cluster(namespace: str):
sys.exit(1)

wait_for_ray_setup()
time.sleep(5)
ray.init(
address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}",
namespace=namespace,
Expand All @@ -95,7 +96,7 @@ def setup_ray_cluster(namespace: str):
# master wait for worker nodes to join
wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"])
else:
# woker wait on the cluster status actor
# worker wait on the cluster status actor
cluster_status = (
ray.remote(ClusterStatus)
.options(
Expand Down