@@ -560,17 +560,21 @@ def _init_distributed_environment_by_device(self, device_str: str):
560560 def _init_ascend_distributed_environment (self , driver_ip ):
561561 """Init ascend distributed environment."""
562562 rank_table_file = _envs .ascend_rank_table_file
563- if not rank_table_file :
564- # if rank table file is not set, treat as single node
565- self .workers = self ._sort_workers (driver_ip , self .workers )
566- # simply set device by index, this is for single node, multiple devices
567- ray .get ([worker .set_device .remote (idx ) for idx , worker in enumerate (self .workers )])
568- else :
563+ set_rt_visable_devices_by_ray = _envs .ascend_set_rt_visable_devices_by_ray
564+
565+ if rank_table_file :
569566 # if rank table file is set, use it to get rank mapping, multiple nodes
570567 rank_mapping , worker_ips , envs = get_ascend_device_rank_mapping (driver_ip )
571568 self .workers = self ._sort_workers_by_ip (worker_ips , self .workers )
572569 ray .get ([worker .set_device .remote (rank_mapping [idx ]) for idx , worker in enumerate (self .workers )])
573570 ray .get ([worker .set_env .remote (envs ) for worker in self .workers ])
571+ elif not set_rt_visable_devices_by_ray :
572+ # if rank table file is not set, treat as single node
573+ # simply set device by index, this is for single node, multiple devices
574+ self .workers = self ._sort_workers (driver_ip , self .workers )
575+ ray .get ([worker .set_device .remote (idx ) for idx , worker in enumerate (self .workers )])
576+ else :
577+ self .workers = self ._sort_workers (driver_ip , self .workers )
574578
575579 """ PD Disaggregation API Begin """
576580
0 commit comments