Skip to content

Commit 3a5cc95

Browse files
[SPIRV] Add FPEncoding operand support for OpTypeFloat (#156871)
This PR introduces support for `FPEncoding` operand for SPIR-V instruction `OpTypeFloat`, with the following main changes: 1. Introduces `FPEncoding` enum class to represent floating-point encodings, such as `BFloat16KHR`, in SPIR-V. 2. Updates SPIR-V instruction `OpTypeFloat` to accept `FPEncoding` as its second input operand. 3. Updates SPIR-V type creation to handle new encoding requirements. This PR enables support for the BFloat floating-point type in SPIR-V.
1 parent 46fcece commit 3a5cc95

File tree

6 files changed

+78
-5
lines changed

6 files changed

+78
-5
lines changed

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,11 @@ namespace SpecConstantOpOperands {
232232
#include "SPIRVGenTables.inc"
233233
} // namespace SpecConstantOpOperands
234234

235+
namespace FPEncoding {
236+
#define GET_FPEncoding_DECL
237+
#include "SPIRVGenTables.inc"
238+
} // namespace FPEncoding
239+
235240
struct ExtendedBuiltin {
236241
StringRef Name;
237242
InstructionSet::InstructionSet Set;

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,18 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
203203
});
204204
}
205205

206+
SPIRVType *
207+
SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
208+
MachineIRBuilder &MIRBuilder,
209+
SPIRV::FPEncoding::FPEncoding FPEncode) {
210+
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
211+
return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
212+
.addDef(createTypeVReg(MIRBuilder))
213+
.addImm(Width)
214+
.addImm(FPEncode);
215+
});
216+
}
217+
206218
SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
207219
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
208220
return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
@@ -1041,8 +1053,14 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
10411053
return Width == 1 ? getOpTypeBool(MIRBuilder)
10421054
: getOpTypeInt(Width, MIRBuilder, false);
10431055
}
1044-
if (Ty->isFloatingPointTy())
1045-
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
1056+
if (Ty->isFloatingPointTy()) {
1057+
if (Ty->isBFloatTy()) {
1058+
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder,
1059+
SPIRV::FPEncoding::BFloat16KHR);
1060+
} else {
1061+
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
1062+
}
1063+
}
10461064
if (Ty->isVoidTy())
10471065
return getOpTypeVoid(MIRBuilder);
10481066
if (Ty->isVectorTy()) {

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
438438

439439
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
440440

441+
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder,
442+
SPIRV::FPEncoding::FPEncoding FPEncode);
443+
441444
SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
442445

443446
SPIRVType *getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType,

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
167167
def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
168168
def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
169169
"$type = OpTypeInt $width $signedness">;
170-
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
170+
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
171171
"$type = OpTypeFloat $width">;
172172
def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
173173
"$type = OpTypeVector $compType $compCount">;

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def CooperativeMatrixLayoutOperand : OperandCategory;
210210
def CooperativeMatrixOperandsOperand : OperandCategory;
211211
def SpecConstantOpOperandsOperand : OperandCategory;
212212
def MatrixMultiplyAccumulateOperandsOperand : OperandCategory;
213+
def FPEncodingOperand : OperandCategory;
213214

214215
//===----------------------------------------------------------------------===//
215216
// Definition of the Environments
@@ -1996,3 +1997,28 @@ defm MatrixAPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x400,
19961997
defm MatrixBPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
19971998
defm MatrixAPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
19981999
defm MatrixBPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
2000+
2001+
//===----------------------------------------------------------------------===//
2002+
// Multiclass used to define FPEncoding enum values and at the
2003+
// same time SymbolicOperand entries with extensions.
2004+
//===----------------------------------------------------------------------===//
2005+
def FPEncoding : GenericEnum, Operand<i32> {
2006+
let FilterClass = "FPEncoding";
2007+
let NameField = "Name";
2008+
let ValueField = "Value";
2009+
let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
2010+
}
2011+
2012+
class FPEncoding<string name, bits<32> value> {
2013+
string Name = name;
2014+
bits<32> Value = value;
2015+
}
2016+
2017+
multiclass FPEncodingOperand<bits<32> value, list<Extension> reqExtensions>{
2018+
def NAME : FPEncoding<NAME, value>;
2019+
defm : SymbolicOperandWithRequirements<
2020+
FPEncodingOperand, value, NAME, 0, 0,
2021+
reqExtensions, [], []>;
2022+
}
2023+
2024+
defm BFloat16KHR : FPEncodingOperand<0, []>;

