Skip to content

Conversation

@AlexVlx
Copy link
Contributor

@AlexVlx AlexVlx commented Nov 2, 2025

Enable the SPV_INTEL_bfloat16_arithmetic extension, which allows arithmetic, relational and OpExtInst instructions to take bfloat16 arguments. This patch only adds support to arithmetic and relational ops. The extension itself is rather fresh, but bfloat16 is ubiquitous at this point and not supporting these ops is limiting.

@llvmbot
Copy link
Member

llvmbot commented Nov 2, 2025

@llvm/pr-subscribers-backend-spir-v

Author: Alex Voicu (AlexVlx)

Changes

Enable the SPV_INTEL_bfloat16_arithmetic extension, which allows arithmetic, relational and OpExtInst instructions to take bfloat16 arguments. This patch only adds support to arithmetic and relational ops. The extension itself is rather fresh, but bfloat16 is ubiquitous at this point and not supporting these ops is limiting.


Patch is 29.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166031.diff

7 Files Affected:

  • (modified) llvm/docs/SPIRVUsage.rst (+4-2)
  • (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+5-5)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+61-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+4-1)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll (+142)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll (+376)
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 85eeabf10244a..922c9e1a3cfe3 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -173,6 +173,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
      - Allows generating arbitrary width integer types.
    * - ``SPV_INTEL_bindless_images``
      - Adds instructions to convert convert unsigned integer handles to images, samplers and sampled images.
+   * - ``SPV_INTEL_bfloat16_arithmetic``
+     - Allows the use of 16-bit bfloat16 values in arithmetic and relational operators.
    * - ``SPV_INTEL_bfloat16_conversion``
      - Adds instructions to convert between single-precision 32-bit floating-point values and 16-bit bfloat16 values.
    * - ``SPV_INTEL_cache_controls``
@@ -226,9 +228,9 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
    * - ``SPV_INTEL_fp_max_error``
      - Adds the ability to specify the maximum error for floating-point operations.
    * - ``SPV_INTEL_ternary_bitwise_function``
-     - Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform. 
+     - Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform.
    * - ``SPV_INTEL_subgroup_matrix_multiply_accumulate``
