diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp index 768e3713f78e2..12b735e053bde 100644 --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -2765,7 +2765,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB, } bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) { - if (containsBF16Type(U)) + if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U)) return false; const CallInst &CI = cast(U); diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index e7da5504b2d58..993de9e9f64ec 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -147,7 +147,8 @@ static const std::map> {"SPV_KHR_float_controls2", SPIRV::Extension::Extension::SPV_KHR_float_controls2}, {"SPV_INTEL_tensor_float32_conversion", - SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}}; + SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}, + {"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}}; bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName, StringRef ArgValue, diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index b7e371d190866..a95f393b75605 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1222,6 +1222,13 @@ static void AddDotProductRequirements(const MachineInstr &MI, } } +static bool isBFloat16Type(const SPIRVType *TypeDef) { + return TypeDef && TypeDef->getNumOperands() == 3 && + TypeDef->getOpcode() == SPIRV::OpTypeFloat && + TypeDef->getOperand(1).getImm() == 16 && + TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR; +} + void addInstrRequirements(const MachineInstr &MI, SPIRV::RequirementHandler &Reqs, const SPIRVSubtarget &ST) { @@ -1261,12 +1268,29 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::Int8); break; } + case SPIRV::OpDot: { + const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); + SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg()); + if (isBFloat16Type(TypeDef)) + Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR); + break; + } case SPIRV::OpTypeFloat: { unsigned BitWidth = MI.getOperand(1).getImm(); if (BitWidth == 64) Reqs.addCapability(SPIRV::Capability::Float64); - else if (BitWidth == 16) - Reqs.addCapability(SPIRV::Capability::Float16); + else if (BitWidth == 16) { + if (isBFloat16Type(&MI)) { + if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16)) + report_fatal_error("OpTypeFloat type with bfloat requires the " + "following SPIR-V extension: SPV_KHR_bfloat16", + false); + Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16); + Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR); + } else { + Reqs.addCapability(SPIRV::Capability::Float16); + } + } break; } case SPIRV::OpTypeVector: { @@ -1286,8 +1310,9 @@ void addInstrRequirements(const MachineInstr &MI, assert(MI.getOperand(2).isReg()); const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg()); - if (TypeDef->getOpcode() == SPIRV::OpTypeFloat && - TypeDef->getOperand(1).getImm() == 16) + if ((TypeDef->getNumOperands() == 2) && + (TypeDef->getOpcode() == SPIRV::OpTypeFloat) && + (TypeDef->getOperand(1).getImm() == 16)) Reqs.addCapability(SPIRV::Capability::Float16Buffer); break; } @@ -1593,7 +1618,7 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::AsmINTEL); } break; - case SPIRV::OpTypeCooperativeMatrixKHR: + case SPIRV::OpTypeCooperativeMatrixKHR: { if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix)) report_fatal_error( "OpTypeCooperativeMatrixKHR type requires the " @@ -1601,7 +1626,12 @@ void addInstrRequirements(const MachineInstr &MI, false); Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix); Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR); + const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); + SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg()); + if (isBFloat16Type(TypeDef)) + Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR); break; + } case SPIRV::OpArithmeticFenceEXT: if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence)) report_fatal_error("OpArithmeticFenceEXT requires the " diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index ed933f872d136..501bcb94af2ea 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -383,6 +383,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>; defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>; defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>; defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>; +defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -595,6 +596,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>; defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>; defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>; +defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>; +defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>; +defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>; //===----------------------------------------------------------------------===// // Multiclass used to define SourceLanguage enum values and at the same time @@ -2021,4 +2025,4 @@ multiclass FPEncodingOperand value, list reqExtensions>{ reqExtensions, [], []>; } -defm BFloat16KHR : FPEncodingOperand<0, []>; +defm BFloat16KHR : FPEncodingOperand<0, [SPV_KHR_bfloat16]>; diff --git a/llvm/test/CodeGen/SPIRV/basic_float_types.ll b/llvm/test/CodeGen/SPIRV/basic_float_types.ll index 486f6358ce5de..a0ba97e1d1f14 100644 --- a/llvm/test/CodeGen/SPIRV/basic_float_types.ll +++ b/llvm/test/CodeGen/SPIRV/basic_float_types.ll @@ -1,12 +1,13 @@ -; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s -; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s -; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %} define void @main() { entry: ; CHECK-DAG: OpCapability Float16 ; CHECK-DAG: OpCapability Float64 +; CHECK-DAG: OpCapability BFloat16TypeKHR ; CHECK-DAG: %[[#half:]] = OpTypeFloat 16{{$}} ; CHECK-DAG: %[[#bfloat:]] = OpTypeFloat 16 0{{$}} diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll new file mode 100644 index 0000000000000..22668e71fb257 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll @@ -0,0 +1,22 @@ +; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %} + +; CHECK-ERROR: LLVM ERROR: OpTypeFloat type with bfloat requires the following SPIR-V extension: SPV_KHR_bfloat16 + +; CHECK-DAG: OpCapability BFloat16TypeKHR +; CHECK-DAG: OpExtension "SPV_KHR_bfloat16" +; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0 +; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2 + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spir64-unknown-unknown" + +define spir_kernel void @test() { +entry: + %addr1 = alloca bfloat + %addr2 = alloca <2 x bfloat> + %data1 = load bfloat, ptr %addr1 + %data2 = load <2 x bfloat>, ptr %addr2 + ret void +} diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll new file mode 100644 index 0000000000000..d47b5d7440d18 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_cooperative_matrix.ll @@ -0,0 +1,22 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpCapability BFloat16TypeKHR +; CHECK-DAG: OpCapability CooperativeMatrixKHR +; CHECK-DAG: OpCapability BFloat16CooperativeMatrixKHR +; CHECK-DAG: OpExtension "SPV_KHR_bfloat16" +; CHECK-DAG: OpExtension "SPV_KHR_cooperative_matrix" +; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0 +; CHECK: %[[#MatTy:]] = OpTypeCooperativeMatrixKHR %[[#BFLOAT]] %[[#]] %[[#]] %[[#]] %[[#]] +; CHECK: OpCompositeConstruct %[[#MatTy]] %[[#]] + +define spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) { +entry: + %addr1 = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4 + %res = alloca target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2), align 4 + %m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat 1.0) + store target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) %m1, ptr %addr1, align 4 + ret void +} + +declare dso_local spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(bfloat) diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll new file mode 100644 index 0000000000000..4c248fea5c7f1 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll @@ -0,0 +1,21 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpCapability BFloat16TypeKHR +; CHECK-DAG: OpCapability BFloat16DotProductKHR +; CHECK-DAG: OpExtension "SPV_KHR_bfloat16" +; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0 +; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2 +; CHECK: OpDot + +declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>) + +define spir_kernel void @test() { +entry: + %addrA = alloca <2 x bfloat> + %addrB = alloca <2 x bfloat> + %dataA = load <2 x bfloat>, ptr %addrA + %dataB = load <2 x bfloat>, ptr %addrB + %call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB) + ret void +}