diff --git a/gputreeshap b/gputreeshap index 40eae8c4c459..787259b412c1 160000 --- a/gputreeshap +++ b/gputreeshap @@ -1 +1 @@ -Subproject commit 40eae8c4c45974705f8053e4d3d05b88e3cfaefd +Subproject commit 787259b412c18ab8d5f24bf2b8bd6a59ff8208f3 diff --git a/plugin/sycl/device_manager.cc b/plugin/sycl/device_manager.cc index dc3939934e31..ee652065db23 100644 --- a/plugin/sycl/device_manager.cc +++ b/plugin/sycl/device_manager.cc @@ -21,18 +21,25 @@ ::sycl::queue* DeviceManager::GetQueue(const DeviceOrd& device_spec) const { (collective::IsDistributed()); DeviceRegister& device_register = GetDevicesRegister(); if (not_use_default_selector) { - const int device_idx = - collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal; if (device_spec.IsSyclDefault()) { auto& devices = device_register.devices; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % devices.size() + : device_spec.ordinal; CHECK_LT(device_idx, devices.size()); queue_idx = device_idx; } else if (device_spec.IsSyclCPU()) { auto& cpu_devices_idxes = device_register.cpu_devices_idxes; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % cpu_devices_idxes.size() + : device_spec.ordinal; CHECK_LT(device_idx, cpu_devices_idxes.size()); queue_idx = cpu_devices_idxes[device_idx]; } else if (device_spec.IsSyclGPU()) { auto& gpu_devices_idxes = device_register.gpu_devices_idxes; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % gpu_devices_idxes.size() + : device_spec.ordinal; CHECK_LT(device_idx, gpu_devices_idxes.size()); queue_idx = gpu_devices_idxes[device_idx]; } else {