Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2765,7 +2765,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}

bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
if (containsBF16Type(U))
if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
return false;

const CallInst &CI = cast<CallInst>(U);
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_KHR_float_controls2",
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
{"SPV_INTEL_tensor_float32_conversion",
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
{"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}};

bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
Expand Down
40 changes: 35 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,13 @@ static void AddDotProductRequirements(const MachineInstr &MI,
}
}

static bool isBFloat16Type(const SPIRVType *TypeDef) {
return TypeDef && TypeDef->getNumOperands() == 3 &&
TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
TypeDef->getOperand(1).getImm() == 16 &&
TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also check the bitwidth?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I have added the check for bitwidth here.

}

void addInstrRequirements(const MachineInstr &MI,
SPIRV::RequirementHandler &Reqs,
const SPIRVSubtarget &ST) {
Expand Down Expand Up @@ -1261,12 +1268,29 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Int8);
break;
}
case SPIRV::OpDot: {
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
if (isBFloat16Type(TypeDef))
Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
break;
}
case SPIRV::OpTypeFloat: {
unsigned BitWidth = MI.getOperand(1).getImm();
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16)
Reqs.addCapability(SPIRV::Capability::Float16);
else if (BitWidth == 16) {
if (isBFloat16Type(&MI)) {
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
report_fatal_error("OpTypeFloat type with bfloat requires the "
"following SPIR-V extension: SPV_KHR_bfloat16",
false);
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
} else {
Reqs.addCapability(SPIRV::Capability::Float16);
}
}
break;
}
case SPIRV::OpTypeVector: {
Expand All @@ -1286,8 +1310,9 @@ void addInstrRequirements(const MachineInstr &MI,
assert(MI.getOperand(2).isReg());
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
TypeDef->getOperand(1).getImm() == 16)
if ((TypeDef->getNumOperands() == 2) &&
(TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
(TypeDef->getOperand(1).getImm() == 16))
Reqs.addCapability(SPIRV::Capability::Float16Buffer);
break;
}
Expand Down Expand Up @@ -1593,15 +1618,20 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::AsmINTEL);
}
break;
case SPIRV::OpTypeCooperativeMatrixKHR:
case SPIRV::OpTypeCooperativeMatrixKHR: {
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
report_fatal_error(
"OpTypeCooperativeMatrixKHR type requires the "
"following SPIR-V extension: SPV_KHR_cooperative_matrix",
false);
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
if (isBFloat16Type(TypeDef))
Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
break;
}
case SPIRV::OpArithmeticFenceEXT:
if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
report_fatal_error("OpArithmeticFenceEXT requires the "
Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
Expand Down Expand Up @@ -595,6 +596,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
Expand Down Expand Up @@ -2021,4 +2025,4 @@ multiclass FPEncodingOperand<bits<32> value, list<Extension> reqExtensions>{
reqExtensions, [], []>;
}

defm BFloat16KHR : FPEncodingOperand<0, []>;
defm BFloat16KHR : FPEncodingOperand<0, [SPV_KHR_bfloat16]>;
7 changes: 4 additions & 3 deletions llvm/test/CodeGen/SPIRV/basic_float_types.ll
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}

define void @main() {
entry:

; CHECK-DAG: OpCapability Float16
; CHECK-DAG: OpCapability Float64
; CHECK-DAG: OpCapability BFloat16TypeKHR

; CHECK-DAG: %[[#half:]] = OpTypeFloat 16{{$}}
; CHECK-DAG: %[[#bfloat:]] = OpTypeFloat 16 0{{$}}
Expand Down
22 changes: 22 additions & 0 deletions llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}

; CHECK-ERROR: LLVM ERROR: OpTypeFloat type with bfloat requires the following SPIR-V extension: SPV_KHR_bfloat16

; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

define spir_kernel void @test() {
entry:
%addr1 = alloca bfloat
%addr2 = alloca <2 x bfloat>
%data1 = load bfloat, ptr %addr1
%data2 = load <2 x bfloat>, ptr %addr2
ret void
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK-DAG: OpCapability CooperativeMatrixKHR
; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK-DAG: OpExtension "SPV_KHR_cooperative_matrix"
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]] %[[#]] %[[#]] %[[#]] %[[#]]
; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]]

define spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) {
entry:
%addr1 = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
%res = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4
%m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat 1.0)
store target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) %m1, ptr %addr1, align 4
ret void
}

declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat)
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK-DAG: OpCapability BFloat16DotProductKHR
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
; CHECK: OpDot

declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)

define spir_kernel void @test() {
entry:
%addrA = alloca <2 x bfloat>
%addrB = alloca <2 x bfloat>
%dataA = load <2 x bfloat>, ptr %addrA
%dataB = load <2 x bfloat>, ptr %addrB
%call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
ret void
}
Loading