diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h index c2c08f8831307..d76180ce97e9e 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h @@ -232,6 +232,11 @@ namespace SpecConstantOpOperands { #include "SPIRVGenTables.inc" } // namespace SpecConstantOpOperands +namespace FPEncoding { +#define GET_FPEncoding_DECL +#include "SPIRVGenTables.inc" +} // namespace FPEncoding + struct ExtendedBuiltin { StringRef Name; InstructionSet::InstructionSet Set; diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index cfe24c84941a9..115766ce886c7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -203,6 +203,18 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, }); } +SPIRVType * +SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, + MachineIRBuilder &MIRBuilder, + SPIRV::FPEncoding::FPEncoding FPEncode) { + return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { + return MIRBuilder.buildInstr(SPIRV::OpTypeFloat) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(Width) + .addImm(FPEncode); + }); +} + SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) @@ -1041,8 +1053,14 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType( return Width == 1 ? getOpTypeBool(MIRBuilder) : getOpTypeInt(Width, MIRBuilder, false); } - if (Ty->isFloatingPointTy()) - return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); + if (Ty->isFloatingPointTy()) { + if (Ty->isBFloatTy()) { + return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, + SPIRV::FPEncoding::BFloat16KHR); + } else { + return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); + } + } if (Ty->isVoidTy()) return getOpTypeVoid(MIRBuilder); if (Ty->isVectorTy()) { diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index 7ef812828b7cc..a648defa0a888 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -438,6 +438,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder); + SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, + SPIRV::FPEncoding::FPEncoding FPEncode); + SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder); SPIRVType *getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType, diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index 8d10cd0ffb3dd..496dcba17c10d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -167,7 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">; def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">; def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness), "$type = OpTypeInt $width $signedness">; -def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width), +def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops), "$type = OpTypeFloat $width">; def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount), "$type = OpTypeVector $compType $compCount">; diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index d2824ee2d2caf..ed933f872d136 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -210,6 +210,7 @@ def CooperativeMatrixLayoutOperand : OperandCategory; def CooperativeMatrixOperandsOperand : OperandCategory; def SpecConstantOpOperandsOperand : OperandCategory; def MatrixMultiplyAccumulateOperandsOperand : OperandCategory; +def FPEncodingOperand : OperandCategory; //===----------------------------------------------------------------------===// // Definition of the Environments @@ -1996,3 +1997,28 @@ defm MatrixAPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x400, defm MatrixBPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>; defm MatrixAPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>; defm MatrixBPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>; + +//===----------------------------------------------------------------------===// +// Multiclass used to define FPEncoding enum values and at the +// same time SymbolicOperand entries with extensions. +//===----------------------------------------------------------------------===// +def FPEncoding : GenericEnum, Operand { + let FilterClass = "FPEncoding"; + let NameField = "Name"; + let ValueField = "Value"; + let PrintMethod = !strconcat("printSymbolicOperand"); +} + +class FPEncoding value> { + string Name = name; + bits<32> Value = value; +} + +multiclass FPEncodingOperand value, list reqExtensions>{ + def NAME : FPEncoding; + defm : SymbolicOperandWithRequirements< + FPEncodingOperand, value, NAME, 0, 0, + reqExtensions, [], []>; +} + +defm BFloat16KHR : FPEncodingOperand<0, []>; diff --git a/llvm/test/CodeGen/SPIRV/basic_float_types.ll b/llvm/test/CodeGen/SPIRV/basic_float_types.ll index dfee1ace2205d..486f6358ce5de 100644 --- a/llvm/test/CodeGen/SPIRV/basic_float_types.ll +++ b/llvm/test/CodeGen/SPIRV/basic_float_types.ll @@ -1,6 +1,6 @@ ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s -; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} +; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} define void @main() { entry: @@ -8,7 +8,8 @@ entry: ; CHECK-DAG: OpCapability Float16 ; CHECK-DAG: OpCapability Float64 -; CHECK-DAG: %[[#half:]] = OpTypeFloat 16 +; CHECK-DAG: %[[#half:]] = OpTypeFloat 16{{$}} +; CHECK-DAG: %[[#bfloat:]] = OpTypeFloat 16 0{{$}} ; CHECK-DAG: %[[#float:]] = OpTypeFloat 32 ; CHECK-DAG: %[[#double:]] = OpTypeFloat 64 @@ -16,6 +17,10 @@ entry: ; CHECK-DAG: %[[#v3half:]] = OpTypeVector %[[#half]] 3 ; CHECK-DAG: %[[#v4half:]] = OpTypeVector %[[#half]] 4 +; CHECK-DAG: %[[#v2bfloat:]] = OpTypeVector %[[#bfloat]] 2 +; CHECK-DAG: %[[#v3bfloat:]] = OpTypeVector %[[#bfloat]] 3 +; CHECK-DAG: %[[#v4bfloat:]] = OpTypeVector %[[#bfloat]] 4 + ; CHECK-DAG: %[[#v2float:]] = OpTypeVector %[[#float]] 2 ; CHECK-DAG: %[[#v3float:]] = OpTypeVector %[[#float]] 3 ; CHECK-DAG: %[[#v4float:]] = OpTypeVector %[[#float]] 4 @@ -25,11 +30,15 @@ entry: ; CHECK-DAG: %[[#v4double:]] = OpTypeVector %[[#double]] 4 ; CHECK-DAG: %[[#ptr_Function_half:]] = OpTypePointer Function %[[#half]] +; CHECK-DAG: %[[#ptr_Function_bfloat:]] = OpTypePointer Function %[[#bfloat]] ; CHECK-DAG: %[[#ptr_Function_float:]] = OpTypePointer Function %[[#float]] ; CHECK-DAG: %[[#ptr_Function_double:]] = OpTypePointer Function %[[#double]] ; CHECK-DAG: %[[#ptr_Function_v2half:]] = OpTypePointer Function %[[#v2half]] ; CHECK-DAG: %[[#ptr_Function_v3half:]] = OpTypePointer Function %[[#v3half]] ; CHECK-DAG: %[[#ptr_Function_v4half:]] = OpTypePointer Function %[[#v4half]] +; CHECK-DAG: %[[#ptr_Function_v2bfloat:]] = OpTypePointer Function %[[#v2bfloat]] +; CHECK-DAG: %[[#ptr_Function_v3bfloat:]] = OpTypePointer Function %[[#v3bfloat]] +; CHECK-DAG: %[[#ptr_Function_v4bfloat:]] = OpTypePointer Function %[[#v4bfloat]] ; CHECK-DAG: %[[#ptr_Function_v2float:]] = OpTypePointer Function %[[#v2float]] ; CHECK-DAG: %[[#ptr_Function_v3float:]] = OpTypePointer Function %[[#v3float]] ; CHECK-DAG: %[[#ptr_Function_v4float:]] = OpTypePointer Function %[[#v4float]] @@ -40,6 +49,9 @@ entry: ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_half]] Function %half_Val = alloca half, align 2 +; CHECK: %[[#]] = OpVariable %[[#ptr_Function_bfloat]] Function + %bfloat_Val = alloca bfloat, align 2 + ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_float]] Function %float_Val = alloca float, align 4 @@ -55,6 +67,15 @@ entry: ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4half]] Function %half4_Val = alloca <4 x half>, align 8 +; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2bfloat]] Function + %bfloat2_Val = alloca <2 x bfloat>, align 4 + +; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v3bfloat]] Function + %bfloat3_Val = alloca <3 x bfloat>, align 8 + +; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4bfloat]] Function + %bfloat4_Val = alloca <4 x bfloat>, align 8 + ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2float]] Function %float2_Val = alloca <2 x float>, align 8