Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def run(config_path: str):
data_config.dj_config_path or data_config.dj_process_desc
):
activate_data_module(data_config.data_workflow_url, config_path)
ray.init()
if not ray.is_initialized():
ray.init()
if config.mode == "explore":
explore(config)
elif config.mode == "train":
Expand Down
10 changes: 8 additions & 2 deletions trinity/common/models/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.distributed
from vllm.worker.worker import Worker

from trinity.utils.distributed import init_process_group
from trinity.utils.distributed import init_process_group, is_ipv6_address
from trinity.utils.log import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -43,9 +43,15 @@ def init_process_group(
)
self._weight_update_rank = torch.distributed.get_rank() + rank_offset

if is_ipv6_address(master_address):
# using tcp://ipv6:port will lead to ValueError
init_method = f"tcp://[{master_address}]:{master_port}"
else:
init_method = f"tcp://{master_address}:{master_port}"

self._model_update_group = init_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
init_method=init_method,
world_size=world_size,
rank=self._weight_update_rank,
group_name=group_name,
Expand Down
10 changes: 8 additions & 2 deletions trinity/trainer/verl/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

from trinity.common.constants import AlgorithmType
from trinity.utils.distributed import init_process_group
from trinity.utils.distributed import init_process_group, is_ipv6_address

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN"))
Expand Down Expand Up @@ -592,9 +592,15 @@ def setup_weight_sync_group(self):
setup_ref = explorer.setup_weight_sync_group.remote(
master_address, master_port, self.state_dict_meta
)
if is_ipv6_address(master_address):
# using tcp://ipv6:port will lead to ValueError
init_method = f"tcp://[{master_address}]:{master_port}"
else:
init_method = f"tcp://{master_address}:{master_port}"

self._model_update_group = init_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
init_method=init_method,
world_size=world_size,
rank=0,
group_name=group_name,
Expand Down
9 changes: 9 additions & 0 deletions trinity/utils/distributed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""For distributed training with multiple process groups."""
import ipaddress
from datetime import timedelta
from typing import Any, Optional, Union

Expand All @@ -15,6 +16,14 @@
)


def is_ipv6_address(ip_str: str) -> bool:
try:
ip = ipaddress.ip_address(ip_str)
return isinstance(ip, ipaddress.IPv6Address)
except ValueError:
return False


def init_process_group(
backend: Union[str, Backend] = None,
init_method: Optional[str] = None,
Expand Down