Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions trinity/common/models/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.distributed

from trinity.utils.distributed import init_process_group, is_ipv6_address
from trinity.utils.distributed import init_process_group
from trinity.utils.log import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -38,20 +38,16 @@ def init_process_group(
f" > rank_offset={rank_offset}\n"
f" > world_size={world_size}"
)
if is_ipv6_address(master_address):
# using tcp://ipv6:port will lead to ValueError
init_method = f"tcp://[{master_address}]:{master_port}"
else:
init_method = f"tcp://{master_address}:{master_port}"

self._model_update_group = init_process_group(
host=master_address,
port=master_port,
group_name=group_name,
backend=backend,
init_method=init_method,
timeout=timeout,
world_size=world_size,
rank=self._weight_update_rank,
group_name=group_name,
)
torch.distributed.barrier(group=self._model_update_group)
logger.info("vLLM init_process_group finished.")
self._explorer_name = explorer_name
self._namespace = namespace
Expand All @@ -78,6 +74,6 @@ def update_weight(self):
weight = weight.type(self.model_config.dtype)
self.model_runner.model.load_weights(weights=[(name, weight)])
del weight
torch.distributed.barrier()
torch.distributed.barrier(group=self._model_update_group)
torch.cuda.synchronize()
torch.cuda.empty_cache()
20 changes: 19 additions & 1 deletion trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self, config: Config):
self.state_dict_meta = []
self.status = RunningStatus.RUNNING
self.logger.info("Finished initializing Explorer.")
self._ready_to_sync_condition = asyncio.Condition()

async def setup_weight_sync_group(
self, master_address: str, master_port: int, state_dict_meta: List = None
Expand Down Expand Up @@ -158,10 +159,28 @@ async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> No

async def _nccl_weights_update(self):
assert self.state_dict_meta is not None
async with self._ready_to_sync_condition:
try:
await asyncio.wait_for(
self._ready_to_sync_condition.wait_for(
lambda: self.status == RunningStatus.WAITING_SYNC,
),
timeout=self.config.synchronizer.sync_timeout,
)
except asyncio.TimeoutError as e:
self.logger.error(
f"Trainer is not ready for model weight sync in {self.config.synchronizer.sync_timeout} seconds."
)
raise e
await asyncio.gather(
*[model.sync_model.remote(self.explore_step_num) for model in self.models]
)

async def ready_to_sync(self):
async with self._ready_to_sync_condition:
self.status = RunningStatus.WAITING_SYNC
self._ready_to_sync_condition.notify_all()

async def prepare(self) -> None:
"""Preparation before running."""
if self.use_checkpoint_weights_update:
Expand Down Expand Up @@ -330,7 +349,6 @@ async def sync_weight(self) -> None:
"""Synchronize model weights."""
# call this method before training start to load the latest model weights
self.logger.info(f"Explorer sync weights at step {self.explore_step_num}.")
self.status = RunningStatus.WAITING_SYNC
if self.use_checkpoint_weights_update:
await self._checkpoint_weights_update()
else: # nccl weights update
Expand Down
3 changes: 2 additions & 1 deletion trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ def sync_weight(self) -> None:
if explorer_status == RunningStatus.STOPPED:
self.logger.warning("Explorer has already stopped. Skipping sync weight.")
return
ray.get(self.explorer_ref.ready_to_sync.remote())
self.engine.sync_weight()
self.logger.info(
f"Trainer synchronizing weights at step {self.engine.train_step_num} end."
)
self.engine.sync_weight()

def flush_log(self, step: int) -> None:
"""Flush the log of the current step."""
Expand Down
21 changes: 9 additions & 12 deletions trinity/trainer/verl/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@

from trinity.common.config import AlgorithmConfig
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod
from trinity.utils.distributed import init_process_group, is_ipv6_address
from trinity.utils.distributed import init_process_group

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
Expand Down Expand Up @@ -579,25 +579,20 @@ def setup_weight_sync_group(self):
setup_ref = explorer.setup_weight_sync_group.remote(
master_address, master_port, self.state_dict_meta
)
if is_ipv6_address(master_address):
# using tcp://ipv6:port will lead to ValueError
init_method = f"tcp://[{master_address}]:{master_port}"
else:
init_method = f"tcp://{master_address}:{master_port}"
timeout = self.config.synchronizer.sync_timeout

self._model_update_group = init_process_group(
host=master_address,
port=master_port,
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
backend="nccl",
init_method=init_method,
timeout=timeout,
world_size=world_size,
rank=0,
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
)
torch.distributed.barrier(group=self._model_update_group)
ray.get(setup_ref)

torch.distributed.barrier()

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def sync_weight(self):
for name_prefix, module in self.named_modules:
Expand All @@ -608,9 +603,11 @@ def sync_weight(self):
continue
torch.distributed.broadcast(param, 0, group=self._model_update_group)
param = None
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
torch.distributed.barrier(group=self._model_update_group)
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.distributed.barrier()
torch.cuda.empty_cache()

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def set_algorithm(self, algo_config: AlgorithmConfig):
Expand Down
56 changes: 24 additions & 32 deletions trinity/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
Store,
_new_process_group_helper,
_world,
default_pg_timeout,
Expand All @@ -25,58 +24,51 @@ def is_ipv6_address(ip_str: str) -> bool:


def init_process_group(
backend: Union[str, Backend] = None,
init_method: Optional[str] = None,
host: str,
port: int,
group_name: str,
backend: Union[str, Backend] = "nccl",
timeout: Optional[float] = None,
world_size: int = -1,
rank: int = -1,
store: Optional[Store] = None,
group_name: Optional[str] = None,
pg_options: Optional[Any] = None,
):
assert (store is None) or (init_method is None), "Cannot specify both init_method and store."
assert backend == "nccl", "Only nccl backend is supported for now."

if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
from torch.distributed.distributed_c10d import is_nccl_available

if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")
assert is_nccl_available()

init_method = (
f"tcp://[{host}]:{port}" if is_ipv6_address(ip_str=host) else f"tcp://{host}:{port}"
)

backend = Backend(backend)

if timeout is None:
timeout = default_pg_timeout
else:
timeout = timedelta(seconds=timeout)

# backward compatible API
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
store, rank, world_size = next(rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)

# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
store = PrefixStore(group_name, store)
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(group_name, store)

# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
# We need to determine the appropriate parameter name based on PyTorch version
pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_size=world_size,
group_rank=rank,
global_ranks_in_group=[],
backend=backend,
store=prefix_store,
group_name=group_name,
**{pg_options_param_name: pg_options},
timeout=timeout,
**{pg_options_param_name: pg_options},
)

_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}

return pg