@@ -54,15 +54,24 @@ def get_ascend_device_rank_mapping(master_addr):
5454 rank_table = json .load (f )
5555 try :
5656 assert master_addr == rank_table ['server_list' ][0 ]['server_id' ], 'Master address does not match rank table'
57- rank_mapping = {}
58- worker_ips = []
57+ rank_mapping : Dict [ int , int ] = {}
58+ worker_ip_by_rank : Dict [ int , str ] = {}
5959 for server in rank_table ['server_list' ]:
6060 node_ip = server ['server_id' ]
6161 for idx , device in enumerate (server ['device' ]):
62- local_rank = idx
62+ # Prefer explicit device_id if present; fall back to enumeration order.
63+ local_rank = int (device .get ('device_id' , idx ))
6364 global_rank = int (device ['rank_id' ])
6465 rank_mapping [global_rank ] = local_rank
65- worker_ips .append (node_ip )
66+ worker_ip_by_rank [global_rank ] = node_ip
67+
68+ if len (worker_ip_by_rank ) == 0 :
69+ raise ValueError ('Rank table contains no devices.' )
70+
71+ ranks = sorted (worker_ip_by_rank .keys ())
72+ if ranks [0 ] != 0 or ranks [- 1 ] != len (ranks ) - 1 :
73+ raise ValueError (f'Rank ids are not contiguous starting from 0: { ranks [:8 ]} ...{ ranks [- 8 :]} ' )
74+ worker_ips = [worker_ip_by_rank [r ] for r in range (len (ranks ))]
6675 except Exception as e :
6776 logger .error (f'Parse rank table file({ rank_table } ) failed' )
6877 raise e
@@ -625,8 +634,19 @@ def _init_ascend_distributed_environment(self, driver_ip):
625634 if rank_table_file :
626635 # if rank table file is set, use it to get rank mapping, multiple nodes
627636 rank_mapping , worker_ips , envs = get_ascend_device_rank_mapping (driver_ip )
628- self .workers = self ._sort_workers_by_ip (worker_ips , self .workers )
629- ray .get ([worker .set_device .remote (rank_mapping [idx ]) for idx , worker in enumerate (self .workers )])
637+ rank_start = self .rank_offset
638+ rank_end = rank_start + len (self .workers )
639+ if rank_end > len (worker_ips ):
640+ raise ValueError (
641+ 'Rank table world_size is smaller than required ranks for current dp_rank. '
642+ f'rank_table_world_size={ len (worker_ips )} , required_rank_range=[{ rank_start } , { rank_end } )' )
643+
644+ # In dp mode each process only owns a slice of global ranks.
645+ expected_worker_ips = worker_ips [rank_start :rank_end ]
646+ self .workers = self ._sort_workers_by_ip (expected_worker_ips , self .workers )
647+
648+ ray .get (
649+ [worker .set_device .remote (rank_mapping [rank_start + idx ]) for idx , worker in enumerate (self .workers )])
630650 ray .get ([worker .set_env .remote (envs ) for worker in self .workers ])
631651 elif not set_rt_visable_devices_by_ray :
632652 # if rank table file is not set, treat as single node
0 commit comments