diff --git a/source/adapters/level_zero/context.hpp b/source/adapters/level_zero/context.hpp index c2fbba633f..e7c0d784a0 100644 --- a/source/adapters/level_zero/context.hpp +++ b/source/adapters/level_zero/context.hpp @@ -100,6 +100,9 @@ struct ur_context_handle_t_ : _ur_object { l0_command_list_cache_info>>> ZeCopyCommandListCache; + std::unordered_map> + P2PDeviceCache; + // Store USM pool for USM shared and device allocations. There is 1 memory // pool per each pair of (context, device) per each memory type. std::unordered_map diff --git a/source/adapters/level_zero/usm.cpp b/source/adapters/level_zero/usm.cpp index 5296391794..b5e7a0242b 100644 --- a/source/adapters/level_zero/usm.cpp +++ b/source/adapters/level_zero/usm.cpp @@ -154,15 +154,26 @@ static ur_result_t USMAllocationMakeResident( } else { Devices.push_back(Device); if (ForceResidency == USMAllocationForceResidencyType::P2PDevices) { - ze_bool_t P2P; - for (const auto &D : Context->Devices) { - if (D == Device) - continue; - // TODO: Cache P2P devices for a context - ZE2UR_CALL(zeDeviceCanAccessPeer, - (D->ZeDevice, Device->ZeDevice, &P2P)); - if (P2P) - Devices.push_back(D); + // Check if the P2P devices are already cached + auto it = Context->P2PDeviceCache.find(Device); + if (it != Context->P2PDeviceCache.end()) { + // Use cached P2P devices + Devices.insert(Devices.end(), it->second.begin(), it->second.end()); + } else { + // Query for P2P devices and update the cache + std::list P2PDevices; + ze_bool_t P2P; + for (const auto &D : Context->Devices) { + if (D == Device) + continue; + ZE2UR_CALL(zeDeviceCanAccessPeer, + (D->ZeDevice, Device->ZeDevice, &P2P)); + if (P2P) + P2PDevices.push_back(D); + } + // Update the cache + Context->P2PDeviceCache[Device] = P2PDevices; + Devices.insert(Devices.end(), P2PDevices.begin(), P2PDevices.end()); } } }