@@ -35,11 +35,28 @@ filterP2PDevices(ur_device_handle_t hSourceDevice,
3535}
3636
3737static std::vector<std::vector<ur_device_handle_t >>
38- populateP2PDevices (size_t maxDevices,
39- const std::vector<ur_device_handle_t > &devices) {
40- std::vector<std::vector<ur_device_handle_t >> p2pDevices (maxDevices);
38+ populateP2PDevices (const std::vector<ur_device_handle_t > &devices) {
39+ std::vector<ur_device_handle_t > allDevices;
40+ std::function<void (ur_device_handle_t )> collectDeviceAndSubdevices =
41+ [&allDevices, &collectDeviceAndSubdevices](ur_device_handle_t device) {
42+ allDevices.push_back (device);
43+ for (auto &subDevice : device->SubDevices ) {
44+ collectDeviceAndSubdevices (subDevice);
45+ }
46+ };
47+
4148 for (auto &device : devices) {
42- p2pDevices[device->Id .value ()] = filterP2PDevices (device, devices);
49+ collectDeviceAndSubdevices (device);
50+ }
51+
52+ uint64_t maxDeviceId = 0 ;
53+ for (auto &device : allDevices) {
54+ maxDeviceId = std::max (maxDeviceId, device->Id .value ());
55+ }
56+
57+ std::vector<std::vector<ur_device_handle_t >> p2pDevices (maxDeviceId + 1 );
58+ for (auto &device : allDevices) {
59+ p2pDevices[device->Id .value ()] = filterP2PDevices (device, allDevices);
4360 }
4461 return p2pDevices;
4562}
@@ -83,8 +100,7 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
83100 nativeEventsPool(this , std::make_unique<v2::provider_normal>(
84101 this , v2::QUEUE_IMMEDIATE,
85102 v2::EVENT_FLAGS_PROFILING_ENABLED)),
86- p2pAccessDevices(populateP2PDevices(
87- phDevices[0 ]->Platform->getNumDevices (), this->hDevices)),
103+ p2pAccessDevices(populateP2PDevices(this ->hDevices)),
88104 defaultUSMPool(this , nullptr ), asyncPool(this , nullptr ) {}
89105
90106ur_result_t ur_context_handle_t_::retain () {
0 commit comments