@@ -18,11 +18,24 @@ static std::vector<ur_device_handle_t>
1818filterP2PDevices (ur_device_handle_t hSourceDevice,
1919 const std::vector<ur_device_handle_t > &devices) {
2020 std::vector<ur_device_handle_t > p2pDevices;
21+
22+ std::optional<std::unordered_set<DeviceId>> p2pDevicesEnabledByUser;
23+
2124 for (auto &device : devices) {
2225 if (device == hSourceDevice) {
2326 continue ;
2427 }
2528
29+ if (!p2pDevicesEnabledByUser.has_value ())
30+ {
31+ std::shared_lock<ur_shared_mutex> Lock (hSourceDevice->Mutex );
32+ p2pDevicesEnabledByUser.emplace (hSourceDevice->p2pDeviceIds );
33+ }
34+
35+ if (p2pDevicesEnabledByUser->count (*device->Id ) == 0 ) {
36+ continue ;
37+ }
38+
2639 ze_bool_t p2p;
2740 ZE2UR_CALL_THROWS (zeDeviceCanAccessPeer,
2841 (device->ZeDevice , hSourceDevice->ZeDevice , &p2p));
@@ -126,8 +139,39 @@ void ur_context_handle_t_::removeUsmPool(ur_usm_pool_handle_t hPool) {
126139 usmPoolHandles.remove (hPool);
127140}
128141
129- const std::vector<ur_device_handle_t > &
130- ur_context_handle_t_::getP2PDevices (ur_device_handle_t hDevice) const {
142+ void ur_context_handle_t_::addResidentDevice (ur_device_handle_t hDevice, ur_device_handle_t newPeerDevice) {
143+ std::scoped_lock<ur_shared_mutex> lock (Mutex);
144+ auto & pDevices = p2pAccessDevices.at (hDevice->Id );
145+
146+ assert (0 = std::count_if (
147+ std::begin (pDevices), std::end (pDevices),
148+ [&](const auto pDevice) { return newPeerDevice->Id == pDevice->Id ; }));
149+
150+ pDevices.push_back (newPeerDevice);
151+ }
152+ void ur_context_handle_t_::removeResidentDevice (ur_device_handle_t hDevice, ur_device_handle_t oldPeerDevice) {
153+ std::scoped_lock<ur_shared_mutex> lock (Mutex);
154+ auto & pDevices = p2pAccessDevices.at (hDevice->Id );
155+
156+ const auto & findOldDevice = [&] {
157+ return std::find_if (
158+ std::begin (pDevices), std::end (pDevices),
159+ [oldPeerDevice](const auto pDevice) { return oldPeerDevice->Id == pDevice->Id ; });
160+ };
161+
162+ auto pDeviceIt = findOldDevice ();
163+ assert (pDeviceIt != std::end (pDevices));
164+ pDevices.erase (pDeviceIt);
165+ assert (findOldDevice () == std::end (pDevices));
166+
167+ for (auto poolHandle : usmPoolHandles) {
168+ poolHandle->removeResidentDevice ();
169+ }
170+ }
171+
172+ std::vector<ur_device_handle_t >
173+ ur_context_handle_t_::getP2PDevices (ur_device_handle_t hDevice) {
174+ std::scoped_lock<ur_shared_mutex> lock (Mutex);
131175 return p2pAccessDevices[hDevice->Id .value ()];
132176}
133177
@@ -145,6 +189,10 @@ ur_result_t urContextCreate(uint32_t deviceCount,
145189
146190 *phContext =
147191 new ur_context_handle_t_ (zeContext, deviceCount, phDevices, true );
192+ {
193+ std::scoped_lock<ur_shared_mutex> Lock (hPlatform->ContextsMutex );
194+ hPlatform->Contexts .push_back (*phContext);
195+ }
148196 return UR_RESULT_SUCCESS;
149197} catch (...) {
150198 return exceptionToResult (std::current_exception ());
@@ -182,6 +230,14 @@ ur_result_t urContextRetain(ur_context_handle_t hContext) try {
182230}
183231
184232ur_result_t urContextRelease (ur_context_handle_t hContext) try {
233+ auto Platform = hContext->getPlatform ();
234+ auto &Contexts = Platform->Contexts ;
235+ {
236+ std::scoped_lock<ur_shared_mutex> Lock (Platform->ContextsMutex );
237+ auto It = std::find (Contexts.begin (), Contexts.end (), hContext);
238+ UR_ASSERT (It != Contexts.end (), UR_RESULT_ERROR_INVALID_CONTEXT);
239+ Contexts.erase (It);
240+ }
185241 return hContext->release ();
186242} catch (...) {
187243 return exceptionToResult (std::current_exception ());
0 commit comments