diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index 3346b932260cf..54b430a7aa498 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -843,40 +843,36 @@ static void setSpecializationConstants(device_image_impl &InputImpl, ur_program_handle_t ProgramManager::getBuiltURProgram( context_impl &ContextImpl, device_impl &DeviceImpl, KernelNameStrRefT KernelName, const NDRDescT &NDRDesc) { - device_impl *RootDevImpl; - ur_bool_t MustBuildOnSubdevice = true; - + device_impl *BuildDev = &DeviceImpl; // Check if we can optimize program builds for sub-devices by using a program // built for the root device - if (!DeviceImpl.isRootDevice()) { - RootDevImpl = &DeviceImpl; - while (!RootDevImpl->isRootDevice()) { - device_impl &ParentDev = *detail::getSyclObjImpl( - RootDevImpl->get_info()); - // Sharing is allowed within a single context only - if (!ContextImpl.hasDevice(ParentDev)) - break; - RootDevImpl = &ParentDev; - } + if (!BuildDev->isRootDevice()) { + device_impl *CandidateRoot = BuildDev; + while (!CandidateRoot->isRootDevice()) + CandidateRoot = &*detail::getSyclObjImpl( + CandidateRoot->get_info()); + bool MustBuildOnSubdevice = true; ContextImpl.getAdapter().call( - RootDevImpl->getHandleRef(), UR_DEVICE_INFO_BUILD_ON_SUBDEVICE, + CandidateRoot->getHandleRef(), UR_DEVICE_INFO_BUILD_ON_SUBDEVICE, sizeof(ur_bool_t), &MustBuildOnSubdevice, nullptr); - } - device_impl &RootOrSubDevImpl = - MustBuildOnSubdevice == true ? DeviceImpl : *RootDevImpl; + // Sharing is allowed within a single context if and only if backend + // supports sharing. + if (!MustBuildOnSubdevice && ContextImpl.hasDevice(*CandidateRoot)) + BuildDev = CandidateRoot; + } const RTDeviceBinaryImage &Img = - getDeviceImage(KernelName, ContextImpl, RootOrSubDevImpl); + getDeviceImage(KernelName, ContextImpl, *BuildDev); // Check that device supports all aspects used by the kernel if (auto exception = - checkDevSupportDeviceRequirements(RootOrSubDevImpl, Img, NDRDesc)) + checkDevSupportDeviceRequirements(*BuildDev, Img, NDRDesc)) throw *exception; std::set DeviceImagesToLink = - collectDeviceImageDeps(Img, {RootOrSubDevImpl}); + collectDeviceImageDeps(Img, {*BuildDev}); // Decompress all DeviceImagesToLink for (const RTDeviceBinaryImage *BinImg : DeviceImagesToLink) @@ -888,8 +884,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram( std::copy(DeviceImagesToLink.begin(), DeviceImagesToLink.end(), std::back_inserter(AllImages)); - return getBuiltURProgram(std::move(AllImages), ContextImpl, - {RootOrSubDevImpl}); + return getBuiltURProgram(std::move(AllImages), ContextImpl, {*BuildDev}); } ur_program_handle_t diff --git a/sycl/unittests/program_manager/SubDevices.cpp b/sycl/unittests/program_manager/SubDevices.cpp index bcfa67d7f55ae..8c9913f412178 100644 --- a/sycl/unittests/program_manager/SubDevices.cpp +++ b/sycl/unittests/program_manager/SubDevices.cpp @@ -14,11 +14,24 @@ #include -static ur_device_handle_t rootDevice; -static ur_device_handle_t urSubDev1 = (ur_device_handle_t)0x1; -static ur_device_handle_t urSubDev2 = (ur_device_handle_t)0x2; +static ur_device_handle_t rootDevice = (ur_device_handle_t)0x1; +// Sub-devices under rootDevice +static ur_device_handle_t urSubDev1 = (ur_device_handle_t)0x11; +static ur_device_handle_t urSubDev2 = (ur_device_handle_t)0x12; +// Sub-sub-devices under urSubDev1 +static ur_device_handle_t urSubSubDev1 = (ur_device_handle_t)0x111; +static ur_device_handle_t urSubSubDev2 = (ur_device_handle_t)0x112; namespace { +ur_result_t redefinedDeviceGet(void *pParams) { + auto params = *static_cast(pParams); + if (*params.ppNumDevices) + **params.ppNumDevices = 1; + if (*params.pphDevices) + (*params.pphDevices)[0] = rootDevice; + return UR_RESULT_SUCCESS; +} + ur_result_t redefinedDeviceGetInfo(void *pParams) { auto params = *static_cast(pParams); if (*params.ppropName == UR_DEVICE_INFO_SUPPORTED_PARTITIONS) { @@ -41,13 +54,32 @@ ur_result_t redefinedDeviceGetInfo(void *pParams) { } } if (*params.ppropName == UR_DEVICE_INFO_PARTITION_MAX_SUB_DEVICES) { - ((uint32_t *)*params.ppPropValue)[0] = 2; + if (!*params.ppPropValue) + **params.ppPropSizeRet = sizeof(uint32_t); + else + ((uint32_t *)*params.ppPropValue)[0] = 2; } if (*params.ppropName == UR_DEVICE_INFO_PARENT_DEVICE) { - if (*params.phDevice == urSubDev1 || *params.phDevice == urSubDev2) - ((ur_device_handle_t *)*params.ppPropValue)[0] = rootDevice; + if (!*params.ppPropValue) { + **params.ppPropSizeRet = sizeof(ur_device_handle_t); + } else { + ur_device_handle_t &ret = + *static_cast(*params.ppPropValue); + if (*params.phDevice == urSubDev1 || *params.phDevice == urSubDev2) { + ret = rootDevice; + } else if (*params.phDevice == urSubSubDev1 || + *params.phDevice == urSubSubDev2) { + ret = urSubDev1; + } else { + ret = nullptr; + } + } + } + if (*params.ppropName == UR_DEVICE_INFO_BUILD_ON_SUBDEVICE) { + if (!*params.ppPropValue) + **params.ppPropSizeRet = sizeof(ur_bool_t); else - ((ur_device_handle_t *)*params.ppPropValue)[0] = nullptr; + ((ur_bool_t *)*params.ppPropValue)[0] = false; } return UR_RESULT_SUCCESS; } @@ -77,6 +109,13 @@ ur_result_t redefinedProgramBuild(void *) { return UR_RESULT_SUCCESS; } +static int buildCallCount = 0; + +ur_result_t redefinedProgramBuildExp(void *) { + buildCallCount++; + return UR_RESULT_SUCCESS; +} + ur_result_t redefinedContextCreate(void *) { return UR_RESULT_SUCCESS; } } // anonymous namespace @@ -128,3 +167,33 @@ TEST(SubDevices, DISABLED_BuildProgramForSubdevices) { *sycl::detail::getSyclObjImpl(Ctx), subDev2, sycl::detail::KernelInfo::getName()); } + +// Check that program is built once for all sub-sub-devices +TEST(SubDevices, BuildProgramForSubSubDevices) { + sycl::unittest::UrMock<> Mock; + mock::getCallbacks().set_after_callback("urDeviceGet", &redefinedDeviceGet); + mock::getCallbacks().set_after_callback("urDeviceGetInfo", + &redefinedDeviceGetInfo); + mock::getCallbacks().set_after_callback("urProgramBuildExp", + &redefinedProgramBuildExp); + sycl::platform Plt = sycl::platform(); + sycl::device root = Plt.get_devices()[0]; + sycl::detail::platform_impl &PltImpl = *sycl::detail::getSyclObjImpl(Plt); + // Initialize sub-sub-devices + sycl::detail::device_impl &SubSub1 = + PltImpl.getOrMakeDeviceImpl(urSubSubDev1); + sycl::detail::device_impl &SubSub2 = + PltImpl.getOrMakeDeviceImpl(urSubSubDev2); + + sycl::context Ctx{root}; + buildCallCount = 0; + sycl::detail::ProgramManager::getInstance().getBuiltURProgram( + *sycl::detail::getSyclObjImpl(Ctx), SubSub1, + sycl::detail::KernelInfo::getName()); + sycl::detail::ProgramManager::getInstance().getBuiltURProgram( + *sycl::detail::getSyclObjImpl(Ctx), SubSub2, + sycl::detail::KernelInfo::getName()); + + // Check that program is built only once. + EXPECT_EQ(buildCallCount, 1); +}