Skip to content

Commit 98a7e5c

Browse files
committed
support bfloat16 native spv
Signed-off-by: jinge90 <[email protected]>
1 parent c310a45 commit 98a7e5c

File tree

5 files changed

+89
-49
lines changed

5 files changed

+89
-49
lines changed

llvm/include/llvm/SYCLLowerIR/SYCLRequiredDeviceLibs.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ enum class DeviceLibExt : std::uint32_t {
4040
cl_intel_devicelib_bfloat16,
4141
};
4242

43+
enum class DeviceLibIsNative : std::uint32_t { Yes, No, Ignore };
44+
45+
struct SYCLDeviceLibSPVMeta {
46+
DeviceLibExt SPVExt;
47+
const char *SPVFileName;
48+
DeviceLibIsNative IsNative;
49+
};
50+
4351
struct SYCLDeviceLibSPVBinary {
4452
typedef uint8_t value_type;
4553
value_type *SPVRawBytes;
@@ -53,8 +61,6 @@ struct SYCLDeviceLibSPVBinary {
5361
};
5462

5563
void getRequiredSYCLDeviceLibs(const Module &M,
56-
SmallVector<DeviceLibExt, 16> &ReqLibs);
57-
58-
const char *getDeviceLibFileName(DeviceLibExt RequiredDeviceLibExt);
64+
SmallVector<SYCLDeviceLibSPVMeta, 16> &ReqLibs);
5965

6066
} // namespace llvm

