diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 59e0387172..9f9c799872 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -129,7 +129,8 @@ def run(config_path: str): data_config.dj_config_path or data_config.dj_process_desc ): activate_data_module(data_config.data_workflow_url, config_path) - ray.init() + if not ray.is_initialized(): + ray.init() if config.mode == "explore": explore(config) elif config.mode == "train": diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 13fd786a94..f42d843c3f 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -5,7 +5,7 @@ import torch.distributed from vllm.worker.worker import Worker -from trinity.utils.distributed import init_process_group +from trinity.utils.distributed import init_process_group, is_ipv6_address from trinity.utils.log import get_logger logger = get_logger(__name__) @@ -43,9 +43,15 @@ def init_process_group( ) self._weight_update_rank = torch.distributed.get_rank() + rank_offset + if is_ipv6_address(master_address): + # using tcp://ipv6:port will lead to ValueError + init_method = f"tcp://[{master_address}]:{master_port}" + else: + init_method = f"tcp://{master_address}:{master_port}" + self._model_update_group = init_process_group( backend=backend, - init_method=f"tcp://{master_address}:{master_port}", + init_method=init_method, world_size=world_size, rank=self._weight_update_rank, group_name=group_name, diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index ce3365354b..94f5939f53 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -51,7 +51,7 @@ from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager from trinity.common.constants import AlgorithmType -from trinity.utils.distributed import init_process_group +from trinity.utils.distributed import init_process_group, is_ipv6_address logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) @@ -592,9 +592,15 @@ def setup_weight_sync_group(self): setup_ref = explorer.setup_weight_sync_group.remote( master_address, master_port, self.state_dict_meta ) + if is_ipv6_address(master_address): + # using tcp://ipv6:port will lead to ValueError + init_method = f"tcp://[{master_address}]:{master_port}" + else: + init_method = f"tcp://{master_address}:{master_port}" + self._model_update_group = init_process_group( backend=backend, - init_method=f"tcp://{master_address}:{master_port}", + init_method=init_method, world_size=world_size, rank=0, group_name=group_name, diff --git a/trinity/utils/distributed.py b/trinity/utils/distributed.py index 898c763e28..5111b41449 100644 --- a/trinity/utils/distributed.py +++ b/trinity/utils/distributed.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """For distributed training with multiple process groups.""" +import ipaddress from datetime import timedelta from typing import Any, Optional, Union @@ -15,6 +16,14 @@ ) +def is_ipv6_address(ip_str: str) -> bool: + try: + ip = ipaddress.ip_address(ip_str) + return isinstance(ip, ipaddress.IPv6Address) + except ValueError: + return False + + def init_process_group( backend: Union[str, Backend] = None, init_method: Optional[str] = None,