Skip to content

Commit 8ad95c0

Browse files
add support for the SPIR-V extension SPV_KHR_bfloat16
1 parent 29b6433 commit 8ad95c0

File tree

12 files changed

+185
-11
lines changed

12 files changed

+185
-11
lines changed

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2765,7 +2765,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
27652765
}
27662766

27672767
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
2768-
if (containsBF16Type(U))
2768+
if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
27692769
return false;
27702770

27712771
const CallInst &CI = cast<CallInst>(U);

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,11 @@ namespace SpecConstantOpOperands {
232232
#include "SPIRVGenTables.inc"
233233
} // namespace SpecConstantOpOperands
234234

235+
namespace FPEncoding {
236+
#define GET_FPEncoding_DECL
237+
#include "SPIRVGenTables.inc"
238+
} // namespace FPEncoding
239+
235240
struct ExtendedBuiltin {
236241
StringRef Name;
237242
InstructionSet::InstructionSet Set;

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
147147
{"SPV_KHR_float_controls2",
148148
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
149149
{"SPV_INTEL_tensor_float32_conversion",
150-
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
150+
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
151+
{"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
151152

152153
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
153154
StringRef ArgValue,

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,18 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
203203
});
204204
}
205205

206+
SPIRVType *
207+
SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
208+
MachineIRBuilder &MIRBuilder,
209+
SPIRV::FPEncoding::FPEncoding FPEncode) {
210+
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
211+
return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
212+
.addDef(createTypeVReg(MIRBuilder))
213+
.addImm(Width)
214+
.addImm(FPEncode);
215+
});
216+
}
217+
206218
SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
207219
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
208220
return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
@@ -1041,8 +1053,14 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
10411053
return Width == 1 ? getOpTypeBool(MIRBuilder)
10421054
: getOpTypeInt(Width, MIRBuilder, false);
10431055
}
1044-
if (Ty->isFloatingPointTy())
1045-
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
1056+
if (Ty->isFloatingPointTy()) {
1057+
if (Ty->isBFloatTy()) {
1058+
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder,
1059+
SPIRV::FPEncoding::BFloat16KHR);
1060+
} else {
1061+
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
1062+
}
1063+
}
10461064
if (Ty->isVoidTy())
10471065
return getOpTypeVoid(MIRBuilder);
10481066
if (Ty->isVectorTy()) {

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
438438

439439
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
440440

441+
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder,
442+
SPIRV::FPEncoding::FPEncoding FPEncode);
443+
441444
SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
442445

443446
SPIRVType *getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType,

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
167167
def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
168168
def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
169169
"$type = OpTypeInt $width $signedness">;
170-
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
170+
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
171171
"$type = OpTypeFloat $width">;
172172
def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
173173
"$type = OpTypeVector $compType $compCount">;

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,13 @@ static void AddDotProductRequirements(const MachineInstr &MI,
12221222
}
12231223
}
12241224

