Skip to content

Commit 4febb5f

Browse files
author
Dmitry Sidorov
committed
[Bacport to 21] Implement SPV_INTEL_bfloat16_arithmetic (#3290)
The extension relaxes rules for bf16 type allowing to use it in some arithmetic operations. Spec is available here: intel/llvm#18352 Co-authered by: Michael Aziz <michael.aziz@intel.com> --------- Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
1 parent b50fc5c commit 4febb5f

File tree

9 files changed

+314
-0
lines changed

9 files changed

+314
-0
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ EXT(SPV_INTEL_bindless_images)
7777
EXT(SPV_INTEL_2d_block_io)
7878
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate)
7979
EXT(SPV_KHR_bfloat16)
80+
EXT(SPV_INTEL_bfloat16_arithmetic)
8081
EXT(SPV_INTEL_ternary_bitwise_function)
8182
EXT(SPV_INTEL_int4)
8283
EXT(SPV_INTEL_function_variants)

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,11 @@ ParamType lastFuncParamType(StringRef MangledName) {
525525
char Mangled = Copy.back();
526526
std::string Mangled2 = Copy.substr(Copy.size() - 2);
527527

528+
std::string Mangled6 = Copy.substr(Copy.size() - 6);
529+
if (Mangled6 == "__bf16") {
530+
return ParamType::FLOAT;
531+
}
532+
528533
if (isMangledTypeFP(Mangled) || isMangledTypeHalf(Mangled2)) {
529534
return ParamType::FLOAT;
530535
} else if (isMangledTypeUnsigned(Mangled)) {
@@ -1913,6 +1918,9 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
19131918
NumElems = VecTy->getNumElements();
19141919
Ty = VecTy->getElementType();
19151920
}
1921+
if (Ty->isBFloatTy() &&
1922+
BM->hasCapability(internal::CapabilityBFloat16ArithmeticINTEL))
1923+
return true;
19161924
if ((!Ty->isFloatTy() && !Ty->isDoubleTy() && !Ty->isHalfTy()) ||
19171925
(!BM->hasCapability(CapabilityVectorAnyINTEL) &&
19181926
((NumElems > 4) && (NumElems != 8) && (NumElems != 16)))) {
@@ -1929,6 +1937,9 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
19291937
NumElems = VecTy->getNumElements();
19301938
Ty = VecTy->getElementType();
19311939
}
1940+
if (Ty->isBFloatTy() &&
1941+
BM->hasCapability(internal::CapabilityBFloat16ArithmeticINTEL))
1942+
return true;
19321943
if ((!Ty->isIntegerTy()) ||
19331944
(!BM->hasCapability(CapabilityVectorAnyINTEL) &&
19341945
((NumElems > 4) && (NumElems != 8) && (NumElems != 16)))) {

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4201,6 +4201,20 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
42014201
// -spirv-allow-unknown-intrinsics work correctly.
42024202
auto IID = II->getIntrinsicID();
42034203
switch (IID) {
4204+
case Intrinsic::fabs:
4205+
case Intrinsic::fma:
4206+
case Intrinsic::maxnum:
4207+
case Intrinsic::minnum:
4208+
case Intrinsic::fmuladd: {
4209+
Type *Ty = II->getType();
4210+
if (Ty->isBFloatTy())
4211+
BM->addCapability(internal::CapabilityBFloat16ArithmeticINTEL);
4212+
break;
4213+
}
4214+
default:
4215+
break;
4216+
}
4217+
switch (IID) {
42044218
case Intrinsic::assume: {
42054219
// llvm.assume translation is currently supported only within
42064220
// SPV_KHR_expect_assume extension, ignore it otherwise, since it's
@@ -5485,6 +5499,11 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI,
54855499
SmallVector<std::string, 2> Dec;
54865500
if (isBuiltinTransToExtInst(CI->getCalledFunction(), &ExtSetKind, &ExtOp,
54875501
&Dec)) {
5502+
if (const auto *FirstArg = F->getArg(0)) {
5503+
const auto *Type = FirstArg->getType();
5504+
if (Type->isBFloatTy())
5505+
BM->addCapability(internal::CapabilityBFloat16ArithmeticINTEL);
5506+
}
54885507
if (DemangledName.find("__spirv_ocl_printf") != StringRef::npos) {
54895508
auto *FormatStrPtr = cast<PointerType>(CI->getArgOperand(0)->getType());
54905509
if (FormatStrPtr->getAddressSpace() !=

lib/SPIRV/libSPIRV/SPIRVEntry.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,8 @@ class SPIRVCapability : public SPIRVEntryNoId<OpCapability> {
916916
case CapabilityFunctionVariantsINTEL:
917917
case CapabilitySpecConditionalINTEL:
918918
return ExtensionID::SPV_INTEL_function_variants;
919+
case internal::CapabilityBFloat16ArithmeticINTEL:
920+
return ExtensionID::SPV_INTEL_bfloat16_arithmetic;
919921
default:
920922
return {};
921923
}

lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
228228
{CapabilityBFloat16TypeKHR, CapabilityCooperativeMatrixKHR});
229229
ADD_VEC_INIT(CapabilityInt4CooperativeMatrixINTEL,
230230
{CapabilityInt4TypeINTEL, CapabilityCooperativeMatrixKHR});
231+
ADD_VEC_INIT(internal::CapabilityBFloat16ArithmeticINTEL,
232+
{CapabilityBFloat16TypeKHR});
231233
}
232234

233235
template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,6 +1732,8 @@ SPIRVInstruction *SPIRVModuleImpl::addBinaryInst(Op TheOpCode, SPIRVType *Type,
17321732
SPIRVValue *Op1,
17331733
SPIRVValue *Op2,
17341734
SPIRVBasicBlock *BB) {
1735+
if (Type->isTypeFloat(16, FPEncodingBFloat16KHR) && TheOpCode != OpDot)
1736+
addCapability(internal::CapabilityBFloat16ArithmeticINTEL);
17351737
return addInstruction(SPIRVInstTemplateBase::create(
17361738
TheOpCode, Type, getId(),
17371739
getVec(Op1->getId(), Op2->getId()), BB, this),
@@ -1755,6 +1757,8 @@ SPIRVInstruction *SPIRVModuleImpl::addUnaryInst(Op TheOpCode,
17551757
SPIRVType *TheType,
17561758
SPIRVValue *Op,
17571759
SPIRVBasicBlock *BB) {
1760+
if (TheType->isTypeFloat(16, FPEncodingBFloat16KHR) && TheOpCode != OpDot)
1761+
addCapability(internal::CapabilityBFloat16ArithmeticINTEL);
17581762
return addInstruction(
17591763
SPIRVInstTemplateBase::create(TheOpCode, TheType, getId(),
17601764
getVec(Op->getId()), BB, this),

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
690690
add(CapabilityInt4CooperativeMatrixINTEL, "Int4CooperativeMatrixINTEL");
691691
add(CapabilityFunctionVariantsINTEL, "FunctionVariantsINTEL");
692692
add(CapabilitySpecConditionalINTEL, "SpecConditionalINTEL");
693+
add(internal::CapabilityBFloat16ArithmeticINTEL, "BFloat16ArithmeticINTEL");
693694
}
694695
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
695696

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ enum InternalCapability {
106106
ICapGlobalVariableDecorationsINTEL = 6146,
107107
ICapabilityTaskSequenceINTEL = 6162,
108108
ICapabilityCooperativeMatrixCheckedInstructionsINTEL = 6192,
109+
ICapabilityBFloat16ArithmeticINTEL = 6226,
109110
ICapabilityCooperativeMatrixOffsetInstructionsINTEL = 6238,
110111
ICapabilityAtomicBFloat16AddINTEL = 6255,
111112
ICapabilityAtomicBFloat16MinMaxINTEL = 6256,
@@ -268,6 +269,8 @@ constexpr Capability CapabilityTokenTypeINTEL =
268269
static_cast<Capability>(ICapTokenTypeINTEL);
269270
constexpr Capability CapabilityGlobalVariableDecorationsINTEL =
270271
static_cast<Capability>(ICapGlobalVariableDecorationsINTEL);
272+
constexpr Capability CapabilityBFloat16ArithmeticINTEL =
273+
static_cast<Capability>(ICapabilityBFloat16ArithmeticINTEL);
271274

272275
constexpr ExecutionMode ExecutionModeNamedSubgroupSizeINTEL =
273276
static_cast<ExecutionMode>(IExecModeNamedSubgroupSizeINTEL);

0 commit comments

Comments
 (0)