Skip to content

Commit 5f8d959

Browse files
pan-x-cchenyushuo
andauthored
Optimize model weight sync process group (#112)
Co-authored-by: chenyushuo <[email protected]>
1 parent 4370578 commit 5f8d959

File tree

5 files changed

+60
-56
lines changed

5 files changed

+60
-56
lines changed

trinity/common/models/vllm_worker.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import torch.distributed
66

7-
from trinity.utils.distributed import init_process_group, is_ipv6_address
7+
from trinity.utils.distributed import init_process_group
88
from trinity.utils.log import get_logger
99

1010
logger = get_logger(__name__)
@@ -38,20 +38,16 @@ def init_process_group(
3838
f" > rank_offset={rank_offset}\n"
3939
f" > world_size={world_size}"
4040
)
41-
if is_ipv6_address(master_address):
42-
# using tcp://ipv6:port will lead to ValueError
43-
init_method = f"tcp://[{master_address}]:{master_port}"
44-
else:
45-
init_method = f"tcp://{master_address}:{master_port}"
46-
4741
self._model_update_group = init_process_group(
42+
host=master_address,
43+
port=master_port,
44+
group_name=group_name,
4845
backend=backend,
49-
init_method=init_method,
5046
timeout=timeout,
5147
world_size=world_size,
5248
rank=self._weight_update_rank,
53-
group_name=group_name,
5449
)
50+
torch.distributed.barrier(group=self._model_update_group)
5551
logger.info("vLLM init_process_group finished.")
5652
self._explorer_name = explorer_name
5753
self._namespace = namespace
@@ -78,6 +74,6 @@ def update_weight(self):
7874
weight = weight.type(self.model_config.dtype)
7975
self.model_runner.model.load_weights(weights=[(name, weight)])
8076
del weight
81-
torch.distributed.barrier()
77+
torch.distributed.barrier(group=self._model_update_group)
8278
torch.cuda.synchronize()
8379
torch.cuda.empty_cache()

trinity/explorer/explorer.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(self, config: Config):
7777
self.state_dict_meta = []
7878
self.status = RunningStatus.RUNNING
7979
self.logger.info("Finished initializing Explorer.")
80+
self._ready_to_sync_condition = asyncio.Condition()
8081

8182
async def setup_weight_sync_group(
8283
self, master_address: str, master_port: int, state_dict_meta: List = None
@@ -158,10 +159,28 @@ async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> No
158159

159160
async def _nccl_weights_update(self):
160161
assert self.state_dict_meta is not None
162+
async with self._ready_to_sync_condition:
163+
try:
164+
await asyncio.wait_for(
165+
self._ready_to_sync_condition.wait_for(
166+
lambda: self.status == RunningStatus.WAITING_SYNC,
167+
),
168+
timeout=self.config.synchronizer.sync_timeout,
169+
)
170+
except asyncio.TimeoutError as e:
171+
self.logger.error(
172+
f"Trainer is not ready for model weight sync in {self.config.synchronizer.sync_timeout} seconds."
173+
)
174+
raise e
161175
await asyncio.gather(
162176
*[model.sync_model.remote(self.explore_step_num) for model in self.models]
163177
)
164178

179+
async def ready_to_sync(self):
180+
async with self._ready_to_sync_condition:
181+
self.status = RunningStatus.WAITING_SYNC
182+
self._ready_to_sync_condition.notify_all()
183+
165184
async def prepare(self) -> None:
166185
"""Preparation before running."""
167186
if self.use_checkpoint_weights_update:
@@ -330,7 +349,6 @@ async def sync_weight(self) -> None:
330349
"""Synchronize model weights."""
331350
# call this method before training start to load the latest model weights
332351
self.logger.info(f"Explorer sync weights at step {self.explore_step_num}.")
333-
self.status = RunningStatus.WAITING_SYNC
334352
if self.use_checkpoint_weights_update:
335353
await self._checkpoint_weights_update()
336354
else: # nccl weights update

trinity/trainer/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,11 @@ def sync_weight(self) -> None:
6767
if explorer_status == RunningStatus.STOPPED:
6868
self.logger.warning("Explorer has already stopped. Skipping sync weight.")
6969
return
70+
ray.get(self.explorer_ref.ready_to_sync.remote())
71+
self.engine.sync_weight()
7072
self.logger.info(
7173
f"Trainer synchronizing weights at step {self.engine.train_step_num} end."
7274
)
73-
self.engine.sync_weight()
7475

7576
def flush_log(self, step: int) -> None:
7677
"""Flush the log of the current step."""

trinity/trainer/verl/fsdp_workers.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373

7474
from trinity.common.config import AlgorithmConfig
7575
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod
76-
from trinity.utils.distributed import init_process_group, is_ipv6_address
76+
from trinity.utils.distributed import init_process_group
7777

7878
logger = logging.getLogger(__file__)
7979
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@@ -579,25 +579,20 @@ def setup_weight_sync_group(self):
579579
setup_ref = explorer.setup_weight_sync_group.remote(
580580
master_address, master_port, self.state_dict_meta
581581
)
582-
if is_ipv6_address(master_address):
583-
# using tcp://ipv6:port will lead to ValueError
584-
init_method = f"tcp://[{master_address}]:{master_port}"
585-
else:
586-
init_method = f"tcp://{master_address}:{master_port}"
587582
timeout = self.config.synchronizer.sync_timeout
588583

589584
self._model_update_group = init_process_group(
585+
host=master_address,
586+
port=master_port,
587+
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
590588
backend="nccl",
591-
init_method=init_method,
592589
timeout=timeout,
593590
world_size=world_size,
594591
rank=0,
595-
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
596592
)
593+
torch.distributed.barrier(group=self._model_update_group)
597594
ray.get(setup_ref)
598595

599-
torch.distributed.barrier()
600-
601596
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
602597
def sync_weight(self):
603598
for name_prefix, module in self.named_modules:
@@ -608,9 +603,11 @@ def sync_weight(self):
608603
continue
609604
torch.distributed.broadcast(param, 0, group=self._model_update_group)
610605
param = None
611-
torch.distributed.barrier()
606+
if torch.distributed.get_rank() == 0:
607+
torch.distributed.barrier(group=self._model_update_group)
612608
torch.cuda.synchronize()
613-
torch.cuda.empty_cache()
609+
torch.distributed.barrier()
610+
torch.cuda.empty_cache()
614611

615612
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
616613
def set_algorithm(self, algo_config: AlgorithmConfig):

trinity/utils/distributed.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torch.distributed.distributed_c10d import (
99
Backend,
1010
PrefixStore,
11-
Store,
1211
_new_process_group_helper,
1312
_world,
1413
default_pg_timeout,
@@ -25,58 +24,51 @@ def is_ipv6_address(ip_str: str) -> bool:
2524

2625

2726
def init_process_group(
28-
backend: Union[str, Backend] = None,
29-
init_method: Optional[str] = None,
27+
host: str,
28+
port: int,
29+
group_name: str,
30+
backend: Union[str, Backend] = "nccl",
3031
timeout: Optional[float] = None,
3132
world_size: int = -1,
3233
rank: int = -1,
33-
store: Optional[Store] = None,
34-
group_name: Optional[str] = None,
3534
pg_options: Optional[Any] = None,
3635
):
37-
assert (store is None) or (init_method is None), "Cannot specify both init_method and store."
36+
assert backend == "nccl", "Only nccl backend is supported for now."
3837

39-
if store is not None:
40-
assert world_size > 0, "world_size must be positive if using store"
41-
assert rank >= 0, "rank must be non-negative if using store"
42-
elif init_method is None:
43-
init_method = "env://"
38+
from torch.distributed.distributed_c10d import is_nccl_available
4439

45-
if backend:
46-
backend = Backend(backend)
47-
else:
48-
backend = Backend("undefined")
40+
assert is_nccl_available()
41+
42+
init_method = (
43+
f"tcp://[{host}]:{port}" if is_ipv6_address(ip_str=host) else f"tcp://{host}:{port}"
44+
)
45+
46+
backend = Backend(backend)
4947

5048
if timeout is None:
5149
timeout = default_pg_timeout
5250
else:
5351
timeout = timedelta(seconds=timeout)
5452

5553
# backward compatible API
56-
if store is None:
57-
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
58-
store, rank, world_size = next(rendezvous_iterator)
59-
store.set_timeout(timeout)
54+
store, rank, world_size = next(rendezvous(init_method, rank, world_size, timeout=timeout))
55+
store.set_timeout(timeout)
6056

61-
# Use a PrefixStore to avoid accidental overrides of keys used by
62-
# different systems (e.g. RPC) in case the store is multi-tenant.
63-
store = PrefixStore(group_name, store)
57+
# Use a PrefixStore to avoid accidental overrides of keys used by
58+
# different systems (e.g. RPC) in case the store is multi-tenant.
59+
prefix_store = PrefixStore(group_name, store)
6460

65-
# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
66-
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
67-
# We need to determine the appropriate parameter name based on PyTorch version
6861
pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
6962
pg, _ = _new_process_group_helper(
70-
world_size,
71-
rank,
72-
[],
73-
backend,
74-
store,
63+
group_size=world_size,
64+
group_rank=rank,
65+
global_ranks_in_group=[],
66+
backend=backend,
67+
store=prefix_store,
7568
group_name=group_name,
76-
**{pg_options_param_name: pg_options},
7769
timeout=timeout,
70+
**{pg_options_param_name: pg_options},
7871
)
7972

8073
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
81-
8274
return pg

0 commit comments

Comments
 (0)