Skip to content

Commit 2286118

Browse files
authored
[SPIRV] Enable bfloat16 arithmetic (#166031)
Enable the `SPV_INTEL_bfloat16_arithmetic` extension, which allows arithmetic, relational and `OpExtInst` instructions to take `bfloat16` arguments. This patch only adds support to arithmetic and relational ops. The extension itself is rather fresh, but `bfloat16` is ubiquitous at this point and not supporting these ops is limiting.
1 parent bcb3d2f commit 2286118

File tree

7 files changed

+595
-7
lines changed

7 files changed

+595
-7
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
173173
- Allows generating arbitrary width integer types.
174174
* - ``SPV_INTEL_bindless_images``
175175
- Adds instructions to convert convert unsigned integer handles to images, samplers and sampled images.
176+
* - ``SPV_INTEL_bfloat16_arithmetic``
177+
- Allows the use of 16-bit bfloat16 values in arithmetic and relational operators.
176178
* - ``SPV_INTEL_bfloat16_conversion``
177179
- Adds instructions to convert between single-precision 32-bit floating-point values and 16-bit bfloat16 values.
178180
* - ``SPV_INTEL_cache_controls``

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,10 @@ void IRTranslator::addMachineCFGPred(CFGEdge Edge, MachineBasicBlock *NewPred) {
294294
MachinePreds[Edge].push_back(NewPred);
295295
}
296296

297+
static bool targetSupportsBF16Type(const MachineFunction *MF) {
298+
return MF->getTarget().getTargetTriple().isSPIRV();
299+
}
300+
297301
static bool containsBF16Type(const User &U) {
298302
// BF16 cannot currently be represented by LLT, to avoid miscompiles we
299303
// prevent any instructions using them. FIXME: This can be removed once LLT
@@ -306,7 +310,7 @@ static bool containsBF16Type(const User &U) {
306310

307311
bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U,
308312
MachineIRBuilder &MIRBuilder) {
309-
if (containsBF16Type(U))
313+
if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
310314
return false;
311315

312316
// Get or create a virtual register for each value.
@@ -328,7 +332,7 @@ bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U,
328332

329333
bool IRTranslator::translateUnaryOp(unsigned Opcode, const User &U,
330334
MachineIRBuilder &MIRBuilder) {
331-
if (containsBF16Type(U))
335+
if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
332336
return false;
333337

334338
Register Op0 = getOrCreateVReg(*U.getOperand(0));
@@ -348,7 +352,7 @@ bool IRTranslator::translateFNeg(const User &U, MachineIRBuilder &MIRBuilder) {
348352

349353
bool IRTranslator::translateCompare(const User &U,
350354
MachineIRBuilder &MIRBuilder) {
351-
if (containsBF16Type(U))
355+
if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
352356
return false;
353357

354358
auto *CI = cast<CmpInst>(&U);
@@ -1569,7 +1573,7 @@ bool IRTranslator::translateBitCast(const User &U,
15691573

15701574
bool IRTranslator::translateCast(unsigned Opcode, const User &U,
15711575
MachineIRBuilder &MIRBuilder) {
1572-
if (containsBF16Type(U))
1576+
if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
15731577
return false;
15741578

15751579
uint32_t Flags = 0;
@@ -2688,7 +2692,7 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
26882692

26892693
bool IRTranslator::translateInlineAsm(const CallBase &CB,
26902694
MachineIRBuilder &MIRBuilder) {
2691-
if (containsBF16Type(CB))
2695+
if (containsBF16Type(CB) && !targetSupportsBF16Type(MF))
26922696
return false;
26932697

26942698
const InlineAsmLowering *ALI = MF->getSubtarget().getInlineAsmLowering();
@@ -2779,7 +2783,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
27792783
}
27802784

27812785
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
2782-
if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
2786+
if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
27832787
return false;
27842788

27852789
const CallInst &CI = cast<CallInst>(U);

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
107107
SPIRV::Extension::Extension::SPV_INTEL_inline_assembly},
108108
{"SPV_INTEL_bindless_images",
109109
SPIRV::Extension::Extension::SPV_INTEL_bindless_images},
110+
{"SPV_INTEL_bfloat16_arithmetic",
111+
SPIRV::Extension::Extension::SPV_INTEL_bfloat16_arithmetic},
110112
{"SPV_INTEL_bfloat16_conversion",
111113
SPIRV::Extension::Extension::SPV_INTEL_bfloat16_conversion},
112114
{"SPV_KHR_subgroup_rotate",

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,8 @@ void addInstrRequirements(const MachineInstr &MI,
14351435
addPrintfRequirements(MI, Reqs, ST);
14361436
break;
14371437
}
1438+
// TODO: handle bfloat16 extended instructions when
1439+
// SPV_INTEL_bfloat16_arithmetic is enabled.
14381440
break;
14391441
}
14401442
case SPIRV::OpAliasDomainDeclINTEL:
@@ -2060,7 +2062,64 @@ void addInstrRequirements(const MachineInstr &MI,
20602062
Reqs.addCapability(SPIRV::Capability::PredicatedIOINTEL);
20612063
break;
20622064
}
2063-
2065+
case SPIRV::OpFAddS:
2066+
case SPIRV::OpFSubS:
2067+
case SPIRV::OpFMulS:
2068+
case SPIRV::OpFDivS:
2069+
case SPIRV::OpFRemS:
2070+
case SPIRV::OpFMod:
2071+
case SPIRV::OpFNegate:
2072+
case SPIRV::OpFAddV:
2073+
case SPIRV::OpFSubV:
2074+
case SPIRV::OpFMulV:
2075+
case SPIRV::OpFDivV:
2076+
case SPIRV::OpFRemV:
2077+
case SPIRV::OpFNegateV: {
2078+
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2079+
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
2080+
if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2081+
TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
2082+
if (isBFloat16Type(TypeDef)) {
2083+
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2084+
report_fatal_error(
2085+
"Arithmetic instructions with bfloat16 arguments require the "
2086+
"following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2087+
false);
2088+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2089+
Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
2090+
}
2091+
break;
2092+
}
2093+
case SPIRV::OpOrdered:
2094+
case SPIRV::OpUnordered:
2095+
case SPIRV::OpFOrdEqual:
2096+
case SPIRV::OpFOrdNotEqual:
2097+
case SPIRV::OpFOrdLessThan:
2098+
case SPIRV::OpFOrdLessThanEqual:
2099+
case SPIRV::OpFOrdGreaterThan:
2100+
case SPIRV::OpFOrdGreaterThanEqual:
2101+
case SPIRV::OpFUnordEqual:
2102+
case SPIRV::OpFUnordNotEqual:
2103+
case SPIRV::OpFUnordLessThan:
2104+
case SPIRV::OpFUnordLessThanEqual:
2105+
case SPIRV::OpFUnordGreaterThan:
2106+
case SPIRV::OpFUnordGreaterThanEqual: {
2107+
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2108+
MachineInstr *OperandDef = MRI.getVRegDef(MI.getOperand(2).getReg());
2109+
SPIRVType *TypeDef = MRI.getVRegDef(OperandDef->getOperand(1).getReg());
2110+
if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2111+
TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
2112+
if (isBFloat16Type(TypeDef)) {
2113+
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2114+
report_fatal_error(
2115+
"Relational instructions with bfloat16 arguments require the "
2116+
"following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2117+
false);
2118+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2119+
Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
2120+
}
2121+
break;
2122+
}
20642123
default:
20652124
break;
20662125
}

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,8 @@ defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
387387
defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>;
388388
defm SPV_INTEL_predicated_io : ExtensionOperand<127, [EnvOpenCL]>;
389389
defm SPV_KHR_maximal_reconvergence : ExtensionOperand<128, [EnvVulkan]>;
390+
defm SPV_INTEL_bfloat16_arithmetic
391+
: ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>;
390392

391393
//===----------------------------------------------------------------------===//
392394
// Multiclass used to define Capabilities enum values and at the same time
@@ -570,6 +572,7 @@ defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atom
570572
defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
571573
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
572574
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
575+
defm BFloat16ArithmeticINTEL : CapabilityOperand<6226, 0, 0, [SPV_INTEL_bfloat16_arithmetic], []>;
573576
defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
574577
defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
575578
defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
3+
; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
4+
5+
; CHECK-ERROR: LLVM ERROR: Arithmetic instructions with bfloat16 arguments require the following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic
6+
7+
; CHECK-DAG: OpCapability BFloat16TypeKHR
8+
; CHECK-DAG: OpCapability BFloat16ArithmeticINTEL
9+
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
10+
; CHECK-DAG: OpExtension "SPV_INTEL_bfloat16_arithmetic"
11+
; CHECK-DAG: OpName [[NEG:%.*]] "neg"
12+
; CHECK-DAG: OpName [[NEGV:%.*]] "negv"
13+
; CHECK-DAG: OpName [[ADD:%.*]] "add"
14+
; CHECK-DAG: OpName [[ADDV:%.*]] "addv"
15+
; CHECK-DAG: OpName [[SUB:%.*]] "sub"
16+
; CHECK-DAG: OpName [[SUBV:%.*]] "subv"
17+
; CHECK-DAG: OpName [[MUL:%.*]] "mul"
18+
; CHECK-DAG: OpName [[MULV:%.*]] "mulv"
19+
; CHECK-DAG: OpName [[DIV:%.*]] "div"
20+
; CHECK-DAG: OpName [[DIVV:%.*]] "divv"
21+
; CHECK-DAG: OpName [[REM:%.*]] "rem"
22+
; CHECK-DAG: OpName [[REMV:%.*]] "remv"
23+
; CHECK: [[BFLOAT:%.*]] = OpTypeFloat 16 0
24+
; CHECK: [[BFLOATV:%.*]] = OpTypeVector [[BFLOAT]] 4
25+
26+
; CHECK-DAG: [[NEG]] = OpFunction [[BFLOAT]]
27+
; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
28+
; CHECK-DAG: [[R:%.*]] = OpFNegate [[BFLOAT]] [[X]]
29+
define spir_func bfloat @neg(bfloat %x) {
30+
entry:
31+
%r = fneg bfloat %x
32+
ret bfloat %r
33+
}
34+
35+
; CHECK-DAG: [[NEGV]] = OpFunction [[BFLOATV]]
36+
; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
37+
; CHECK-DAG: [[R:%.*]] = OpFNegate [[BFLOATV]] [[X]]
38+
define spir_func <4 x bfloat> @negv(<4 x bfloat> %x) {
39+
entry:
40+
%r = fneg <4 x bfloat> %x
41+
ret <4 x bfloat> %r
42+
}
43+
44+
; CHECK-DAG: [[ADD]] = OpFunction [[BFLOAT]]
45+
; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
46+
; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
47+
; CHECK-DAG: [[R:%.*]] = OpFAdd [[BFLOAT]] [[X]] [[Y]]
48+
define spir_func bfloat @add(bfloat %x, bfloat %y) {
49+
entry:
50+
%r = fadd bfloat %x, %y
51+
ret bfloat %r
52+
}
53+
54+
; CHECK-DAG: [[ADDV]] = OpFunction [[BFLOATV]]
55+
; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
56+
; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
57+
; CHECK-DAG: [[R:%.*]] = OpFAdd [[BFLOATV]] [[X]] [[Y]]
58+
define spir_func <4 x bfloat> @addv(<4 x bfloat> %x, <4 x bfloat> %y) {
59+
entry:
60+
%r = fadd <4 x bfloat> %x, %y
61+
ret <4 x bfloat> %r
62+
}
63+
64+
; CHECK-DAG: [[SUB]] = OpFunction [[BFLOAT]]
65+
; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
66+
; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
67+
; CHECK-DAG: [[R:%.*]] = OpFSub [[BFLOAT]] [[X]] [[Y]]
68+
define spir_func bfloat @sub(bfloat %x, bfloat %y) {
69+
entry:
70+
%r = fsub bfloat %x, %y
71+
ret bfloat %r
72+
}
73+
74+
; CHECK-DAG: [[SUBV]] = OpFunction [[BFLOATV]]
75+
; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
76+
; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
77+
; CHECK-DAG: [[R:%.*]] = OpFSub [[BFLOATV]] [[X]] [[Y]]
78+
define spir_func <4 x bfloat> @subv(<4 x bfloat> %x, <4 x bfloat> %y) {
79+
entry:
80+
%r = fsub <4 x bfloat> %x, %y
81+
ret <4 x bfloat> %r
82+
}
83+
84+
; CHECK-DAG: [[MUL]] = OpFunction [[BFLOAT]]
85+
; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
86+
; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
87+
; CHECK-DAG: [[R:%.*]] = OpFMul [[BFLOAT]] [[X]] [[Y]]
88+
define spir_func bfloat @mul(bfloat %x, bfloat %y) {
89+
entry:
90+
%r = fmul bfloat %x, %y
91+
ret bfloat %r
92+
}
93+
94+
; CHECK-DAG: [[MULV]] = OpFunction [[BFLOATV]]
95+
; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
96+
; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
97+
; CHECK-DAG: [[R:%.*]] = OpFMul [[BFLOATV]] [[X]] [[Y]]
98+
define spir_func <4 x bfloat> @mulv(<4 x bfloat> %x, <4 x bfloat> %y) {
99+
entry:
100+
%r = fmul <4 x bfloat> %x, %y
101+
ret <4 x bfloat> %r
102+
}
103+
104+
; CHECK-DAG: [[DIV]] = OpFunction [[BFLOAT]]
105+
; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
106+
; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
107+
; CHECK-DAG: [[R:%.*]] = OpFDiv [[BFLOAT]] [[X]] [[Y]]
108+
define spir_func bfloat @div(bfloat %x, bfloat %y) {
109+
entry:
110+
%r = fdiv bfloat %x, %y
111+
ret bfloat %r
112+
}
113+
114+
; CHECK-DAG: [[DIVV]] = OpFunction [[BFLOATV]]
115+
; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
116+
; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
117+
; CHECK-DAG: [[R:%.*]] = OpFDiv [[BFLOATV]] [[X]] [[Y]]
118+
define spir_func <4 x bfloat> @divv(<4 x bfloat> %x, <4 x bfloat> %y) {
119+
entry:
120+
%r = fdiv <4 x bfloat> %x, %y
121+
ret <4 x bfloat> %r
122+
}
123+
124+
; CHECK-DAG: [[REM]] = OpFunction [[BFLOAT]]
125+
; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
126+
; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
127+
; CHECK-DAG: [[R:%.*]] = OpFRem [[BFLOAT]] [[X]] [[Y]]
128+
define spir_func bfloat @rem(bfloat %x, bfloat %y) {
129+
entry:
130+
%r = frem bfloat %x, %y
131+
ret bfloat %r
132+
}
133+
134+
; CHECK-DAG: [[REMV]] = OpFunction [[BFLOATV]]
135+
; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
136+
; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
137+
; CHECK-DAG: [[R:%.*]] = OpFRem [[BFLOATV]] [[X]] [[Y]]
138+
define spir_func <4 x bfloat> @remv(<4 x bfloat> %x, <4 x bfloat> %y) {
139+
entry:
140+
%r = frem <4 x bfloat> %x, %y
141+
ret <4 x bfloat> %r
142+
}

0 commit comments

Comments
 (0)