Skip to content

Commit 4585556

Browse files
Merge branch 'sycl' into fixdevicekernelinfo
2 parents 830b562 + 61de220 commit 4585556

File tree

13 files changed

+368
-61
lines changed

13 files changed

+368
-61
lines changed

clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,6 +1293,48 @@ class BinaryWrapper {
12931293
appendToGlobalDtors(M, Func, /*Priority*/ 1);
12941294
}
12951295

1296+
void createSyclRegisterWithAtexitUnregister(GlobalVariable *BinDesc) {
1297+
auto *UnregFuncTy =
1298+
FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
1299+
auto *UnregFunc =
1300+
Function::Create(UnregFuncTy, GlobalValue::InternalLinkage,
1301+
"sycl.descriptor_unreg.atexit", &M);
1302+
UnregFunc->setSection(".text.startup");
1303+
1304+
// Declaration for __sycl_unregister_lib(void*).
1305+
auto *UnregTargetTy =
1306+
FunctionType::get(Type::getVoidTy(C), getPtrTy(), /*isVarArg=*/false);
1307+
FunctionCallee UnregTargetC =
1308+
M.getOrInsertFunction("__sycl_unregister_lib", UnregTargetTy);
1309+
1310+
IRBuilder<> UnregBuilder(BasicBlock::Create(C, "entry", UnregFunc));
1311+
UnregBuilder.CreateCall(UnregTargetC, BinDesc);
1312+
UnregBuilder.CreateRetVoid();
1313+
1314+
auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
1315+
auto *RegFunc = Function::Create(RegFuncTy, GlobalValue::InternalLinkage,
1316+
"sycl.descriptor_reg", &M);
1317+
RegFunc->setSection(".text.startup");
1318+
1319+
auto *RegTargetTy =
1320+
FunctionType::get(Type::getVoidTy(C), getPtrTy(), false);
1321+
FunctionCallee RegTargetC =
1322+
M.getOrInsertFunction("__sycl_register_lib", RegTargetTy);
1323+
1324+
// `atexit` takes a `void(*)()` function pointer arg and returns an i32.
1325+
FunctionType *AtExitTy =
1326+
FunctionType::get(Type::getInt32Ty(C), getPtrTy(), false);
1327+
FunctionCallee AtExitC = M.getOrInsertFunction("atexit", AtExitTy);
1328+
1329+
IRBuilder<> RegBuilder(BasicBlock::Create(C, "entry", RegFunc));
1330+
RegBuilder.CreateCall(RegTargetC, BinDesc);
1331+
RegBuilder.CreateCall(AtExitC, UnregFunc);
1332+
RegBuilder.CreateRetVoid();
1333+
1334+
// Add this function to global destructors.
1335+
appendToGlobalCtors(M, RegFunc, /*Priority*/ 1);
1336+
}
1337+
12961338
public:
12971339
BinaryWrapper(StringRef Target, StringRef ToolName,
12981340
StringRef SymPropBCFiles = "")
@@ -1370,8 +1412,13 @@ class BinaryWrapper {
13701412

13711413
if (EmitRegFuncs) {
13721414
GlobalVariable *Desc = *DescOrErr;
1373-
createRegisterFunction(Kind, Desc);
1374-
createUnregisterFunction(Kind, Desc);
1415+
if (Kind == OffloadKind::SYCL &&
1416+
Triple(M.getTargetTriple()).isOSWindows()) {
1417+
createSyclRegisterWithAtexitUnregister(Desc);
1418+
} else {
1419+
createRegisterFunction(Kind, Desc);
1420+
createUnregisterFunction(Kind, Desc);
1421+
}
13751422
}
13761423
}
13771424
return &M;

llvm/lib/Frontend/Offloading/SYCLOffloadWrapper.cpp

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "llvm/Support/ErrorHandling.h"
3535
#include "llvm/Support/LineIterator.h"
3636
#include "llvm/Support/PropertySetIO.h"
37+
#include "llvm/TargetParser/Triple.h"
3738
#include "llvm/Transforms/Utils/ModuleUtils.h"
3839
#include <memory>
3940
#include <string>
@@ -734,6 +735,50 @@ struct Wrapper {
734735
// Add this function to global destructors.
735736
appendToGlobalDtors(M, Func, /*Priority*/ 1);
736737
}
738+
739+
void createSyclRegisterWithAtexitUnregister(GlobalVariable *FatbinDesc) {
740+
auto *UnregFuncTy =
741+
FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
742+
auto *UnregFunc =
743+
Function::Create(UnregFuncTy, GlobalValue::InternalLinkage,
744+
"sycl.descriptor_unreg.atexit", &M);
745+
UnregFunc->setSection(".text.startup");
746+
747+
// Declaration for __sycl_unregister_lib(void*).
748+
auto *UnregTargetTy =
749+
FunctionType::get(Type::getVoidTy(C), PointerType::getUnqual(C), false);
750+
FunctionCallee UnregTargetC =
751+
M.getOrInsertFunction("__sycl_unregister_lib", UnregTargetTy);
752+
753+
// Body of the unregister wrapper.
754+
IRBuilder<> UnregBuilder(BasicBlock::Create(C, "entry", UnregFunc));
755+
UnregBuilder.CreateCall(UnregTargetC, FatbinDesc);
756+
UnregBuilder.CreateRetVoid();
757+
758+
auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
759+
auto *RegFunc = Function::Create(RegFuncTy, GlobalValue::InternalLinkage,
760+
"sycl.descriptor_reg", &M);
761+
RegFunc->setSection(".text.startup");
762+
763+
auto *RegTargetTy =
764+
FunctionType::get(Type::getVoidTy(C), PointerType::getUnqual(C), false);
765+
FunctionCallee RegTargetC =
766+
M.getOrInsertFunction("__sycl_register_lib", RegTargetTy);
767+
768+
// `atexit` takes a `void(*)()` function pointer arg and returns an i32.
769+
FunctionType *AtExitTy = FunctionType::get(
770+
Type::getInt32Ty(C), PointerType::getUnqual(C), false);
771+
FunctionCallee AtExitC = M.getOrInsertFunction("atexit", AtExitTy);
772+
773+
IRBuilder<> RegBuilder(BasicBlock::Create(C, "entry", RegFunc));
774+
RegBuilder.CreateCall(RegTargetC, FatbinDesc);
775+
RegBuilder.CreateCall(AtExitC, UnregFunc);
776+
RegBuilder.CreateRetVoid();
777+
778+
// Finally, add to global constructors.
779+
appendToGlobalCtors(M, RegFunc, /*Priority*/ 1);
780+
}
781+
737782
}; // end of Wrapper
738783

739784
} // anonymous namespace
@@ -747,7 +792,11 @@ Error llvm::offloading::wrapSYCLBinaries(llvm::Module &M,
747792
return createStringError(inconvertibleErrorCode(),
748793
"No binary descriptors created.");
749794

750-
W.createRegisterFatbinFunction(Desc);
751-
W.createUnregisterFunction(Desc);
795+
if (Triple(M.getTargetTriple()).isOSWindows()) {
796+
W.createSyclRegisterWithAtexitUnregister(Desc);
797+
} else {
798+
W.createRegisterFatbinFunction(Desc);
799+
W.createUnregisterFunction(Desc);
800+
}
752801
return Error::success();
753802
}

sycl/include/sycl/handler.hpp

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,25 @@ class __SYCL_EXPORT handler {
491491
"a single kernel or explicit memory operation.");
492492
}
493493

494+
template <class Kernel> void setDeviceKernelInfo(void *KernelFuncPtr) {
495+
constexpr auto Info = detail::CompileTimeKernelInfo<Kernel>;
496+
MKernelName = Info.Name;
497+
// TODO support ESIMD in no-integration-header case too.
498+
setKernelInfo(KernelFuncPtr, Info.NumParams, Info.ParamDescGetter,
499+
Info.IsESIMD, Info.HasSpecialCaptures);
500+
setDeviceKernelInfoPtr(&detail::getDeviceKernelInfo<Kernel>());
501+
setType(detail::CGType::Kernel);
502+
}
503+
504+
void setDeviceKernelInfo(kernel &&Kernel) {
505+
MKernel = detail::getSyclObjImpl(std::move(Kernel));
506+
MKernelName = getKernelName();
507+
setType(detail::CGType::Kernel);
508+
509+
// If any extra actions are added here make sure that logic around
510+
// `lambdaAndKernelHaveEqualName` calls can handle that.
511+
}
512+
494513
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
495514
// TODO: Those functions are not used anymore, remove it in the next
496515
// ABI-breaking window.
@@ -823,7 +842,6 @@ class __SYCL_EXPORT handler {
823842
detail::GetInstantiateKernelOnHostPtr<KernelType, LambdaArgType,
824843
Dims>());
825844
#endif
826-
constexpr auto Info = detail::CompileTimeKernelInfo<KernelName>;
827845

828846
// SYCL unittests are built without sycl compiler, so "host" information
829847
// about kernels isn't provided (e.g., via integration headers or compiler
@@ -836,6 +854,8 @@ class __SYCL_EXPORT handler {
836854
// don't actually execute those operation, that's why disabling
837855
// unconditional `static_asserts`s is enough for now.
838856
#ifndef __SYCL_UNITTESTS_BYPASS_KERNEL_NAME_CHECK
857+
constexpr auto Info = detail::CompileTimeKernelInfo<KernelName>;
858+
839859
static_assert(Info.Name != std::string_view{}, "Kernel must have a name!");
840860

841861
// Some host compilers may have different captures from Clang. Currently
@@ -855,12 +875,7 @@ class __SYCL_EXPORT handler {
855875
"might also help.");
856876
#endif
857877

858-
// Force hasSpecialCaptures to be evaluated at compile-time.
859-
setKernelInfo((void *)MHostKernel->getPtr(), Info.NumParams,
860-
Info.ParamDescGetter, Info.IsESIMD, Info.HasSpecialCaptures);
861-
862-
MKernelName = Info.Name;
863-
setDeviceKernelInfoPtr(&detail::getDeviceKernelInfo<KernelName>());
878+
setDeviceKernelInfo<KernelName>((void *)MHostKernel->getPtr());
864879

865880
// If the kernel lambda is callable with a kernel_handler argument, manifest
866881
// the associated kernel handler.
@@ -1302,7 +1317,6 @@ class __SYCL_EXPORT handler {
13021317
setNDRangeDescriptor(RoundedRange);
13031318
StoreLambda<KName, decltype(Wrapper), Dims, TransformedArgType>(
13041319
std::move(Wrapper));
1305-
setType(detail::CGType::Kernel);
13061320
#endif
13071321
} else
13081322
#endif // !__SYCL_DISABLE_PARALLEL_FOR_RANGE_ROUNDING__ &&
@@ -1324,7 +1338,6 @@ class __SYCL_EXPORT handler {
13241338
setNDRangeDescriptor(std::move(UserRange));
13251339
StoreLambda<NameT, KernelType, Dims, TransformedArgType>(
13261340
std::move(KernelFunc));
1327-
setType(detail::CGType::Kernel);
13281341
#endif
13291342
#else
13301343
(void)KernelFunc;
@@ -1346,13 +1359,11 @@ class __SYCL_EXPORT handler {
13461359
[[maybe_unused]] kernel Kernel) {
13471360
#ifndef __SYCL_DEVICE_ONLY__
13481361
throwIfActionIsCreated();
1349-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
1362+
setDeviceKernelInfo(std::move(Kernel));
13501363
detail::checkValueRange<Dims>(NumWorkItems);
13511364
setNDRangeDescriptor(std::move(NumWorkItems));
13521365
processLaunchProperties<PropertiesT>(Props);
1353-
setType(detail::CGType::Kernel);
13541366
extractArgsAndReqs();
1355-
MKernelName = getKernelName();
13561367
#endif
13571368
}
13581369

@@ -1371,13 +1382,11 @@ class __SYCL_EXPORT handler {
13711382
[[maybe_unused]] kernel Kernel) {
13721383
#ifndef __SYCL_DEVICE_ONLY__
13731384
throwIfActionIsCreated();
1374-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
1385+
setDeviceKernelInfo(std::move(Kernel));
13751386
detail::checkValueRange<Dims>(NDRange);
13761387
setNDRangeDescriptor(std::move(NDRange));
13771388
processLaunchProperties(Props);
1378-
setType(detail::CGType::Kernel);
13791389
extractArgsAndReqs();
1380-
MKernelName = getKernelName();
13811390
#endif
13821391
}
13831392

@@ -1404,7 +1413,6 @@ class __SYCL_EXPORT handler {
14041413
}
14051414
throwIfActionIsCreated();
14061415
verifyUsedKernelBundleInternal(Info.Name);
1407-
setType(detail::CGType::Kernel);
14081416

14091417
detail::checkValueRange<Dims>(params...);
14101418
if constexpr (SetNumWorkGroups) {
@@ -1450,7 +1458,6 @@ class __SYCL_EXPORT handler {
14501458
// kernel.
14511459
setHandlerKernelBundle(Kernel);
14521460
verifyUsedKernelBundleInternal(Info.Name);
1453-
setType(detail::CGType::Kernel);
14541461

14551462
detail::checkValueRange<Dims>(params...);
14561463
if constexpr (SetNumWorkGroups) {
@@ -1460,12 +1467,11 @@ class __SYCL_EXPORT handler {
14601467
setNDRangeDescriptor(std::move(params)...);
14611468
}
14621469

1463-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
1464-
if (!lambdaAndKernelHaveEqualName<NameT>()) {
1465-
extractArgsAndReqs();
1466-
MKernelName = getKernelName();
1467-
} else {
1470+
setDeviceKernelInfo(std::move(Kernel));
1471+
if (lambdaAndKernelHaveEqualName<NameT>()) {
14681472
StoreLambda<NameT, KernelType, Dims, ElementType>(std::move(KernelFunc));
1473+
} else {
1474+
extractArgsAndReqs();
14691475
}
14701476
processProperties<Info.IsESIMD, PropertiesT>(Props);
14711477
#endif
@@ -1845,10 +1851,8 @@ class __SYCL_EXPORT handler {
18451851
// No need to check if range is out of INT_MAX limits as it's compile-time
18461852
// known constant
18471853
setNDRangeDescriptor(range<1>{1});
1848-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
1849-
setType(detail::CGType::Kernel);
1854+
setDeviceKernelInfo(std::move(Kernel));
18501855
extractArgsAndReqs();
1851-
MKernelName = getKernelName();
18521856
}
18531857

18541858
void parallel_for(range<1> NumWorkItems, kernel Kernel) {
@@ -1881,12 +1885,10 @@ class __SYCL_EXPORT handler {
18811885
[[maybe_unused]] kernel Kernel) {
18821886
#ifndef __SYCL_DEVICE_ONLY__
18831887
throwIfActionIsCreated();
1884-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
1888+
setDeviceKernelInfo(std::move(Kernel));
18851889
detail::checkValueRange<Dims>(NumWorkItems, WorkItemOffset);
18861890
setNDRangeDescriptor(std::move(NumWorkItems), std::move(WorkItemOffset));
1887-
setType(detail::CGType::Kernel);
18881891
extractArgsAndReqs();
1889-
MKernelName = getKernelName();
18901892
#endif
18911893
}
18921894

@@ -1924,13 +1926,12 @@ class __SYCL_EXPORT handler {
19241926
// No need to check if range is out of INT_MAX limits as it's compile-time
19251927
// known constant
19261928
setNDRangeDescriptor(range<1>{1});
1927-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
1928-
setType(detail::CGType::Kernel);
1929-
if (!lambdaAndKernelHaveEqualName<NameT>()) {
1930-
extractArgsAndReqs();
1931-
MKernelName = getKernelName();
1932-
} else
1929+
setDeviceKernelInfo(std::move(Kernel));
1930+
if (lambdaAndKernelHaveEqualName<NameT>()) {
19331931
StoreLambda<NameT, KernelType, /*Dims*/ 1, void>(std::move(KernelFunc));
1932+
} else {
1933+
extractArgsAndReqs();
1934+
}
19341935
#else
19351936
detail::CheckDeviceCopyable<KernelType>();
19361937
#endif

sycl/source/detail/context_impl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ context_impl::~context_impl() {
125125
DeviceGlobalMapEntry *DGEntry =
126126
detail::ProgramManager::getInstance().getDeviceGlobalEntry(
127127
DeviceGlobal);
128-
DGEntry->removeAssociatedResources(this);
128+
if (DGEntry != nullptr)
129+
DGEntry->removeAssociatedResources(this);
129130
}
130131
MCachedLibPrograms.clear();
131132
// TODO catch an exception and put it to list of asynchronous exceptions

sycl/source/detail/device_global_map.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class DeviceGlobalMap {
9494
});
9595
if (findDevGlobalByValue != MPtr2DeviceGlobal.end())
9696
MPtr2DeviceGlobal.erase(findDevGlobalByValue);
97+
9798
MDeviceGlobals.erase(DevGlobalIt);
9899
}
99100
}
@@ -119,8 +120,7 @@ class DeviceGlobalMap {
119120
DeviceGlobalMapEntry *getEntry(const void *DeviceGlobalPtr) {
120121
std::lock_guard<std::mutex> DeviceGlobalsGuard(MDeviceGlobalsMutex);
121122
auto Entry = MPtr2DeviceGlobal.find(DeviceGlobalPtr);
122-
assert(Entry != MPtr2DeviceGlobal.end() && "Device global entry not found");
123-
return Entry->second;
123+
return (Entry != MPtr2DeviceGlobal.end()) ? Entry->second : nullptr;
124124
}
125125

126126
DeviceGlobalMapEntry *

0 commit comments

Comments
 (0)