@@ -70,18 +70,25 @@ ::sycl::queue DeviceManager::GetQueue(const DeviceOrd& device_spec) const {
70
70
std::lock_guard<std::mutex> guard (queue_registering_mutex);
71
71
if (not_use_default_selector) {
72
72
DeviceRegister& device_register = GetDevicesRegister ();
73
- const int device_idx =
74
- collective::IsDistributed () ? collective::GetRank () : device_spec.ordinal ;
75
73
if (device_spec.IsSyclDefault ()) {
76
74
auto & devices = device_register.devices ;
75
+ const int device_idx = collective::IsDistributed ()
76
+ ? collective::GetRank () % devices.size ()
77
+ : device_spec.ordinal ;
77
78
CHECK_LT (device_idx, devices.size ());
78
79
queue_register[device_spec.Name ()] = ::sycl::queue (devices[device_idx]);
79
80
} else if (device_spec.IsSyclCPU ()) {
80
81
auto & cpu_devices = device_register.cpu_devices ;
82
+ const int device_idx = collective::IsDistributed ()
83
+ ? collective::GetRank () % cpu_devices.size ()
84
+ : device_spec.ordinal ;
81
85
CHECK_LT (device_idx, cpu_devices.size ());
82
86
queue_register[device_spec.Name ()] = ::sycl::queue (cpu_devices[device_idx]);
83
87
} else if (device_spec.IsSyclGPU ()) {
84
88
auto & gpu_devices = device_register.gpu_devices ;
89
+ const int device_idx = collective::IsDistributed ()
90
+ ? collective::GetRank () % gpu_devices.size ()
91
+ : device_spec.ordinal ;
85
92
CHECK_LT (device_idx, gpu_devices.size ());
86
93
queue_register[device_spec.Name ()] = ::sycl::queue (gpu_devices[device_idx]);
87
94
}
0 commit comments