llvm/lib/SYCLLowerIR/ComputeModuleRuntimeInfo.cpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,11 @@ PropSetRegTy computeModuleProperties(const Module &M,
171171
// If sycl-post-link doesn't specify a valid fallback spv path, the param
172172
// 'DeviceLibSPVLoc' is set to an empty string.
173173
if (!DeviceLibSPVLoc.empty()) {
174-
SmallVector<llvm::DeviceLibExt, 16> RequiredLibs;
174+
SmallVector<llvm::SYCLDeviceLibSPVMeta, 16> RequiredLibs;
175175
llvm::getRequiredSYCLDeviceLibs(M, RequiredLibs);
176-
for (auto Ext : RequiredLibs) {
177-
const char *SPVFileName = llvm::getDeviceLibFileName(Ext);
176+
for (auto ExtMeta : RequiredLibs) {
178177
std::string SPVPath =
179-
DeviceLibSPVLoc.str() + "/" + std::string(SPVFileName);
178+
DeviceLibSPVLoc.str() + "/" + std::string(ExtMeta.SPVFileName);
180179
if (!llvm::sys::fs::exists(SPVPath))
181180
continue;
182181

@@ -188,31 +187,27 @@ PropSetRegTy computeModuleProperties(const Module &M,
188187
// std::aligned_alloc is not available in some pre-ci Windows machine.
189188
#if defined(_WIN32) || defined(_WIN64)
190189
uint8_t *SPVBuffer = reinterpret_cast<uint8_t *>(
191-
_aligned_malloc(alignof(uint32_t), SPVSize + sizeof(uint32_t)));
190+
_aligned_malloc(alignof(uint32_t), SPVSize + sizeof(uint32_t) * 2));
192191
#else
193-
uint8_t *SPVBuffer = reinterpret_cast<uint8_t *>(
194-
std::aligned_alloc(alignof(uint32_t), SPVSize + sizeof(uint32_t)));
192+
uint8_t *SPVBuffer = reinterpret_cast<uint8_t *>(std::aligned_alloc(
193+
alignof(uint32_t), SPVSize + sizeof(uint32_t) * 2));
195194
#endif
196195

197196
if (!SPVBuffer)
198197
continue;
199198

200-
// The data embedded consists of 2 parts, first 4 bytes are corresponding
201-
// DeivceLib extension and the following bytes are raw data of fallback
202-
// spv files. There is 1 exception for native bfloat16 spv, it is used
203-
// to support native bfloat16 conversions on some devices and it doesn't
204-
// fully comply to fallback device library mechanism, the extension
205-
// 'cl_intel_devicelib_bfloat16' corresponds to 2 fallback spvs: native
206-
// version used for devices which supports native bfloat16 conversion and
207-
// generic version for all other devices, so we have to embed 1 one field
208-
// to distinguish.
209-
*(reinterpret_cast<uint32_t *>(SPVBuffer)) = static_cast<uint32_t>(Ext);
210-
size_t RawSPVOffset = sizeof(uint32_t);
211-
std::memcpy(SPVBuffer + RawSPVOffset, (*SPVMB)->getBufferStart(),
199+
// The data embedded consists of 3 parts, overall layout is following:
200+
// |--devicelib ext(4 byte)--|--IsNative Flag(4 byte)--|--spv raw data--|
201+
*(reinterpret_cast<uint32_t *>(SPVBuffer)) =
202+
static_cast<uint32_t>(ExtMeta.SPVExt);
203+
204+
*(reinterpret_cast<uint32_t *>(SPVBuffer + sizeof(uint32_t))) =
205+
static_cast<uint32_t>(ExtMeta.IsNative);
206+
std::memcpy(SPVBuffer + sizeof(uint32_t) * 2, (*SPVMB)->getBufferStart(),
212207
SPVSize);
213208
llvm::SYCLDeviceLibSPVBinary SPVBinaryObj(SPVBuffer,
214-
SPVSize + sizeof(uint32_t));
215-
PropSet.add(PropSetRegTy::SYCL_DEVICELIB_REQ_BINS, SPVFileName,
209+
SPVSize + sizeof(uint32_t) * 2);
210+
PropSet.add(PropSetRegTy::SYCL_DEVICELIB_REQ_BINS, ExtMeta.SPVFileName,
216211
SPVBinaryObj);
217212
#if defined(_WIN32) || defined(_WIN64)
218213
_aligned_free(SPVBuffer);

llvm/lib/SYCLLowerIR/SYCLRequiredDeviceLibs.cpp

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -735,32 +735,52 @@ SYCLDeviceLibFuncMap SDLMap = {
735735

736736
} // namespace
737737

738-
// Each fallback device library corresponds to one SPV file whose name is kept
739-
// in DeviceLibSPVExtMap.
740-
static std::unordered_map<DeviceLibExt, const char *> DeviceLibSPVExtMap = {
741-
{DeviceLibExt::cl_intel_devicelib_assert, "libsycl-fallback-cassert.spv"},
742-
{DeviceLibExt::cl_intel_devicelib_math, "libsycl-fallback-cmath.spv"},
738+
// One devicelib extension may correspond to multiple spv files, following
739+
// map stores corresponding index values in SPVMetaList for an extension.
740+
static std::unordered_map<DeviceLibExt, std::vector<size_t>>
741+
DeviceLibSPVExtMap = {{DeviceLibExt::cl_intel_devicelib_assert, {0}},
742+
{DeviceLibExt::cl_intel_devicelib_math, {1}},
743+
{DeviceLibExt::cl_intel_devicelib_math_fp64, {2}},
744+
{DeviceLibExt::cl_intel_devicelib_complex, {3}},
745+
{DeviceLibExt::cl_intel_devicelib_complex_fp64, {4}},
746+
{DeviceLibExt::cl_intel_devicelib_cstring, {5}},
747+
{DeviceLibExt::cl_intel_devicelib_imf, {6}},
748+
{DeviceLibExt::cl_intel_devicelib_imf_fp64, {7}},
749+
{DeviceLibExt::cl_intel_devicelib_imf_bf16, {8}},
750+
{DeviceLibExt::cl_intel_devicelib_bfloat16, {9, 10}}};
751+
752+
static SYCLDeviceLibSPVMeta SPVMetaList[] = {
753+
{DeviceLibExt::cl_intel_devicelib_assert, "libsycl-fallback-cassert.spv",
754+
DeviceLibIsNative::Ignore},
755+
{DeviceLibExt::cl_intel_devicelib_math, "libsycl-fallback-cmath.spv",
756+
DeviceLibIsNative::Ignore},
743757
{DeviceLibExt::cl_intel_devicelib_math_fp64,
744-
"libsycl-fallback-cmath-fp64.spv"},
745-
{DeviceLibExt::cl_intel_devicelib_complex, "libsycl-fallback-complex.spv"},
758+
"libsycl-fallback-cmath-fp64.spv", DeviceLibIsNative::Ignore},
759+
{DeviceLibExt::cl_intel_devicelib_complex, "libsycl-fallback-complex.spv",
760+
DeviceLibIsNative::Ignore},
746761
{DeviceLibExt::cl_intel_devicelib_complex_fp64,
747-
"libsycl-fallback-complex-fp64.spv"},
748-
{DeviceLibExt::cl_intel_devicelib_cstring, "libsycl-fallback-cstring.spv"},
749-
{DeviceLibExt::cl_intel_devicelib_imf, "libsycl-fallback-imf.spv"},
750-
{DeviceLibExt::cl_intel_devicelib_imf_fp64,
751-
"libsycl-fallback-imf-fp64.spv"},
752-
{DeviceLibExt::cl_intel_devicelib_imf_bf16,
753-
"libsycl-fallback-imf-bf16.spv"},
754-
{DeviceLibExt::cl_intel_devicelib_bfloat16,
755-
"libsycl-fallback-bfloat16.spv"}};
762+
"libsycl-fallback-complex-fp64.spv", DeviceLibIsNative::Ignore},
763+
{DeviceLibExt::cl_intel_devicelib_cstring, "libsycl-fallback-cstring.spv",
764+
DeviceLibIsNative::Ignore},
765+
{DeviceLibExt::cl_intel_devicelib_imf, "libsycl-fallback-imf.spv",
766+
DeviceLibIsNative::Ignore},
767+
{DeviceLibExt::cl_intel_devicelib_imf_fp64, "libsycl-fallback-imf-fp64.spv",
768+
DeviceLibIsNative::Ignore},
769+
{DeviceLibExt::cl_intel_devicelib_imf_bf16, "libsycl-fallback-imf-bf16.spv",
770+
DeviceLibIsNative::Ignore},
771+
{DeviceLibExt::cl_intel_devicelib_bfloat16, "libsycl-fallback-bfloat16.spv",
772+
DeviceLibIsNative::No},
773+
{DeviceLibExt::cl_intel_devicelib_bfloat16, "libsycl-native-bfloat16.spv",
774+
DeviceLibIsNative::Yes}};
756775

757776
namespace llvm {
758777
// For each device image module, we go through all functions which meets
759778
// 1. The function name has prefix "__devicelib_"
760779
// 2. The function is declaration which means it doesn't have function body
761780
// And we don't expect non-spirv functions with "__devicelib_" prefix.
762781
void getRequiredSYCLDeviceLibs(
763-
const Module &M, llvm::SmallVector<DeviceLibExt, 16> &ReqDeviceLibs) {
782+
const Module &M,
783+
llvm::SmallVector<SYCLDeviceLibSPVMeta, 16> &ReqDeviceLibs) {
764784
// Device libraries will be enabled only for spir-v module.
765785
if (!Triple(M.getTargetTriple()).isSPIROrSPIRV())
766786
return;
@@ -776,12 +796,10 @@ void getRequiredSYCLDeviceLibs(
776796
continue;
777797

778798
DeviceLibUsed.insert(DeviceLibFuncIter->second);
779-
ReqDeviceLibs.push_back(DeviceLibFuncIter->second);
799+
for (size_t idx : DeviceLibSPVExtMap[DeviceLibFuncIter->second]) {
800+
ReqDeviceLibs.push_back(SPVMetaList[idx]);
801+
}
780802
}
781803
}
782804
}
783-
784-
const char *getDeviceLibFileName(DeviceLibExt RequiredDeviceLibExt) {
785-
return DeviceLibSPVExtMap[RequiredDeviceLibExt];
786-
}
787805
} // namespace llvm

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,6 +1734,12 @@ getDeviceLibPrograms(const ContextImplPtr Context,
17341734
!fp64Support)
17351735
continue;
17361736

1737+
DeviceLibByteArray.dropBytes(4);
1738+
DeviceLibIsNative IsNativeSPV = static_cast<DeviceLibIsNative>(
1739+
(static_cast<uint32_t>(DeviceLibByteArray[3]) << 24) |
1740+
(static_cast<uint32_t>(DeviceLibByteArray[2]) << 16) |
1741+
(static_cast<uint32_t>(DeviceLibByteArray[1]) << 8) |
1742+
DeviceLibByteArray[0]);
17371743
auto DeviceLibExtReqName = getDeviceLibExtensionStr(DeviceLibExtReq);
17381744
bool InhibitNativeImpl = false;
17391745
if (const char *Env = getenv("SYCL_DEVICELIB_INHIBIT_NATIVE")) {
@@ -1746,11 +1752,24 @@ getDeviceLibPrograms(const ContextImplPtr Context,
17461752
// 1. underlying device doesn't support the extension
17471753
// 2. user explicitly ask to inhibit usage of native support
17481754
if (!ExtReqAvailable || InhibitNativeImpl) {
1755+
if (IsNativeSPV == DeviceLibIsNative::Yes)
1756+
continue;
17491757
DeviceLibByteArray.dropBytes(4);
17501758
Programs.push_back(loadDeviceLibFallback(
17511759
Context, DeviceLibExtReq, Devices,
17521760
/*UseNativeLib=*/false, false, DeviceLibByteArray.begin(),
17531761
DeviceLibByteArray.size()));
1762+
} else {
1763+
// bfloat16 spv has native and generic version, if native support is
1764+
// available in underlying device, we should use native version and
1765+
// ignore generic version.
1766+
if (IsNativeSPV != DeviceLibIsNative::Yes)
1767+
continue;
1768+
DeviceLibByteArray.dropBytes(4);
1769+
Programs.push_back(loadDeviceLibFallback(
1770+
Context, DeviceLibExtReq, Devices,
1771+
/*UseNativeLib=*/true, false, DeviceLibByteArray.begin(),
1772+
DeviceLibByteArray.size()));
17541773
}
17551774
}
17561775
}

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ class device_impl;
7171
using DeviceImplPtr = std::shared_ptr<device_impl>;
7272
class queue_impl;
7373
class event_impl;
74-
// DeviceLibExt is shared between sycl runtime and sycl-post-link tool.
75-
// If any update is made here, need to sync with DeviceLibExt definition
76-
// in llvm/tools/sycl-post-link/sycl-post-link.cpp
74+
// DeviceLibExt and DeviceLibIsNaitve are shared between sycl runtime and
75+
// SYCL Post Link tool. If any update is made here, please sync with definition
76+
// in llvm/llvm/include/llvm/SYCLLowerIR/SYCLRequiredDeviceLibs.h
7777
enum class DeviceLibExt : std::uint32_t {
7878
cl_intel_devicelib_assert,
7979
cl_intel_devicelib_math,
@@ -87,6 +87,8 @@ enum class DeviceLibExt : std::uint32_t {
8787
cl_intel_devicelib_bfloat16,
8888
};
8989

90+
enum class DeviceLibIsNative : std::uint32_t {Yes, No, Ignore};
91+
9092
// Provides single loading and building OpenCL programs with unique contexts
9193
// that is necessary for no interoperability cases with lambda.
9294
class ProgramManager {

0 commit comments

Comments
 (0)