-     - Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix. 
+     - Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.
    * - ``SPV_INTEL_int4``
      - Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.
    * - ``SPV_KHR_float_controls2``
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 1fc90d0852aad..847163edcbc4b 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -306,7 +306,7 @@ static bool containsBF16Type(const User &U) {
 
 bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U,
                                      MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
     return false;
 
   // Get or create a virtual register for each value.
@@ -328,7 +328,7 @@ bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U,
 
 bool IRTranslator::translateUnaryOp(unsigned Opcode, const User &U,
                                     MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
     return false;
 
   Register Op0 = getOrCreateVReg(*U.getOperand(0));
@@ -348,7 +348,7 @@ bool IRTranslator::translateFNeg(const User &U, MachineIRBuilder &MIRBuilder) {
 
 bool IRTranslator::translateCompare(const User &U,
                                     MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
     return false;
 
   auto *CI = cast<CmpInst>(&U);
@@ -1569,7 +1569,7 @@ bool IRTranslator::translateBitCast(const User &U,
 
 bool IRTranslator::translateCast(unsigned Opcode, const User &U,
                                  MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
     return false;
 
   uint32_t Flags = 0;
@@ -2688,7 +2688,7 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
 
 bool IRTranslator::translateInlineAsm(const CallBase &CB,
                                       MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(CB))
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(CB))
     return false;
 
   const InlineAsmLowering *ALI = MF->getSubtarget().getInlineAsmLowering();
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 96f5dee21bc2a..9643db1f1bf53 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -107,6 +107,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
          SPIRV::Extension::Extension::SPV_INTEL_inline_assembly},
         {"SPV_INTEL_bindless_images",
          SPIRV::Extension::Extension::SPV_INTEL_bindless_images},
+        {"SPV_INTEL_bfloat16_arithmetic",
+         SPIRV::Extension::Extension::SPV_INTEL_bfloat16_arithmetic},
         {"SPV_INTEL_bfloat16_conversion",
          SPIRV::Extension::Extension::SPV_INTEL_bfloat16_conversion},
         {"SPV_KHR_subgroup_rotate",
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index db036a55ee6c6..009d2dc1a567a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1435,6 +1435,8 @@ void addInstrRequirements(const MachineInstr &MI,
       addPrintfRequirements(MI, Reqs, ST);
       break;
     }
+    // TODO: handle bfloat16 extended instructions when
+    // SPV_INTEL_bfloat16_arithmetic is enabled.
     break;
   }
   case SPIRV::OpAliasDomainDeclINTEL:
@@ -2060,7 +2062,65 @@ void addInstrRequirements(const MachineInstr &MI,
     Reqs.addCapability(SPIRV::Capability::PredicatedIOINTEL);
     break;
   }
-
+  case SPIRV::OpFAddS:
+  case SPIRV::OpFSubS:
+  case SPIRV::OpFMulS:
+  case SPIRV::OpFDivS:
+  case SPIRV::OpFRemS:
+  case SPIRV::OpFMod:
+  case SPIRV::OpFNegate:
+  case SPIRV::OpFAddV:
+  case SPIRV::OpFSubV:
+  case SPIRV::OpFMulV:
+  case SPIRV::OpFDivV:
+  case SPIRV::OpFRemV:
+  case SPIRV::OpFNegateV: {
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+    if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
+      TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
+    if (isBFloat16Type(TypeDef)) {
+      if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
+        report_fatal_error(
+            "Arithmetic instructions with bfloat16 arguments require the "
+            "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
+            false);
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
+      Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
+    }
+    break;
+  }
+  case SPIRV::OpOrdered:
+  case SPIRV::OpUnordered:
+  case SPIRV::OpFOrdEqual:
+  case SPIRV::OpFOrdNotEqual:
+  case SPIRV::OpFOrdLessThan:
+  case SPIRV::OpFOrdLessThanEqual:
+  case SPIRV::OpFOrdGreaterThan:
+  case SPIRV::OpFOrdGreaterThanEqual:
+  case SPIRV::OpFUnordEqual:
+  case SPIRV::OpFUnordNotEqual:
+  case SPIRV::OpFUnordLessThan:
+  case SPIRV::OpFUnordLessThanEqual:
+  case SPIRV::OpFUnordGreaterThan:
+  case SPIRV::OpFUnordGreaterThanEqual: {
+    const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
+    SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+    SPIRVType *TypeDef = GR->getSPIRVTypeForVReg(MI.getOperand(2).getReg());
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
+      TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
+    if (isBFloat16Type(TypeDef)) {
+      if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
+        report_fatal_error(
+            "Relational instructions with bfloat16 arguments require the "
+            "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
+            false);
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
+      Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
+    }
+    break;
+  }
   default:
     break;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 7d08b29a51a6e..263b59fbe6959 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -387,6 +387,8 @@ defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
 defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>;
 defm SPV_INTEL_predicated_io : ExtensionOperand<127, [EnvOpenCL]>;
 defm SPV_KHR_maximal_reconvergence : ExtensionOperand<128, [EnvVulkan]>;
+defm SPV_INTEL_bfloat16_arithmetic
+    : ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -570,6 +572,7 @@ defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atom
 defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
 defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
 defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
+defm BFloat16ArithmeticINTEL : CapabilityOperand<6226, 0, 0, [SPV_INTEL_bfloat16_arithmetic], []>;
 defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
 defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
 defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
@@ -1919,7 +1922,7 @@ defm GenericCastToPtr :  SpecConstantOpOperandsOperand<122, [], [Kernel]>;
 defm PtrCastToGeneric :  SpecConstantOpOperandsOperand<121, [], [Kernel]>;
 defm Bitcast :  SpecConstantOpOperandsOperand<124, [], []>;
 defm QuantizeToF16 :  SpecConstantOpOperandsOperand<116, [], [Shader]>;
-// Arithmetic 
+// Arithmetic
 defm SNegate :  SpecConstantOpOperandsOperand<126, [], []>;
 defm Not :  SpecConstantOpOperandsOperand<200, [], []>;
 defm IAdd :  SpecConstantOpOperandsOperand<128, [], []>;
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll
new file mode 100644
index 0000000000000..4cabddb94df25
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll
@@ -0,0 +1,142 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: Arithmetic instructions with bfloat16 arguments require the following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16ArithmeticINTEL
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpExtension "SPV_INTEL_bfloat16_arithmetic"
+; CHECK-DAG: OpName [[NEG:%.*]] "neg"
+; CHECK-DAG: OpName [[NEGV:%.*]] "negv"
+; CHECK-DAG: OpName [[ADD:%.*]] "add"
+; CHECK-DAG: OpName [[ADDV:%.*]] "addv"
+; CHECK-DAG: OpName [[SUB:%.*]] "sub"
+; CHECK-DAG: OpName [[SUBV:%.*]] "subv"
+; CHECK-DAG: OpName [[MUL:%.*]] "mul"
+; CHECK-DAG: OpName [[MULV:%.*]] "mulv"
+; CHECK-DAG: OpName [[DIV:%.*]] "div"
+; CHECK-DAG: OpName [[DIVV:%.*]] "divv"
+; CHECK-DAG: OpName [[REM:%.*]] "rem"
+; CHECK-DAG: OpName [[REMV:%.*]] "remv"
+; CHECK: [[BFLOAT:%.*]] = OpTypeFloat 16 0
+; CHECK: [[BFLOATV:%.*]] = OpTypeVector [[BFLOAT]] 4
+
+; CHECK-DAG: [[NEG]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFNegate [[BFLOAT]] [[X]]
+define spir_func bfloat @neg(bfloat %x) {
+entry:
+  %r = fneg bfloat %x
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[NEGV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFNegate [[BFLOATV]] [[X]]
+define spir_func <4 x bfloat> @negv(<4 x bfloat> %x) {
+entry:
+  %r = fneg <4 x bfloat> %x
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[ADD]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFAdd [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @add(bfloat %x, bfloat %y) {
+entry:
+  %r = fadd bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[ADDV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFAdd [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @addv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = fadd <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[SUB]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFSub [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @sub(bfloat %x, bfloat %y) {
+entry:
+  %r = fsub bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[SUBV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFSub [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @subv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = fsub <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[MUL]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFMul [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @mul(bfloat %x, bfloat %y) {
+entry:
+  %r = fmul bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[MULV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFMul [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @mulv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = fmul <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[DIV]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFDiv [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @div(bfloat %x, bfloat %y) {
+entry:
+  %r = fdiv bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[DIVV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFDiv [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @divv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = fdiv <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[REM]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFRem [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @rem(bfloat %x, bfloat %y) {
+entry:
+  %r = frem bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[REMV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFRem [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @remv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = frem <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll
new file mode 100644
index 0000000000000..3774791d58f87
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll
@@ -0,0 +1,376 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: Relational instructions with bfloat16 arguments require the following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16ArithmeticINTEL
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpExtension "SPV_INTEL_bfloat16_arithmetic"
+; CHECK-DAG: OpName [[UEQ:%.*]] "test_ueq"
+; CHECK-DAG: OpName [[OEQ:%.*]] "test_oeq"
+; CHECK-DAG: OpName [[UNE:%.*]] "test_une"
+; CHECK-DAG: OpName [[ONE:%.*]] "test_one"
+; CHECK-DAG: OpName [[ULT:%.*]] "test_ult"
+; CHECK-DAG: OpName [[OLT:%.*]] "test_olt"
+; CHECK-DAG: OpName [[ULE:%.*]] "test_ule"
+; CHECK-DAG: OpName [[OLE:%.*]] "test_ole"
+; CHECK-DAG: OpName [[UGT:%.*]] "test_ugt"
+; CHECK-DAG: OpName [[OGT:%.*]] "test_ogt"
+; CHECK-DAG: OpName [[UGE:%.*]] "test_uge"
+; CHECK-DAG: OpName [[OGE:%.*]] "test_oge"
+; CHECK-DAG: OpName [[UNO:%.*]] "test_uno"
+; CHECK-DAG: OpName [[ORD:%.*]] "test_ord"
+; CHECK-DAG: OpName [[v3UEQ:%.*]] "test_v3_ueq"
+; CHECK-DAG: OpName [[v3OEQ:%.*]] "test_v3_oeq"
+; CHECK-DAG: OpName [[v3UNE:%.*]] "test_v3_une"
+; CHECK-DAG: OpName [[v3ONE:%.*]] "test_v3_one"
+; CHECK-DAG: OpName [[v3ULT:%.*]] "test_v3_ult"
+; CHECK-DAG: OpName [[v3OLT:%.*]] "test_v3_olt"
+; CHECK-DAG: OpName [[v3ULE:%.*]] "test_v3_ule"
+; CHECK-DAG: OpName [[v3OLE:%.*]] "test_v3_ole"
+; CHECK-DAG: OpName [[v3UGT:%.*]] "test_v3_ugt"
+; CHECK-DAG: OpName [[v3OGT:%.*]] "test_v3_ogt"
+; CHECK-DAG: OpName [[v3UGE:%.*]] "test_v3_uge"
+; CHECK-DAG: OpName [[v3OGE:%.*]] "test_v3_oge"
+; CHECK-DAG: OpName [[v3UNO:%.*]] "test_v3_uno"
+; CHECK-DAG: OpName [[v3ORD:%.*]] "test_v3_ord"
+; CHECK: [[BFLOAT:%.*]] = OpTypeFloat 16 0
+; CHECK: [[BFLOATV:%.*]] = OpTypeVector [[BFLOAT]] 3
+
+; CHECK:      [[UEQ]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ueq(bfloat %a, bfloat %b) {
+  %r = fcmp ueq bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[OEQ]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_oeq(bfloat %a, bfloat %b) {
+  %r = fcmp oeq bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[UNE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordNotEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_une(bfloat %a, bfloat %b) {
+  %r = fcmp une bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[ONE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdNotEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_one(bfloat %a, bfloat %b) {
+  %r = fcmp one bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[ULT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordLessThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ult(bfloat %a, bfloat %b) {
+  %r = fcmp ult bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[OLT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLO...
[truncated]

Copy link
Contributor

@jmmartinez jmmartinez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment, everything else seems fine.

bool IRTranslator::translateInlineAsm(const CallBase &CB,
MachineIRBuilder &MIRBuilder) {
if (containsBF16Type(CB))
if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(CB))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Can we add a helper canHandleBF16ForTarget() that does the MF->getTarget().getTargetTriple().isSPIRV() check?

Adding a hook in TargetMachine would be ideal but it'd be an overkill.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@MrSidims
Copy link
Contributor

MrSidims commented Nov 3, 2025

@YixingZhang007 please take a look

Copy link
Contributor Author

@AlexVlx AlexVlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Can we add a helper canHandleBF16ForTarget() that does the MF->getTarget().getTargetTriple().isSPIRV() check?

Adding a hook in TargetMachine would be ideal but it'd be an overkill.

We can; I'm actually not entirely certain that the check isn't vestigial at this point / if it should happen in IRTranslator at all rather than something that is left to the instruction selection / the target (the LTT issue that is mentioned in a TODO elsewhere might've been solved). Perhaps someone who is more informed can advise?

entry:
%r = frem <4 x bfloat> %x, %y
ret <4 x bfloat> %r
}
Copy link
Contributor

@YixingZhang007 YixingZhang007 Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw test for OpFMod is missing, and I am not sure if we need to add test for OpFMod as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is intentional, as LLVM doesn't have (that I know) an equivalent for OpFMod (taking the sign of the divisor), it only has frem (taking the sign of the dividend). If desired, I can add a test that has one of the fancy __spirv builtins.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! I noticed that there don't appear to be any tests for OpFMod across the llvm-project, and we should probably add test coverage for it. However, I am not sure if we should add it in this patch or if it should be in a separate patch for the test for OpFMod.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intuitively I'd be in favour of doing it separately, as thoroughly testing it would entail a wider net (not just bfloat16), and it feels like a separate piece of work. That's just my 2p on the matter though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also agree that it is better to do it separately since it will not only cover the case of 'bfloat16' :)

Copy link
Contributor

@YixingZhang007 YixingZhang007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM! I left a small comment regarding the testing. Thank you!

@AlexVlx
Copy link
Contributor Author

AlexVlx commented Nov 4, 2025

Overall LGTM! I left a small comment regarding the testing. Thank you!

Cheers - just to clarify, are you OK with this going in as is, or would you like to add some OpFmod testing - I've commented on the thread on that, but I'm not strongly opposed to adding it if you really want it.

@YixingZhang007
Copy link
Contributor

Overall LGTM! I left a small comment regarding the testing. Thank you!

Cheers - just to clarify, are you OK with this going in as is, or would you like to add some OpFmod testing - I've commented on the thread on that, but I'm not strongly opposed to adding it if you really want it.

I think everything looks good and 'OpFmod' can be added in a future patch. Thank you! 👍

@AlexVlx AlexVlx merged commit 2286118 into llvm:main Nov 4, 2025
12 checks passed
@AlexVlx AlexVlx deleted the spirv_be_staging_3 branch November 4, 2025 16:10
AlexVlx added a commit that referenced this pull request Nov 9, 2025
…tomics` extension (#166257)

This enables support for atomic RMW ops (add, sub, min and max to be
precise) with `bfloat16` operands, via the [SPV_INTEL_16bit_atomics
extension](intel/llvm#20009). It's logically a
successor to #166031 (I should've used a stack), but I'm putting it up
for early review.

---------

Co-authored-by: Matt Arsenault <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants