Skip to content

Commit 24c407e

Browse files
authored
Fix init_process_group failed when using ipv6 master address (#24)
1 parent a3dfe19 commit 24c407e

File tree

4 files changed

+27
-5
lines changed

4 files changed

+27
-5
lines changed

trinity/cli/launcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ def run(config_path: str):
129129
data_config.dj_config_path or data_config.dj_process_desc
130130
):
131131
activate_data_module(data_config.data_workflow_url, config_path)
132-
ray.init()
132+
if not ray.is_initialized():
133+
ray.init()
133134
if config.mode == "explore":
134135
explore(config)
135136
elif config.mode == "train":

trinity/common/models/vllm_worker.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.distributed
66
from vllm.worker.worker import Worker
77

8-
from trinity.utils.distributed import init_process_group
8+
from trinity.utils.distributed import init_process_group, is_ipv6_address
99
from trinity.utils.log import get_logger
1010

1111
logger = get_logger(__name__)
@@ -43,9 +43,15 @@ def init_process_group(
4343
)
4444
self._weight_update_rank = torch.distributed.get_rank() + rank_offset
4545

46+
if is_ipv6_address(master_address):
47+
# using tcp://ipv6:port will lead to ValueError
48+
init_method = f"tcp://[{master_address}]:{master_port}"
49+
else:
50+
init_method = f"tcp://{master_address}:{master_port}"
51+
4652
self._model_update_group = init_process_group(
4753
backend=backend,
48-
init_method=f"tcp://{master_address}:{master_port}",
54+
init_method=init_method,
4955
world_size=world_size,
5056
rank=self._weight_update_rank,
5157
group_name=group_name,

trinity/trainer/verl/fsdp_workers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
5252

5353
from trinity.common.constants import AlgorithmType
54-
from trinity.utils.distributed import init_process_group
54+
from trinity.utils.distributed import init_process_group, is_ipv6_address
5555

5656
logger = logging.getLogger(__file__)
5757
logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN"))
@@ -592,9 +592,15 @@ def setup_weight_sync_group(self):
592592
setup_ref = explorer.setup_weight_sync_group.remote(
593593
master_address, master_port, self.state_dict_meta
594594
)
595+
if is_ipv6_address(master_address):
596+
# using tcp://ipv6:port will lead to ValueError
597+
init_method = f"tcp://[{master_address}]:{master_port}"
598+
else:
599+
init_method = f"tcp://{master_address}:{master_port}"
600+
595601
self._model_update_group = init_process_group(
596602
backend=backend,
597-
init_method=f"tcp://{master_address}:{master_port}",
603+
init_method=init_method,
598604
world_size=world_size,
599605
rank=0,
600606
group_name=group_name,

trinity/utils/distributed.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
"""For distributed training with multiple process groups."""
3+
import ipaddress
34
from datetime import timedelta
45
from typing import Any, Optional, Union
56

@@ -15,6 +16,14 @@
1516
)
1617

1718

19+
def is_ipv6_address(ip_str: str) -> bool:
20+
try:
21+
ip = ipaddress.ip_address(ip_str)
22+
return isinstance(ip, ipaddress.IPv6Address)
23+
except ValueError:
24+
return False
25+
26+
1827
def init_process_group(
1928
backend: Union[str, Backend] = None,
2029
init_method: Optional[str] = None,

0 commit comments

Comments
 (0)