Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def run(config_path: str):
data_config.dj_config_path or data_config.dj_process_desc
):
activate_data_module(data_config.data_workflow_url, config_path)
ray.init()
if not ray.is_initialized():
ray.init()
if config.mode == "explore":
explore(config)
elif config.mode == "train":
Expand Down
10 changes: 8 additions & 2 deletions trinity/common/models/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.distributed
from vllm.worker.worker import Worker

from trinity.utils.distributed import init_process_group
from trinity.utils.distributed import init_process_group, is_ipv6_address
from trinity.utils.log import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -43,9 +43,15 @@ def init_process_group(
)
self._weight_update_rank = torch.distributed.get_rank() + rank_offset

if is_ipv6_address(master_address):
# using tcp://ipv6:port will lead to ValueError
init_method = f"tcp://[{master_address}]:{master_port}"
else:
init_method = f"tcp://{master_address}:{master_port}"

self._model_update_group = init_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
init_method=init_method,
world_size=world_size,
rank=self._weight_update_rank,
group_name=group_name,
Expand Down
10 changes: 8 additions & 2 deletions trinity/trainer/verl/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

from trinity.common.constants import AlgorithmType
from trinity.utils.distributed import init_process_group
from trinity.utils.distributed import init_process_group, is_ipv6_address

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN"))
Expand Down Expand Up @@ -592,9 +592,15 @@ def setup_weight_sync_group(self):
setup_ref = explorer.setup_weight_sync_group.remote(
master_address, master_port, self.state_dict_meta
)
if is_ipv6_address(master_address):
# using tcp://ipv6:port will lead to ValueError
init_method = f"tcp://[{master_address}]:{master_port}"
else:
init_method = f"tcp://{master_address}:{master_port}"

self._model_update_group = init_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
init_method=init_method,
world_size=world_size,
rank=0,
group_name=group_name,
Expand Down
9 changes: 9 additions & 0 deletions trinity/utils/distributed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""For distributed training with multiple process groups."""
import ipaddress
from datetime import timedelta
from typing import Any, Optional, Union

Expand All @@ -15,6 +16,14 @@
)


def is_ipv6_address(ip_str: str) -> bool:
try:
ip = ipaddress.ip_address(ip_str)
return isinstance(ip, ipaddress.IPv6Address)
except ValueError:
return False


def init_process_group(
backend: Union[str, Backend] = None,
init_method: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import numpy as np
import pandas as pd
import wandb
from torch.utils.tensorboard import SummaryWriter

import wandb
from trinity.common.constants import MonitorType
from trinity.utils.log import get_logger

Expand Down
Loading