Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ namespace SpecConstantOpOperands {
#include "SPIRVGenTables.inc"
} // namespace SpecConstantOpOperands

namespace FPEncoding {
#define GET_FPEncoding_DECL
#include "SPIRVGenTables.inc"
} // namespace FPEncoding

struct ExtendedBuiltin {
StringRef Name;
InstructionSet::InstructionSet Set;
Expand Down
22 changes: 20 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,18 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
});
}

SPIRVType *
SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
MachineIRBuilder &MIRBuilder,
SPIRV::FPEncoding::FPEncoding FPEncode) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width)
.addImm(FPEncode);
});
}

SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
Expand Down Expand Up @@ -1041,8 +1053,14 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
return Width == 1 ? getOpTypeBool(MIRBuilder)
: getOpTypeInt(Width, MIRBuilder, false);
}
if (Ty->isFloatingPointTy())
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
if (Ty->isFloatingPointTy()) {
if (Ty->isBFloatTy()) {
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder,
SPIRV::FPEncoding::BFloat16KHR);
} else {
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
}
}
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {

SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);

SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder,
SPIRV::FPEncoding::FPEncoding FPEncode);

SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);

SPIRVType *getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType,
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
"$type = OpTypeInt $width $signedness">;
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
"$type = OpTypeFloat $width">;
def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
"$type = OpTypeVector $compType $compCount">;
Expand Down
26 changes: 26 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def CooperativeMatrixLayoutOperand : OperandCategory;
def CooperativeMatrixOperandsOperand : OperandCategory;
def SpecConstantOpOperandsOperand : OperandCategory;
def MatrixMultiplyAccumulateOperandsOperand : OperandCategory;
def FPEncodingOperand : OperandCategory;

//===----------------------------------------------------------------------===//
// Definition of the Environments
Expand Down Expand Up @@ -1996,3 +1997,28 @@ defm MatrixAPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x400,
defm MatrixBPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixAPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixBPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define FPEncoding enum values and at the
// same time SymbolicOperand entries with extensions.
//===----------------------------------------------------------------------===//
def FPEncoding : GenericEnum, Operand<i32> {
let FilterClass = "FPEncoding";
let NameField = "Name";
let ValueField = "Value";
let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
}

class FPEncoding<string name, bits<32> value> {
string Name = name;
bits<32> Value = value;
}

multiclass FPEncodingOperand<bits<32> value, list<Extension> reqExtensions>{
def NAME : FPEncoding<NAME, value>;
defm : SymbolicOperandWithRequirements<
FPEncodingOperand, value, NAME, 0, 0,
reqExtensions, [], []>;
}

defm BFloat16KHR : FPEncodingOperand<0, []>;
25 changes: 23 additions & 2 deletions llvm/test/CodeGen/SPIRV/basic_float_types.ll
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}

define void @main() {
entry:

; CHECK-DAG: OpCapability Float16
; CHECK-DAG: OpCapability Float64

; CHECK-DAG: %[[#half:]] = OpTypeFloat 16
; CHECK-DAG: %[[#half:]] = OpTypeFloat 16{{$}}
; CHECK-DAG: %[[#bfloat:]] = OpTypeFloat 16 0{{$}}
; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
; CHECK-DAG: %[[#double:]] = OpTypeFloat 64

; CHECK-DAG: %[[#v2half:]] = OpTypeVector %[[#half]] 2
; CHECK-DAG: %[[#v3half:]] = OpTypeVector %[[#half]] 3
; CHECK-DAG: %[[#v4half:]] = OpTypeVector %[[#half]] 4

; CHECK-DAG: %[[#v2bfloat:]] = OpTypeVector %[[#bfloat]] 2
; CHECK-DAG: %[[#v3bfloat:]] = OpTypeVector %[[#bfloat]] 3
; CHECK-DAG: %[[#v4bfloat:]] = OpTypeVector %[[#bfloat]] 4

; CHECK-DAG: %[[#v2float:]] = OpTypeVector %[[#float]] 2
; CHECK-DAG: %[[#v3float:]] = OpTypeVector %[[#float]] 3
; CHECK-DAG: %[[#v4float:]] = OpTypeVector %[[#float]] 4
Expand All @@ -25,11 +30,15 @@ entry:
; CHECK-DAG: %[[#v4double:]] = OpTypeVector %[[#double]] 4

; CHECK-DAG: %[[#ptr_Function_half:]] = OpTypePointer Function %[[#half]]
; CHECK-DAG: %[[#ptr_Function_bfloat:]] = OpTypePointer Function %[[#bfloat]]
; CHECK-DAG: %[[#ptr_Function_float:]] = OpTypePointer Function %[[#float]]
; CHECK-DAG: %[[#ptr_Function_double:]] = OpTypePointer Function %[[#double]]
; CHECK-DAG: %[[#ptr_Function_v2half:]] = OpTypePointer Function %[[#v2half]]
; CHECK-DAG: %[[#ptr_Function_v3half:]] = OpTypePointer Function %[[#v3half]]
; CHECK-DAG: %[[#ptr_Function_v4half:]] = OpTypePointer Function %[[#v4half]]
; CHECK-DAG: %[[#ptr_Function_v2bfloat:]] = OpTypePointer Function %[[#v2bfloat]]
; CHECK-DAG: %[[#ptr_Function_v3bfloat:]] = OpTypePointer Function %[[#v3bfloat]]
; CHECK-DAG: %[[#ptr_Function_v4bfloat:]] = OpTypePointer Function %[[#v4bfloat]]
; CHECK-DAG: %[[#ptr_Function_v2float:]] = OpTypePointer Function %[[#v2float]]
; CHECK-DAG: %[[#ptr_Function_v3float:]] = OpTypePointer Function %[[#v3float]]
; CHECK-DAG: %[[#ptr_Function_v4float:]] = OpTypePointer Function %[[#v4float]]
Expand All @@ -40,6 +49,9 @@ entry:
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_half]] Function
%half_Val = alloca half, align 2

; CHECK: %[[#]] = OpVariable %[[#ptr_Function_bfloat]] Function
%bfloat_Val = alloca bfloat, align 2

; CHECK: %[[#]] = OpVariable %[[#ptr_Function_float]] Function
%float_Val = alloca float, align 4

Expand All @@ -55,6 +67,15 @@ entry:
; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4half]] Function
%half4_Val = alloca <4 x half>, align 8

; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2bfloat]] Function
%bfloat2_Val = alloca <2 x bfloat>, align 4

; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v3bfloat]] Function
%bfloat3_Val = alloca <3 x bfloat>, align 8

; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4bfloat]] Function
%bfloat4_Val = alloca <4 x bfloat>, align 8

; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2float]] Function
%float2_Val = alloca <2 x float>, align 8

Expand Down
Loading