@@ -19,20 +19,27 @@ ::sycl::queue* DeviceManager::GetQueue(const DeviceOrd& device_spec) const {
19
19
size_t queue_idx;
20
20
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal ) ||
21
21
(collective::IsDistributed ());
22
- DeviceRegister& device_register = GetDevicesRegister ();
23
22
if (not_use_default_selector) {
24
- const int device_idx =
25
- collective::IsDistributed () ? collective::GetRank () : device_spec.ordinal ;
23
+ DeviceRegister& device_register = GetDevicesRegister ();
26
24
if (device_spec.IsSyclDefault ()) {
27
25
auto & devices = device_register.devices ;
26
+ const int device_idx = collective::IsDistributed ()
27
+ ? collective::GetRank () % devices.size ()
28
+ : device_spec.ordinal ;
28
29
CHECK_LT (device_idx, devices.size ());
29
30
queue_idx = device_idx;
30
31
} else if (device_spec.IsSyclCPU ()) {
31
32
auto & cpu_devices_idxes = device_register.cpu_devices_idxes ;
33
+ const int device_idx = collective::IsDistributed ()
34
+ ? collective::GetRank () % cpu_devices_idxes.size ()
35
+ : device_spec.ordinal ;
32
36
CHECK_LT (device_idx, cpu_devices_idxes.size ());
33
37
queue_idx = cpu_devices_idxes[device_idx];
34
38
} else if (device_spec.IsSyclGPU ()) {
35
39
auto & gpu_devices_idxes = device_register.gpu_devices_idxes ;
40
+ const int device_idx = collective::IsDistributed ()
41
+ ? collective::GetRank () % gpu_devices_idxes.size ()
42
+ : device_spec.ordinal ;
36
43
CHECK_LT (device_idx, gpu_devices_idxes.size ());
37
44
queue_idx = gpu_devices_idxes[device_idx];
38
45
} else {
0 commit comments