llvm/test/CodeGen/SPIRV/basic_float_types.ll

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
22
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
3-
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
44

55
define void @main() {
66
entry:
77

88
; CHECK-DAG: OpCapability Float16
99
; CHECK-DAG: OpCapability Float64
1010

11-
; CHECK-DAG: %[[#half:]] = OpTypeFloat 16
11+
; CHECK-DAG: %[[#half:]] = OpTypeFloat 16{{$}}
12+
; CHECK-DAG: %[[#bfloat:]] = OpTypeFloat 16 0{{$}}
1213
; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
1314
; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
1415

1516
; CHECK-DAG: %[[#v2half:]] = OpTypeVector %[[#half]] 2
1617
; CHECK-DAG: %[[#v3half:]] = OpTypeVector %[[#half]] 3
1718
; CHECK-DAG: %[[#v4half:]] = OpTypeVector %[[#half]] 4
1819

20+
; CHECK-DAG: %[[#v2bfloat:]] = OpTypeVector %[[#bfloat]] 2
21+
; CHECK-DAG: %[[#v3bfloat:]] = OpTypeVector %[[#bfloat]] 3
22+
; CHECK-DAG: %[[#v4bfloat:]] = OpTypeVector %[[#bfloat]] 4
23+
1924
; CHECK-DAG: %[[#v2float:]] = OpTypeVector %[[#float]] 2
2025
; CHECK-DAG: %[[#v3float:]] = OpTypeVector %[[#float]] 3
2126
; CHECK-DAG: %[[#v4float:]] = OpTypeVector %[[#float]] 4
@@ -25,11 +30,15 @@ entry:
2530
; CHECK-DAG: %[[#v4double:]] = OpTypeVector %[[#double]] 4
2631

2732
; CHECK-DAG: %[[#ptr_Function_half:]] = OpTypePointer Function %[[#half]]
33+
; CHECK-DAG: %[[#ptr_Function_bfloat:]] = OpTypePointer Function %[[#bfloat]]
2834
; CHECK-DAG: %[[#ptr_Function_float:]] = OpTypePointer Function %[[#float]]
2935
; CHECK-DAG: %[[#ptr_Function_double:]] = OpTypePointer Function %[[#double]]
3036
; CHECK-DAG: %[[#ptr_Function_v2half:]] = OpTypePointer Function %[[#v2half]]
3137
; CHECK-DAG: %[[#ptr_Function_v3half:]] = OpTypePointer Function %[[#v3half]]
3238
; CHECK-DAG: %[[#ptr_Function_v4half:]] = OpTypePointer Function %[[#v4half]]
39+
; CHECK-DAG: %[[#ptr_Function_v2bfloat:]] = OpTypePointer Function %[[#v2bfloat]]
40+
; CHECK-DAG: %[[#ptr_Function_v3bfloat:]] = OpTypePointer Function %[[#v3bfloat]]
41+
; CHECK-DAG: %[[#ptr_Function_v4bfloat:]] = OpTypePointer Function %[[#v4bfloat]]
3342
; CHECK-DAG: %[[#ptr_Function_v2float:]] = OpTypePointer Function %[[#v2float]]
3443
; CHECK-DAG: %[[#ptr_Function_v3float:]] = OpTypePointer Function %[[#v3float]]
3544
; CHECK-DAG: %[[#ptr_Function_v4float:]] = OpTypePointer Function %[[#v4float]]
@@ -40,6 +49,9 @@ entry:
4049
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_half]] Function
4150
%half_Val = alloca half, align 2
4251

52+
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_bfloat]] Function
53+
%bfloat_Val = alloca bfloat, align 2
54+
4355
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_float]] Function
4456
%float_Val = alloca float, align 4
4557

@@ -55,6 +67,15 @@ entry:
5567
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4half]] Function
5668
%half4_Val = alloca <4 x half>, align 8
5769

70+
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2bfloat]] Function
71+
%bfloat2_Val = alloca <2 x bfloat>, align 4
72+
73+
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v3bfloat]] Function
74+
%bfloat3_Val = alloca <3 x bfloat>, align 8
75+
76+
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4bfloat]] Function
77+
%bfloat4_Val = alloca <4 x bfloat>, align 8
78+
5879
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2float]] Function
5980
%float2_Val = alloca <2 x float>, align 8
6081

0 commit comments

Comments
 (0)