1225+
static bool isBFloat16Type(const SPIRVType *TypeDef) {
1226+
return TypeDef && TypeDef->getNumOperands() == 3 &&
1227+
TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1228+
TypeDef->getOperand(1).getImm() == 16 &&
1229+
TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
1230+
}
1231+
12251232
void addInstrRequirements(const MachineInstr &MI,
12261233
SPIRV::RequirementHandler &Reqs,
12271234
const SPIRVSubtarget &ST) {
@@ -1261,12 +1268,29 @@ void addInstrRequirements(const MachineInstr &MI,
12611268
Reqs.addCapability(SPIRV::Capability::Int8);
12621269
break;
12631270
}
1271+
case SPIRV::OpDot: {
1272+
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1273+
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1274+
if (isBFloat16Type(TypeDef))
1275+
Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
1276+
break;
1277+
}
12641278
case SPIRV::OpTypeFloat: {
12651279
unsigned BitWidth = MI.getOperand(1).getImm();
12661280
if (BitWidth == 64)
12671281
Reqs.addCapability(SPIRV::Capability::Float64);
1268-
else if (BitWidth == 16)
1269-
Reqs.addCapability(SPIRV::Capability::Float16);
1282+
else if (BitWidth == 16) {
1283+
if (isBFloat16Type(&MI)) {
1284+
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
1285+
report_fatal_error("OpTypeFloat type with bfloat requires the "
1286+
"following SPIR-V extension: SPV_KHR_bfloat16",
1287+
false);
1288+
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
1289+
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
1290+
} else {
1291+
Reqs.addCapability(SPIRV::Capability::Float16);
1292+
}
1293+
}
12701294
break;
12711295
}
12721296
case SPIRV::OpTypeVector: {
@@ -1286,8 +1310,9 @@ void addInstrRequirements(const MachineInstr &MI,
12861310
assert(MI.getOperand(2).isReg());
12871311
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
12881312
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
1289-
if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1290-
TypeDef->getOperand(1).getImm() == 16)
1313+
if ((TypeDef->getNumOperands() == 2) &&
1314+
(TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
1315+
(TypeDef->getOperand(1).getImm() == 16))
12911316
Reqs.addCapability(SPIRV::Capability::Float16Buffer);
12921317
break;
12931318
}
@@ -1593,15 +1618,20 @@ void addInstrRequirements(const MachineInstr &MI,
15931618
Reqs.addCapability(SPIRV::Capability::AsmINTEL);
15941619
}
15951620
break;
1596-
case SPIRV::OpTypeCooperativeMatrixKHR:
1621+
case SPIRV::OpTypeCooperativeMatrixKHR: {
15971622
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
15981623
report_fatal_error(
15991624
"OpTypeCooperativeMatrixKHR type requires the "
16001625
"following SPIR-V extension: SPV_KHR_cooperative_matrix",
16011626
false);
16021627
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
16031628
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1629+
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1630+
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1631+
if (isBFloat16Type(TypeDef))
1632+
Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
16041633
break;
1634+
}
16051635
case SPIRV::OpArithmeticFenceEXT:
16061636
if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
16071637
report_fatal_error("OpArithmeticFenceEXT requires the "

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def CooperativeMatrixLayoutOperand : OperandCategory;
210210
def CooperativeMatrixOperandsOperand : OperandCategory;
211211
def SpecConstantOpOperandsOperand : OperandCategory;
212212
def MatrixMultiplyAccumulateOperandsOperand : OperandCategory;
213+
def FPEncodingOperand : OperandCategory;
213214

214215
//===----------------------------------------------------------------------===//
215216
// Definition of the Environments
@@ -382,6 +383,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
382383
defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
383384
defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
384385
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
386+
defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>;
385387

386388
//===----------------------------------------------------------------------===//
387389
// Multiclass used to define Capabilities enum values and at the same time
@@ -594,6 +596,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
594596
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
595597
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
596598
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
599+
defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
600+
defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
601+
defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
597602

598603
//===----------------------------------------------------------------------===//
599604
// Multiclass used to define SourceLanguage enum values and at the same time
@@ -1996,3 +2001,28 @@ defm MatrixAPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x400,
19962001
defm MatrixBPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
19972002
defm MatrixAPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
19982003
defm MatrixBPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
2004+
2005+
//===----------------------------------------------------------------------===//
2006+
// Multiclass used to define FPEncoding enum values and at the
2007+
// same time SymbolicOperand entries with extensions.
2008+
//===----------------------------------------------------------------------===//
2009+
def FPEncoding : GenericEnum, Operand<i32> {
2010+
let FilterClass = "FPEncoding";
2011+
let NameField = "Name";
2012+
let ValueField = "Value";
2013+
let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
2014+
}
2015+
2016+
class FPEncoding<string name, bits<32> value> {
2017+
string Name = name;
2018+
bits<32> Value = value;
2019+
}
2020+
2021+
multiclass FPEncodingOperand<bits<32> value, list<Extension> reqExtensions>{
2022+
def NAME : FPEncoding<NAME, value>;
2023+
defm : SymbolicOperandWithRequirements<
2024+
FPEncodingOperand, value, NAME, 0, 0,
2025+
reqExtensions, [], []>;
2026+
}
2027+
2028+
defm BFloat16KHR : FPEncodingOperand<0, [SPV_KHR_bfloat16]>;

llvm/test/CodeGen/SPIRV/basic_float_types.ll

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,21 @@ entry:
77

88
; CHECK-DAG: OpCapability Float16
99
; CHECK-DAG: OpCapability Float64
10+
; CHECK-DAG: OpCapability BFloat16TypeKHR
1011

11-
; CHECK-DAG: %[[#half:]] = OpTypeFloat 16
12+
; CHECK-DAG: %[[#half:]] = OpTypeFloat 16{{$}}
13+
; CHECK-DAG: %[[#bfloat:]] = OpTypeFloat 16 0{{$}}
1214
; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
1315
; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
1416

1517
; CHECK-DAG: %[[#v2half:]] = OpTypeVector %[[#half]] 2
1618
; CHECK-DAG: %[[#v3half:]] = OpTypeVector %[[#half]] 3
1719
; CHECK-DAG: %[[#v4half:]] = OpTypeVector %[[#half]] 4
1820

21+
; CHECK-DAG: %[[#v2bfloat:]] = OpTypeVector %[[#bfloat]] 2
22+
; CHECK-DAG: %[[#v3bfloat:]] = OpTypeVector %[[#bfloat]] 3
23+
; CHECK-DAG: %[[#v4bfloat:]] = OpTypeVector %[[#bfloat]] 4
24+
1925
; CHECK-DAG: %[[#v2float:]] = OpTypeVector %[[#float]] 2
2026
; CHECK-DAG: %[[#v3float:]] = OpTypeVector %[[#float]] 3
2127
; CHECK-DAG: %[[#v4float:]] = OpTypeVector %[[#float]] 4
@@ -25,11 +31,15 @@ entry:
2531
; CHECK-DAG: %[[#v4double:]] = OpTypeVector %[[#double]] 4
2632

2733
; CHECK-DAG: %[[#ptr_Function_half:]] = OpTypePointer Function %[[#half]]
34+
; CHECK-DAG: %[[#ptr_Function_bfloat:]] = OpTypePointer Function %[[#bfloat]]
2835
; CHECK-DAG: %[[#ptr_Function_float:]] = OpTypePointer Function %[[#float]]
2936
; CHECK-DAG: %[[#ptr_Function_double:]] = OpTypePointer Function %[[#double]]
3037
; CHECK-DAG: %[[#ptr_Function_v2half:]] = OpTypePointer Function %[[#v2half]]
3138
; CHECK-DAG: %[[#ptr_Function_v3half:]] = OpTypePointer Function %[[#v3half]]
3239
; CHECK-DAG: %[[#ptr_Function_v4half:]] = OpTypePointer Function %[[#v4half]]
40+
; CHECK-DAG: %[[#ptr_Function_v2bfloat:]] = OpTypePointer Function %[[#v2bfloat]]
41+
; CHECK-DAG: %[[#ptr_Function_v3bfloat:]] = OpTypePointer Function %[[#v3bfloat]]
42+
; CHECK-DAG: %[[#ptr_Function_v4bfloat:]] = OpTypePointer Function %[[#v4bfloat]]
3343
; CHECK-DAG: %[[#ptr_Function_v2float:]] = OpTypePointer Function %[[#v2float]]
3444
; CHECK-DAG: %[[#ptr_Function_v3float:]] = OpTypePointer Function %[[#v3float]]
3545
; CHECK-DAG: %[[#ptr_Function_v4float:]] = OpTypePointer Function %[[#v4float]]
@@ -40,6 +50,9 @@ entry:
4050
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_half]] Function
4151
%half_Val = alloca half, align 2
4252

53+
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_bfloat]] Function
54+
%bfloat_Val = alloca bfloat, align 2
55+
4356
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_float]] Function
4457
%float_Val = alloca float, align 4
4558

@@ -55,6 +68,15 @@ entry:
5568
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4half]] Function
5669
%half4_Val = alloca <4 x half>, align 8
5770

71+
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2bfloat]] Function
72+
%bfloat2_Val = alloca <2 x bfloat>, align 4
73+
74+
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v3bfloat]] Function
75+
%bfloat3_Val = alloca <3 x bfloat>, align 8
76+
77+
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4bfloat]] Function
78+
%bfloat4_Val = alloca <4 x bfloat>, align 8
79+
5880
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2float]] Function
5981
%float2_Val = alloca <2 x float>, align 8
6082

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
3+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
4+
5+
; CHECK-ERROR: LLVM ERROR: OpTypeFloat type with bfloat requires the following SPIR-V extension: SPV_KHR_bfloat16
6+
7+
; CHECK-DAG: OpCapability BFloat16TypeKHR
8+
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
9+
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
10+
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
11+
12+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
13+
target triple = "spir64-unknown-unknown"
14+
15+
define spir_kernel void @test() {
16+
entry:
17+
%addr1 = alloca bfloat
18+
%addr2 = alloca <2 x bfloat>
19+
%data1 = load bfloat, ptr %addr1
20+
%data2 = load <2 x bfloat>, ptr %addr2
21+
ret void
22+
}

0 commit comments

Comments
 (0)