-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[SPIRV] Enable bfloat16 arithmetic
#166031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… HIPSTDPAR runtime component.
…v_be_staging_3
…v_be_staging_3
…v_be_staging_3
|
@llvm/pr-subscribers-backend-spir-v Author: Alex Voicu (AlexVlx) ChangesEnable the Patch is 29.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166031.diff 7 Files Affected:
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 85eeabf10244a..922c9e1a3cfe3 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -173,6 +173,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
- Allows generating arbitrary width integer types.
* - ``SPV_INTEL_bindless_images``
- Adds instructions to convert convert unsigned integer handles to images, samplers and sampled images.
+ * - ``SPV_INTEL_bfloat16_arithmetic``
+ - Allows the use of 16-bit bfloat16 values in arithmetic and relational operators.
* - ``SPV_INTEL_bfloat16_conversion``
- Adds instructions to convert between single-precision 32-bit floating-point values and 16-bit bfloat16 values.
* - ``SPV_INTEL_cache_controls``
@@ -226,9 +228,9 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
* - ``SPV_INTEL_fp_max_error``
- Adds the ability to specify the maximum error for floating-point operations.
* - ``SPV_INTEL_ternary_bitwise_function``
- - Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform.
+ - Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform.
* - ``SPV_INTEL_subgroup_matrix_multiply_accumulate``
- - Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.
+ - Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.
* - ``SPV_INTEL_int4``
- Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.
* - ``SPV_KHR_float_controls2``
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 1fc90d0852aad..847163edcbc4b 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -306,7 +306,7 @@ static bool containsBF16Type(const User &U) {
bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U,
MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(U))
+ if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
return false;
// Get or create a virtual register for each value.
@@ -328,7 +328,7 @@ bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U,
bool IRTranslator::translateUnaryOp(unsigned Opcode, const User &U,
MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(U))
+ if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
return false;
Register Op0 = getOrCreateVReg(*U.getOperand(0));
@@ -348,7 +348,7 @@ bool IRTranslator::translateFNeg(const User &U, MachineIRBuilder &MIRBuilder) {
bool IRTranslator::translateCompare(const User &U,
MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(U))
+ if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
return false;
auto *CI = cast<CmpInst>(&U);
@@ -1569,7 +1569,7 @@ bool IRTranslator::translateBitCast(const User &U,
bool IRTranslator::translateCast(unsigned Opcode, const User &U,
MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(U))
+ if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
return false;
uint32_t Flags = 0;
@@ -2688,7 +2688,7 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
bool IRTranslator::translateInlineAsm(const CallBase &CB,
MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(CB))
+ if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(CB))
return false;
const InlineAsmLowering *ALI = MF->getSubtarget().getInlineAsmLowering();
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 96f5dee21bc2a..9643db1f1bf53 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -107,6 +107,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::SPV_INTEL_inline_assembly},
{"SPV_INTEL_bindless_images",
SPIRV::Extension::Extension::SPV_INTEL_bindless_images},
+ {"SPV_INTEL_bfloat16_arithmetic",
+ SPIRV::Extension::Extension::SPV_INTEL_bfloat16_arithmetic},
{"SPV_INTEL_bfloat16_conversion",
SPIRV::Extension::Extension::SPV_INTEL_bfloat16_conversion},
{"SPV_KHR_subgroup_rotate",
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index db036a55ee6c6..009d2dc1a567a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1435,6 +1435,8 @@ void addInstrRequirements(const MachineInstr &MI,
addPrintfRequirements(MI, Reqs, ST);
break;
}
+ // TODO: handle bfloat16 extended instructions when
+ // SPV_INTEL_bfloat16_arithmetic is enabled.
break;
}
case SPIRV::OpAliasDomainDeclINTEL:
@@ -2060,7 +2062,65 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::PredicatedIOINTEL);
break;
}
-
+ case SPIRV::OpFAddS:
+ case SPIRV::OpFSubS:
+ case SPIRV::OpFMulS:
+ case SPIRV::OpFDivS:
+ case SPIRV::OpFRemS:
+ case SPIRV::OpFMod:
+ case SPIRV::OpFNegate:
+ case SPIRV::OpFAddV:
+ case SPIRV::OpFSubV:
+ case SPIRV::OpFMulV:
+ case SPIRV::OpFDivV:
+ case SPIRV::OpFRemV:
+ case SPIRV::OpFNegateV: {
+ const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+ SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+ if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
+ TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
+ if (isBFloat16Type(TypeDef)) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
+ report_fatal_error(
+ "Arithmetic instructions with bfloat16 arguments require the "
+ "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
+ Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
+ }
+ break;
+ }
+ case SPIRV::OpOrdered:
+ case SPIRV::OpUnordered:
+ case SPIRV::OpFOrdEqual:
+ case SPIRV::OpFOrdNotEqual:
+ case SPIRV::OpFOrdLessThan:
+ case SPIRV::OpFOrdLessThanEqual:
+ case SPIRV::OpFOrdGreaterThan:
+ case SPIRV::OpFOrdGreaterThanEqual:
+ case SPIRV::OpFUnordEqual:
+ case SPIRV::OpFUnordNotEqual:
+ case SPIRV::OpFUnordLessThan:
+ case SPIRV::OpFUnordLessThanEqual:
+ case SPIRV::OpFUnordGreaterThan:
+ case SPIRV::OpFUnordGreaterThanEqual: {
+ const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
+ SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+ SPIRVType *TypeDef = GR->getSPIRVTypeForVReg(MI.getOperand(2).getReg());
+ const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+ if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
+ TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
+ if (isBFloat16Type(TypeDef)) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
+ report_fatal_error(
+ "Relational instructions with bfloat16 arguments require the "
+ "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
+ Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
+ }
+ break;
+ }
default:
break;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 7d08b29a51a6e..263b59fbe6959 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -387,6 +387,8 @@ defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_predicated_io : ExtensionOperand<127, [EnvOpenCL]>;
defm SPV_KHR_maximal_reconvergence : ExtensionOperand<128, [EnvVulkan]>;
+defm SPV_INTEL_bfloat16_arithmetic
+ : ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -570,6 +572,7 @@ defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atom
defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
+defm BFloat16ArithmeticINTEL : CapabilityOperand<6226, 0, 0, [SPV_INTEL_bfloat16_arithmetic], []>;
defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
@@ -1919,7 +1922,7 @@ defm GenericCastToPtr : SpecConstantOpOperandsOperand<122, [], [Kernel]>;
defm PtrCastToGeneric : SpecConstantOpOperandsOperand<121, [], [Kernel]>;
defm Bitcast : SpecConstantOpOperandsOperand<124, [], []>;
defm QuantizeToF16 : SpecConstantOpOperandsOperand<116, [], [Shader]>;
-// Arithmetic
+// Arithmetic
defm SNegate : SpecConstantOpOperandsOperand<126, [], []>;
defm Not : SpecConstantOpOperandsOperand<200, [], []>;
defm IAdd : SpecConstantOpOperandsOperand<128, [], []>;
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll
new file mode 100644
index 0000000000000..4cabddb94df25
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll
@@ -0,0 +1,142 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: Arithmetic instructions with bfloat16 arguments require the following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16ArithmeticINTEL
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpExtension "SPV_INTEL_bfloat16_arithmetic"
+; CHECK-DAG: OpName [[NEG:%.*]] "neg"
+; CHECK-DAG: OpName [[NEGV:%.*]] "negv"
+; CHECK-DAG: OpName [[ADD:%.*]] "add"
+; CHECK-DAG: OpName [[ADDV:%.*]] "addv"
+; CHECK-DAG: OpName [[SUB:%.*]] "sub"
+; CHECK-DAG: OpName [[SUBV:%.*]] "subv"
+; CHECK-DAG: OpName [[MUL:%.*]] "mul"
+; CHECK-DAG: OpName [[MULV:%.*]] "mulv"
+; CHECK-DAG: OpName [[DIV:%.*]] "div"
+; CHECK-DAG: OpName [[DIVV:%.*]] "divv"
+; CHECK-DAG: OpName [[REM:%.*]] "rem"
+; CHECK-DAG: OpName [[REMV:%.*]] "remv"
+; CHECK: [[BFLOAT:%.*]] = OpTypeFloat 16 0
+; CHECK: [[BFLOATV:%.*]] = OpTypeVector [[BFLOAT]] 4
+
+; CHECK-DAG: [[NEG]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFNegate [[BFLOAT]] [[X]]
+define spir_func bfloat @neg(bfloat %x) {
+entry:
+ %r = fneg bfloat %x
+ ret bfloat %r
+}
+
+; CHECK-DAG: [[NEGV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFNegate [[BFLOATV]] [[X]]
+define spir_func <4 x bfloat> @negv(<4 x bfloat> %x) {
+entry:
+ %r = fneg <4 x bfloat> %x
+ ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[ADD]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFAdd [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @add(bfloat %x, bfloat %y) {
+entry:
+ %r = fadd bfloat %x, %y
+ ret bfloat %r
+}
+
+; CHECK-DAG: [[ADDV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFAdd [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @addv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+ %r = fadd <4 x bfloat> %x, %y
+ ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[SUB]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFSub [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @sub(bfloat %x, bfloat %y) {
+entry:
+ %r = fsub bfloat %x, %y
+ ret bfloat %r
+}
+
+; CHECK-DAG: [[SUBV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFSub [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @subv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+ %r = fsub <4 x bfloat> %x, %y
+ ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[MUL]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFMul [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @mul(bfloat %x, bfloat %y) {
+entry:
+ %r = fmul bfloat %x, %y
+ ret bfloat %r
+}
+
+; CHECK-DAG: [[MULV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFMul [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @mulv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+ %r = fmul <4 x bfloat> %x, %y
+ ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[DIV]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFDiv [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @div(bfloat %x, bfloat %y) {
+entry:
+ %r = fdiv bfloat %x, %y
+ ret bfloat %r
+}
+
+; CHECK-DAG: [[DIVV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFDiv [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @divv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+ %r = fdiv <4 x bfloat> %x, %y
+ ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[REM]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFRem [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @rem(bfloat %x, bfloat %y) {
+entry:
+ %r = frem bfloat %x, %y
+ ret bfloat %r
+}
+
+; CHECK-DAG: [[REMV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFRem [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @remv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+ %r = frem <4 x bfloat> %x, %y
+ ret <4 x bfloat> %r
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll
new file mode 100644
index 0000000000000..3774791d58f87
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll
@@ -0,0 +1,376 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: Relational instructions with bfloat16 arguments require the following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16ArithmeticINTEL
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpExtension "SPV_INTEL_bfloat16_arithmetic"
+; CHECK-DAG: OpName [[UEQ:%.*]] "test_ueq"
+; CHECK-DAG: OpName [[OEQ:%.*]] "test_oeq"
+; CHECK-DAG: OpName [[UNE:%.*]] "test_une"
+; CHECK-DAG: OpName [[ONE:%.*]] "test_one"
+; CHECK-DAG: OpName [[ULT:%.*]] "test_ult"
+; CHECK-DAG: OpName [[OLT:%.*]] "test_olt"
+; CHECK-DAG: OpName [[ULE:%.*]] "test_ule"
+; CHECK-DAG: OpName [[OLE:%.*]] "test_ole"
+; CHECK-DAG: OpName [[UGT:%.*]] "test_ugt"
+; CHECK-DAG: OpName [[OGT:%.*]] "test_ogt"
+; CHECK-DAG: OpName [[UGE:%.*]] "test_uge"
+; CHECK-DAG: OpName [[OGE:%.*]] "test_oge"
+; CHECK-DAG: OpName [[UNO:%.*]] "test_uno"
+; CHECK-DAG: OpName [[ORD:%.*]] "test_ord"
+; CHECK-DAG: OpName [[v3UEQ:%.*]] "test_v3_ueq"
+; CHECK-DAG: OpName [[v3OEQ:%.*]] "test_v3_oeq"
+; CHECK-DAG: OpName [[v3UNE:%.*]] "test_v3_une"
+; CHECK-DAG: OpName [[v3ONE:%.*]] "test_v3_one"
+; CHECK-DAG: OpName [[v3ULT:%.*]] "test_v3_ult"
+; CHECK-DAG: OpName [[v3OLT:%.*]] "test_v3_olt"
+; CHECK-DAG: OpName [[v3ULE:%.*]] "test_v3_ule"
+; CHECK-DAG: OpName [[v3OLE:%.*]] "test_v3_ole"
+; CHECK-DAG: OpName [[v3UGT:%.*]] "test_v3_ugt"
+; CHECK-DAG: OpName [[v3OGT:%.*]] "test_v3_ogt"
+; CHECK-DAG: OpName [[v3UGE:%.*]] "test_v3_uge"
+; CHECK-DAG: OpName [[v3OGE:%.*]] "test_v3_oge"
+; CHECK-DAG: OpName [[v3UNO:%.*]] "test_v3_uno"
+; CHECK-DAG: OpName [[v3ORD:%.*]] "test_v3_ord"
+; CHECK: [[BFLOAT:%.*]] = OpTypeFloat 16 0
+; CHECK: [[BFLOATV:%.*]] = OpTypeVector [[BFLOAT]] 3
+
+; CHECK: [[UEQ]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ueq(bfloat %a, bfloat %b) {
+ %r = fcmp ueq bfloat %a, %b
+ ret i1 %r
+}
+
+; CHECK: [[OEQ]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_oeq(bfloat %a, bfloat %b) {
+ %r = fcmp oeq bfloat %a, %b
+ ret i1 %r
+}
+
+; CHECK: [[UNE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordNotEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_une(bfloat %a, bfloat %b) {
+ %r = fcmp une bfloat %a, %b
+ ret i1 %r
+}
+
+; CHECK: [[ONE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdNotEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_one(bfloat %a, bfloat %b) {
+ %r = fcmp one bfloat %a, %b
+ ret i1 %r
+}
+
+; CHECK: [[ULT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordLessThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ult(bfloat %a, bfloat %b) {
+ %r = fcmp ult bfloat %a, %b
+ ret i1 %r
+}
+
+; CHECK: [[OLT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLO...
[truncated]
|
jmmartinez
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comment, everything else seems fine.
| bool IRTranslator::translateInlineAsm(const CallBase &CB, | ||
| MachineIRBuilder &MIRBuilder) { | ||
| if (containsBF16Type(CB)) | ||
| if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(CB)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Can we add a helper canHandleBF16ForTarget() that does the MF->getTarget().getTargetTriple().isSPIRV() check?
Adding a hook in TargetMachine would be ideal but it'd be an overkill.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
@YixingZhang007 please take a look |
AlexVlx
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Can we add a helper
canHandleBF16ForTarget()that does theMF->getTarget().getTargetTriple().isSPIRV()check?Adding a hook in
TargetMachinewould be ideal but it'd be an overkill.
We can; I'm actually not entirely certain that the check isn't vestigial at this point / if it should happen in IRTranslator at all rather than something that is left to the instruction selection / the target (the LTT issue that is mentioned in a TODO elsewhere might've been solved). Perhaps someone who is more informed can advise?
| entry: | ||
| %r = frem <4 x bfloat> %x, %y | ||
| ret <4 x bfloat> %r | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw test for OpFMod is missing, and I am not sure if we need to add test for OpFMod as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is intentional, as LLVM doesn't have (that I know) an equivalent for OpFMod (taking the sign of the divisor), it only has frem (taking the sign of the dividend). If desired, I can add a test that has one of the fancy __spirv builtins.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation! I noticed that there don't appear to be any tests for OpFMod across the llvm-project, and we should probably add test coverage for it. However, I am not sure if we should add it in this patch or if it should be in a separate patch for the test for OpFMod.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Intuitively I'd be in favour of doing it separately, as thoroughly testing it would entail a wider net (not just bfloat16), and it feels like a separate piece of work. That's just my 2p on the matter though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I also agree that it is better to do it separately since it will not only cover the case of 'bfloat16' :)
YixingZhang007
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM! I left a small comment regarding the testing. Thank you!
Cheers - just to clarify, are you OK with this going in as is, or would you like to add some OpFmod testing - I've commented on the thread on that, but I'm not strongly opposed to adding it if you really want it. |
I think everything looks good and 'OpFmod' can be added in a future patch. Thank you! 👍 |
…tomics` extension (#166257) This enables support for atomic RMW ops (add, sub, min and max to be precise) with `bfloat16` operands, via the [SPV_INTEL_16bit_atomics extension](intel/llvm#20009). It's logically a successor to #166031 (I should've used a stack), but I'm putting it up for early review. --------- Co-authored-by: Matt Arsenault <[email protected]>
Enable the
SPV_INTEL_bfloat16_arithmeticextension, which allows arithmetic, relational andOpExtInstinstructions to takebfloat16arguments. This patch only adds support to arithmetic and relational ops. The extension itself is rather fresh, butbfloat16is ubiquitous at this point and not supporting these ops is limiting.