diff --git a/sycl/source/detail/device_impl.cpp b/sycl/source/detail/device_impl.cpp index 629aa72f04dde..b45d4d27b1981 100644 --- a/sycl/source/detail/device_impl.cpp +++ b/sycl/source/detail/device_impl.cpp @@ -167,6 +167,12 @@ std::vector device_impl::create_sub_devices( MPlatform.getOrMakeDeviceImpl(a_ur_device)); res.push_back(sycl_device); }); + // urDevicePartition returns devices with their reference counts + // incremented. Each device_impl wrapper increments the reference count and + // decrements it on destruction (shared ownership). So, we have to decrement + // the reference count once here to release temporary handles. + for (ur_device_handle_t &SubDevice : SubDevices) + Adapter.call(SubDevice); return res; } diff --git a/sycl/unittests/context_device/DeviceRefCounter.cpp b/sycl/unittests/context_device/DeviceRefCounter.cpp index ea8c38b6fdbb7..81b53e094e4b2 100644 --- a/sycl/unittests/context_device/DeviceRefCounter.cpp +++ b/sycl/unittests/context_device/DeviceRefCounter.cpp @@ -29,6 +29,39 @@ static ur_result_t redefinedDeviceReleaseAfter(void *) { return UR_RESULT_SUCCESS; } +ur_result_t redefinedDevicePartitionAfter(void *pParams) { + auto params = *static_cast(pParams); + if (*params.pphSubDevices) { + for (size_t I = 0; I < *params.pNumDevices; ++I) { + *params.pphSubDevices[I] = reinterpret_cast(1000 + I); + } + } + if (*params.ppNumDevicesRet) + **params.ppNumDevicesRet = *params.pNumDevices; + + DevRefCounter += *params.pNumDevices; + return UR_RESULT_SUCCESS; +} + +static constexpr size_t NumSubDevices = 2; + +ur_result_t redefinedDeviceGetInfoAfter(void *pParams) { + auto params = *static_cast(pParams); + if (*params.ppropName == UR_DEVICE_INFO_SUPPORTED_PARTITIONS) { + if (*params.ppPropValue) { + auto *Result = + reinterpret_cast(*params.ppPropValue); + *Result = UR_DEVICE_PARTITION_EQUALLY; + } + if (*params.ppPropSizeRet) + **params.ppPropSizeRet = sizeof(ur_device_partition_t); + } else if (*params.ppropName == UR_DEVICE_INFO_MAX_COMPUTE_UNITS) { + auto *Result = reinterpret_cast(*params.ppPropValue); + *Result = NumSubDevices; + } + return UR_RESULT_SUCCESS; +} + TEST(DevRefCounter, DevRefCounter) { { sycl::unittest::UrMock<> Mock; @@ -52,3 +85,32 @@ TEST(DevRefCounter, DevRefCounter) { } EXPECT_EQ(DevRefCounter, 0); } + +TEST(SubDevRefCounter, SubDevRefCounter) { + { + DevRefCounter = 0; + sycl::unittest::UrMock<> Mock; + mock::getCallbacks().set_after_callback("urDeviceGet", + &redefinedDevicesGetAfter); + mock::getCallbacks().set_after_callback("urDeviceRetain", + &redefinedDeviceRetainAfter); + mock::getCallbacks().set_after_callback("urDeviceRelease", + &redefinedDeviceReleaseAfter); + mock::getCallbacks().set_before_callback("urDevicePartition", + &redefinedDevicePartitionAfter); + mock::getCallbacks().set_after_callback("urDeviceGetInfo", + &redefinedDeviceGetInfoAfter); + sycl::platform Plt = sycl::platform(); + + auto Devs = Plt.get_devices(); + if (!Devs.empty()) { + auto Subdevs = Devs[0] + .create_sub_devices< + sycl::info::partition_property::partition_equally>( + NumSubDevices); + } + EXPECT_NE(DevRefCounter, 0); + sycl::detail::GlobalHandler::instance().getPlatformCache().clear(); + } + EXPECT_EQ(DevRefCounter, 0); +}