Skip to content

Commit 1186004

Browse files
[SYCL] Unify program build paths in program manager
1 parent c5845a7 commit 1186004

File tree

2 files changed

+105
-193
lines changed

2 files changed

+105
-193
lines changed

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 97 additions & 188 deletions
Original file line numberDiff line numberDiff line change
@@ -732,9 +732,6 @@ static void
732732
setSpecializationConstants(const std::shared_ptr<device_image_impl> &InputImpl,
733733
ur_program_handle_t Prog,
734734
const AdapterPtr &Adapter) {
735-
// Set ITT annotation specialization constant if needed.
736-
enableITTAnnotationsIfNeeded(Prog, Adapter);
737-
738735
std::lock_guard<std::mutex> Lock{InputImpl->get_spec_const_data_lock()};
739736
const std::map<std::string, std::vector<device_image_impl::SpecConstDescT>>
740737
&SpecConstData = InputImpl->get_spec_const_data_ref();
@@ -769,15 +766,6 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
769766
const ContextImplPtr &ContextImpl, const DeviceImplPtr &DeviceImpl,
770767
const std::string &KernelName, const NDRDescT &NDRDesc,
771768
bool JITCompilationIsRequired) {
772-
KernelProgramCache &Cache = ContextImpl->getKernelProgramCache();
773-
774-
std::string CompileOpts;
775-
std::string LinkOpts;
776-
777-
applyOptionsFromEnvironment(CompileOpts, LinkOpts);
778-
779-
SerializedObj SpecConsts;
780-
781769
// Check if we can optimize program builds for sub-devices by using a program
782770
// built for the root device
783771
DeviceImplPtr RootDevImpl = DeviceImpl;
@@ -824,19 +812,36 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
824812
AllImages.push_back(&Img);
825813
std::copy(ImageDeps.begin(), ImageDeps.end(), std::back_inserter(AllImages));
826814

827-
auto BuildF = [this, &Img, &Context, &ContextImpl, &Device, &CompileOpts,
815+
return getBuiltURProgram(Img, Context, {Device}, DeviceImagesToLink,
816+
AllImages);
817+
}
818+
819+
ur_program_handle_t ProgramManager::getBuiltURProgram(
820+
const RTDeviceBinaryImage &Img, const context &Context,
821+
const std::vector<device> &Devs,
822+
const std::set<RTDeviceBinaryImage *> &DeviceImagesToLink,
823+
const std::vector<const RTDeviceBinaryImage *> &AllImages,
824+
const std::shared_ptr<device_image_impl> &DeviceImageImpl,
825+
const SerializedObj &SpecConsts) {
826+
std::string CompileOpts;
827+
std::string LinkOpts;
828+
applyOptionsFromEnvironment(CompileOpts, LinkOpts);
829+
auto BuildF = [this, &Img, &DeviceImageImpl, &Context, &Devs, &CompileOpts,
828830
&LinkOpts, SpecConsts, &DeviceImagesToLink, &AllImages] {
831+
const ContextImplPtr &ContextImpl = getSyclObjImpl(Context);
829832
const AdapterPtr &Adapter = ContextImpl->getAdapter();
830-
applyOptionsFromImage(CompileOpts, LinkOpts, Img, {Device}, Adapter);
833+
applyOptionsFromImage(CompileOpts, LinkOpts, Img, Devs, Adapter);
831834
// Should always come last!
832835
appendCompileEnvironmentVariablesThatAppend(CompileOpts);
833836
appendLinkEnvironmentVariablesThatAppend(LinkOpts);
837+
834838
auto [NativePrg, DeviceCodeWasInCache] = getOrCreateURProgram(
835-
Img, AllImages, Context, {Device}, CompileOpts + LinkOpts, SpecConsts);
839+
Img, {AllImages}, Context, Devs, CompileOpts + LinkOpts, SpecConsts);
836840

837-
if (!DeviceCodeWasInCache) {
838-
if (Img.supportsSpecConstants())
839-
enableITTAnnotationsIfNeeded(NativePrg, Adapter);
841+
if (!DeviceCodeWasInCache && Img.supportsSpecConstants()) {
842+
enableITTAnnotationsIfNeeded(NativePrg, Adapter);
843+
if (DeviceImageImpl)
844+
setSpecializationConstants(DeviceImageImpl, NativePrg, Adapter);
840845
}
841846

842847
UrFuncInfo<UrApiKind::urProgramRelease> programReleaseInfo;
@@ -864,34 +869,28 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
864869
for (RTDeviceBinaryImage *BinImg : DeviceImagesToLink) {
865870
if (UseDeviceLibs)
866871
DeviceLibReqMask |= getDeviceLibReqMask(*BinImg);
867-
device_image_plain DevImagePlain =
868-
getDeviceImageFromBinaryImage(BinImg, Context, Device);
869-
const std::shared_ptr<detail::device_image_impl> &DeviceImageImpl =
870-
detail::getSyclObjImpl(DevImagePlain);
871872

872-
SerializedObj ImgSpecConsts =
873-
DeviceImageImpl->get_spec_const_blob_ref();
874-
875-
ur_program_handle_t NativePrg =
876-
createURProgram(*BinImg, Context, {Device});
873+
ur_program_handle_t NativePrg = createURProgram(*BinImg, Context, Devs);
877874

878875
if (BinImg->supportsSpecConstants())
879-
setSpecializationConstants(DeviceImageImpl, NativePrg, Adapter);
876+
enableITTAnnotationsIfNeeded(NativePrg, Adapter);
880877

881878
ProgramsToLink.push_back(NativePrg);
882879
}
883880
}
884-
std::vector<ur_device_handle_t> Devs = {
885-
getSyclObjImpl(Device).get()->getHandleRef()};
886-
;
881+
882+
std::vector<ur_device_handle_t> URDevices;
883+
for (auto Dev : Devs)
884+
URDevices.push_back(getSyclObjImpl(Dev).get()->getHandleRef());
885+
887886
ProgramPtr BuiltProgram = build(
888-
std::move(ProgramManaged), ContextImpl, CompileOpts, LinkOpts, Devs,
889-
DeviceLibReqMask, ProgramsToLink,
887+
std::move(ProgramManaged), ContextImpl, CompileOpts, LinkOpts,
888+
URDevices, DeviceLibReqMask, ProgramsToLink,
890889
/*CreatedFromBinary*/ Img.getFormat() != SYCL_DEVICE_BINARY_TYPE_SPIRV);
890+
891891
// Those extra programs won't be used anymore, just the final linked result
892892
for (ur_program_handle_t Prg : ProgramsToLink)
893893
Adapter->call<UrApiKind::urProgramRelease>(Prg);
894-
895894
emitBuiltProgramInfo(BuiltProgram.get(), ContextImpl);
896895

897896
{
@@ -902,57 +901,98 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
902901
}
903902
}
904903

905-
ContextImpl->addDeviceGlobalInitializer(BuiltProgram.get(), {Device}, &Img);
904+
ContextImpl->addDeviceGlobalInitializer(BuiltProgram.get(), Devs, &Img);
906905

907906
// Save program to persistent cache if it is not there
908907
if (!DeviceCodeWasInCache) {
909-
PersistentDeviceCodeCache::putItemToDisc({Device}, AllImages, SpecConsts,
908+
PersistentDeviceCodeCache::putItemToDisc(Devs, AllImages, SpecConsts,
910909
CompileOpts + LinkOpts,
911910
BuiltProgram.get());
912911
}
912+
913913
return BuiltProgram.release();
914914
};
915915

916+
if (!SYCLConfig<SYCL_CACHE_IN_MEM>::get())
917+
return BuildF();
918+
916919
uint32_t ImgId = Img.getImageID();
917-
const ur_device_handle_t UrDevice = Dev->getHandleRef();
918-
auto CacheKey = std::make_pair(std::make_pair(std::move(SpecConsts), ImgId),
919-
std::set<ur_device_handle_t>{UrDevice});
920+
std::set<ur_device_handle_t> URDevicesSet;
921+
std::transform(Devs.begin(), Devs.end(),
922+
std::inserter(URDevicesSet, URDevicesSet.begin()),
923+
[](const device &Dev) {
924+
return getSyclObjImpl(Dev).get()->getHandleRef();
925+
});
926+
auto CacheKey =
927+
std::make_pair(std::make_pair(SpecConsts, ImgId), URDevicesSet);
920928

929+
const ContextImplPtr &ContextImpl = getSyclObjImpl(Context);
930+
KernelProgramCache &Cache = ContextImpl->getKernelProgramCache();
921931
auto GetCachedBuildF = [&Cache, &CacheKey]() {
922932
return Cache.getOrInsertProgram(CacheKey);
923933
};
924934

925-
if (!SYCLConfig<SYCL_CACHE_IN_MEM>::get())
926-
return BuildF();
927-
928935
auto BuildResult = Cache.getOrBuild<errc::build>(GetCachedBuildF, BuildF);
929936
// getOrBuild is not supposed to return nullptr
930937
assert(BuildResult != nullptr && "Invalid build result");
931938

932939
ur_program_handle_t ResProgram = BuildResult->Val;
933-
auto Adapter = ContextImpl->getAdapter();
934940

941+
// Here we have multiple devices a program is built for, so add the program to
942+
// the cache for all subsets of provided list of devices.
943+
const AdapterPtr &Adapter = ContextImpl->getAdapter();
935944
// If we linked any extra device images, then we need to
936945
// cache them as well.
937-
for (const RTDeviceBinaryImage *BImg : DeviceImagesToLink) {
938-
// CacheKey is captured by reference by GetCachedBuildF, so we can simply
939-
// update it here and re-use that lambda.
940-
CacheKey.first.second = BImg->getImageID();
941-
bool DidInsert = Cache.insertBuiltProgram(CacheKey, ResProgram);
942-
if (DidInsert) {
943-
// For every cached copy of the program, we need to increment its refcount
944-
Adapter->call<UrApiKind::urProgramRetain>(ResProgram);
946+
auto CacheLinkedImages = [&Adapter, &Cache, &CacheKey, &ResProgram,
947+
&DeviceImagesToLink] {
948+
for (const RTDeviceBinaryImage *BImg : DeviceImagesToLink) {
949+
// CacheKey is captured by reference by GetCachedBuildF, so we can simply
950+
// update it here and re-use that lambda.
951+
CacheKey.first.second = BImg->getImageID();
952+
bool DidInsert = Cache.insertBuiltProgram(CacheKey, ResProgram);
953+
if (DidInsert) {
954+
// For every cached copy of the program, we need to increment its
955+
// refcount
956+
Adapter->call<UrApiKind::urProgramRetain>(ResProgram);
957+
}
958+
}
959+
};
960+
CacheLinkedImages();
961+
962+
if (URDevicesSet.size() > 1) {
963+
// emplace all subsets of the current set of devices into the cache.
964+
// Set of all devices is not included in the loop as it was already added
965+
// into the cache.
966+
for (int Mask = 1; Mask < (1 << URDevicesSet.size()) - 1; ++Mask) {
967+
std::set<ur_device_handle_t> Subset;
968+
int Index = 0;
969+
for (auto It = URDevicesSet.begin(); It != URDevicesSet.end();
970+
++It, ++Index) {
971+
if (Mask & (1 << Index)) {
972+
Subset.insert(*It);
973+
}
974+
}
975+
// Change device in the cache key to reduce copying of spec const data.
976+
CacheKey.second = Subset;
977+
bool DidInsert = Cache.insertBuiltProgram(CacheKey, ResProgram);
978+
if (DidInsert) {
979+
// For every cached copy of the program, we need to increment its
980+
// refcount
981+
Adapter->call<UrApiKind::urProgramRetain>(ResProgram);
982+
}
983+
CacheLinkedImages();
984+
// getOrBuild is not supposed to return nullptr
985+
assert(BuildResult != nullptr && "Invalid build result");
945986
}
946987
}
947988

948989
// If caching is enabled, one copy of the program handle will be
949990
// stored in the cache, and one handle is returned to the
950991
// caller. In that case, we need to increase the ref count of the
951992
// program.
952-
ContextImpl->getAdapter()->call<UrApiKind::urProgramRetain>(ResProgram);
993+
Adapter->call<UrApiKind::urProgramRetain>(ResProgram);
953994
return ResProgram;
954995
}
955-
956996
// When caching is enabled, the returned UrProgram and UrKernel will
957997
// already have their ref count incremented.
958998
std::tuple<ur_kernel_handle_t, std::mutex *, const KernelArgMask *,
@@ -2437,8 +2477,6 @@ ProgramManager::compile(const device_image_plain &DeviceImage,
24372477
const AdapterPtr &Adapter =
24382478
getSyclObjImpl(InputImpl->get_context())->getAdapter();
24392479

2440-
// Device is not used when creating program from SPIRV, so passing only one
2441-
// device is OK.
24422480
ur_program_handle_t Prog = createURProgram(*InputImpl->get_bin_image_ref(),
24432481
InputImpl->get_context(), Devs);
24442482

@@ -2608,149 +2646,20 @@ device_image_plain ProgramManager::build(const device_image_plain &DeviceImage,
26082646

26092647
KernelProgramCache &Cache = ContextImpl->getKernelProgramCache();
26102648

2611-
std::string CompileOpts;
2612-
std::string LinkOpts;
2613-
applyOptionsFromEnvironment(CompileOpts, LinkOpts);
2614-
26152649
const RTDeviceBinaryImage *ImgPtr = InputImpl->get_bin_image_ref();
26162650
const RTDeviceBinaryImage &Img = *ImgPtr;
26172651

26182652
SerializedObj SpecConsts = InputImpl->get_spec_const_blob_ref();
26192653

2620-
// TODO: Unify this code with getBuiltPIProgram
2621-
auto BuildF = [this, &Context, &Img, &Devs, &CompileOpts, &LinkOpts,
2622-
&InputImpl, SpecConsts] {
2623-
ContextImplPtr ContextImpl = getSyclObjImpl(Context);
2624-
const AdapterPtr &Adapter = ContextImpl->getAdapter();
2625-
applyOptionsFromImage(CompileOpts, LinkOpts, Img, Devs, Adapter);
2626-
// Should always come last!
2627-
appendCompileEnvironmentVariablesThatAppend(CompileOpts);
2628-
appendLinkEnvironmentVariablesThatAppend(LinkOpts);
2629-
2630-
// Device is not used when creating program from SPIRV, so passing only one
2631-
// device is OK.
2632-
auto [NativePrg, DeviceCodeWasInCache] = getOrCreateURProgram(
2633-
Img, {&Img}, Context, Devs, CompileOpts + LinkOpts, SpecConsts);
2634-
2635-
if (!DeviceCodeWasInCache &&
2636-
InputImpl->get_bin_image_ref()->supportsSpecConstants())
2637-
setSpecializationConstants(InputImpl, NativePrg, Adapter);
2638-
2639-
UrFuncInfo<UrApiKind::urProgramRelease> programReleaseInfo;
2640-
auto programRelease =
2641-
programReleaseInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
2642-
ProgramPtr ProgramManaged(NativePrg, programRelease);
2643-
2644-
// Link a fallback implementation of device libraries if they are not
2645-
// supported by a device compiler.
2646-
// Pre-compiled programs are supposed to be already linked.
2647-
// If device image is not SPIR-V, DeviceLibReqMask will be 0 which means
2648-
// no fallback device library will be linked.
2649-
uint32_t DeviceLibReqMask = 0;
2650-
if (Img.getFormat() == SYCL_DEVICE_BINARY_TYPE_SPIRV &&
2651-
!SYCLConfig<SYCL_DEVICELIB_NO_FALLBACK>::get())
2652-
DeviceLibReqMask = getDeviceLibReqMask(Img);
2653-
2654-
// TODO: Add support for dynamic linking with kernel bundles
2655-
std::vector<ur_program_handle_t> ExtraProgramsToLink;
2656-
std::vector<ur_device_handle_t> URDevices;
2657-
for (auto Dev : Devs) {
2658-
URDevices.push_back(getSyclObjImpl(Dev).get()->getHandleRef());
2659-
}
2660-
ProgramPtr BuiltProgram =
2661-
build(std::move(ProgramManaged), ContextImpl, CompileOpts, LinkOpts,
2662-
URDevices, DeviceLibReqMask, ExtraProgramsToLink);
2663-
2664-
emitBuiltProgramInfo(BuiltProgram.get(), ContextImpl);
2665-
2666-
{
2667-
std::lock_guard<std::mutex> Lock(MNativeProgramsMutex);
2668-
NativePrograms.insert({BuiltProgram.get(), &Img});
2669-
}
2670-
2671-
ContextImpl->addDeviceGlobalInitializer(BuiltProgram.get(), Devs, &Img);
2672-
2673-
// Save program to persistent cache if it is not there
2674-
if (!DeviceCodeWasInCache)
2675-
PersistentDeviceCodeCache::putItemToDisc(
2676-
Devs, {&Img}, SpecConsts, CompileOpts + LinkOpts, BuiltProgram.get());
2677-
2678-
return BuiltProgram.release();
2679-
};
2680-
2681-
if (!SYCLConfig<SYCL_CACHE_IN_MEM>::get()) {
2682-
auto ResProgram = BuildF();
2683-
DeviceImageImplPtr ExecImpl = std::make_shared<detail::device_image_impl>(
2684-
InputImpl->get_bin_image_ref(), Context, Devs, bundle_state::executable,
2685-
InputImpl->get_kernel_ids_ptr(), ResProgram,
2686-
InputImpl->get_spec_const_data_ref(),
2687-
InputImpl->get_spec_const_blob_ref());
2688-
2689-
return createSyclObjFromImpl<device_image_plain>(ExecImpl);
2690-
}
2691-
2692-
uint32_t ImgId = Img.getImageID();
2693-
std::set<ur_device_handle_t> URDevicesSet;
2694-
std::transform(Devs.begin(), Devs.end(),
2695-
std::inserter(URDevicesSet, URDevicesSet.begin()),
2696-
[](const device &Dev) {
2697-
return getSyclObjImpl(Dev).get()->getHandleRef();
2698-
});
2699-
auto CacheKey = std::make_pair(std::make_pair(std::move(SpecConsts), ImgId),
2700-
URDevicesSet);
2701-
2702-
// CacheKey is captured by reference so when we overwrite it later we can
2703-
// reuse this function.
2704-
auto GetCachedBuildF = [&Cache, &CacheKey]() {
2705-
return Cache.getOrInsertProgram(CacheKey);
2706-
};
2707-
2708-
auto BuildResult = Cache.getOrBuild<errc::build>(GetCachedBuildF, BuildF);
2709-
// getOrBuild is not supposed to return nullptr
2710-
assert(BuildResult != nullptr && "Invalid build result");
2711-
2712-
ur_program_handle_t ResProgram = BuildResult->Val;
2713-
2714-
// Here we have multiple devices a program is built for, so add the program to
2715-
// the cache for all subsets of provided list of devices.
2716-
const AdapterPtr &Adapter = ContextImpl->getAdapter();
2717-
auto CacheSubsets = [ResProgram, &Adapter]() {
2718-
Adapter->call<UrApiKind::urProgramRetain>(ResProgram);
2719-
return ResProgram;
2720-
};
2721-
2722-
if (URDevicesSet.size() > 1) {
2723-
// emplace all subsets of the current set of devices into the cache.
2724-
// Set of all devices is not included in the loop as it was already added
2725-
// into the cache.
2726-
for (int Mask = 1; Mask < (1 << URDevicesSet.size()) - 1; ++Mask) {
2727-
std::set<ur_device_handle_t> Subset;
2728-
int Index = 0;
2729-
for (auto It = URDevicesSet.begin(); It != URDevicesSet.end();
2730-
++It, ++Index) {
2731-
if (Mask & (1 << Index)) {
2732-
Subset.insert(*It);
2733-
}
2734-
}
2735-
// Change device in the cache key to reduce copying of spec const data.
2736-
CacheKey.second = Subset;
2737-
Cache.getOrBuild<errc::build>(GetCachedBuildF, CacheSubsets);
2738-
// getOrBuild is not supposed to return nullptr
2739-
assert(BuildResult != nullptr && "Invalid build result");
2740-
}
2741-
}
2742-
2743-
// devive_image_impl shares ownership of PIProgram with, at least, program
2744-
// cache. The ref counter will be descremented in the destructor of
2745-
// device_image_impl
2746-
Adapter->call<UrApiKind::urProgramRetain>(ResProgram);
2747-
2654+
// TODO: Add support for dynamic linking with kernel bundles
2655+
ur_program_handle_t ResProgram =
2656+
getBuiltURProgram(Img, Context, Devs, /*DeviceImagesToLink*/ {}, {&Img},
2657+
InputImpl, SpecConsts);
27482658
DeviceImageImplPtr ExecImpl = std::make_shared<detail::device_image_impl>(
27492659
InputImpl->get_bin_image_ref(), Context, Devs, bundle_state::executable,
27502660
InputImpl->get_kernel_ids_ptr(), ResProgram,
27512661
InputImpl->get_spec_const_data_ref(),
27522662
InputImpl->get_spec_const_blob_ref());
2753-
27542663
return createSyclObjFromImpl<device_image_plain>(ExecImpl);
27552664
}
27562665

0 commit comments

Comments
 (0)