Skip to content

Commit 2ef2a04

Browse files
authored
[ascend] fix dp multinode rank_table mapping (#4268)
1 parent 0cd064d commit 2ef2a04

File tree

2 files changed

+30
-11
lines changed

2 files changed

+30
-11
lines changed

lmdeploy/pytorch/engine/executor/ray_executor.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,24 @@ def get_ascend_device_rank_mapping(master_addr):
5454
rank_table = json.load(f)
5555
try:
5656
assert master_addr == rank_table['server_list'][0]['server_id'], 'Master address does not match rank table'
57-
rank_mapping = {}
58-
worker_ips = []
57+
rank_mapping: Dict[int, int] = {}
58+
worker_ip_by_rank: Dict[int, str] = {}
5959
for server in rank_table['server_list']:
6060
node_ip = server['server_id']
6161
for idx, device in enumerate(server['device']):
62-
local_rank = idx
62+
# Prefer explicit device_id if present; fall back to enumeration order.
63+
local_rank = int(device.get('device_id', idx))
6364
global_rank = int(device['rank_id'])
6465
rank_mapping[global_rank] = local_rank
65-
worker_ips.append(node_ip)
66+
worker_ip_by_rank[global_rank] = node_ip
67+
68+
if len(worker_ip_by_rank) == 0:
69+
raise ValueError('Rank table contains no devices.')
70+
71+
ranks = sorted(worker_ip_by_rank.keys())
72+
if ranks[0] != 0 or ranks[-1] != len(ranks) - 1:
73+
raise ValueError(f'Rank ids are not contiguous starting from 0: {ranks[:8]}...{ranks[-8:]}')
74+
worker_ips = [worker_ip_by_rank[r] for r in range(len(ranks))]
6675
except Exception as e:
6776
logger.error(f'Parse rank table file({rank_table}) failed')
6877
raise e
@@ -625,8 +634,19 @@ def _init_ascend_distributed_environment(self, driver_ip):
625634
if rank_table_file:
626635
# if rank table file is set, use it to get rank mapping, multiple nodes
627636
rank_mapping, worker_ips, envs = get_ascend_device_rank_mapping(driver_ip)
628-
self.workers = self._sort_workers_by_ip(worker_ips, self.workers)
629-
ray.get([worker.set_device.remote(rank_mapping[idx]) for idx, worker in enumerate(self.workers)])
637+
rank_start = self.rank_offset
638+
rank_end = rank_start + len(self.workers)
639+
if rank_end > len(worker_ips):
640+
raise ValueError(
641+
'Rank table world_size is smaller than required ranks for current dp_rank. '
642+
f'rank_table_world_size={len(worker_ips)}, required_rank_range=[{rank_start}, {rank_end})')
643+
644+
# In dp mode each process only owns a slice of global ranks.
645+
expected_worker_ips = worker_ips[rank_start:rank_end]
646+
self.workers = self._sort_workers_by_ip(expected_worker_ips, self.workers)
647+
648+
ray.get(
649+
[worker.set_device.remote(rank_mapping[rank_start + idx]) for idx, worker in enumerate(self.workers)])
630650
ray.get([worker.set_env.remote(envs) for worker in self.workers])
631651
elif not set_rt_visable_devices_by_ray:
632652
# if rank table file is not set, treat as single node

lmdeploy/pytorch/ray.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,10 @@ def init_ray_cluster(world_size: int, ray_address: str = None, dp: int = 1, devi
126126
# Create a new placement group
127127
placement_group_specs: List[Dict[str, float]] = ([{device_str: 1.0} for _ in range(world_size)])
128128

129-
gcs_addr = ray.get_runtime_context().gcs_address
130-
master_addr = gcs_addr.split(':')[0]
131-
current_ip = master_addr
132-
# This way, at least bundle is required to be created in a current
133-
# node.
129+
# Pin at least one bundle to the local node.
130+
# This helps multi-node DP keep each dp_rank process's workers co-located with
131+
# the node where the process is launched.
132+
current_ip = ray.util.get_node_ip_address()
134133
placement_group_specs[0][f'node:{current_ip}'] = 0.001
135134

136135
# By default, Ray packs resources as much as possible.

0 commit comments

Comments
 (0)