Skip to content

Conversation

YixingZhang007
Copy link
Contributor

@YixingZhang007 YixingZhang007 commented Aug 27, 2025

This PR introduces the support for the SPIR-V extension SPV_KHR_bfloat16. This extension extends the OpTypeFloat instruction to enable the use of bfloat16 types with cooperative matrices and dot products.

TODO:
Per the SPV_KHR_bfloat16 extension, there are a limited number of instructions that can use the bfloat16 type. For example, arithmetic instructions like FAdd or FMul can't operate on bfloat16 values. Therefore, a future patch should be added to either emit an error or fall back to FP32 for arithmetic in cases where bfloat16 must not be used.

Reference Specification:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_bfloat16.asciidoc

@llvmbot
Copy link
Member

llvmbot commented Aug 27, 2025

@llvm/pr-subscribers-llvm-globalisel

@llvm/pr-subscribers-backend-spir-v

Author: None (YixingZhang007)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/155645.diff

6 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/MachineInstr.h (+2-1)
  • (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+7-2)
  • (modified) llvm/lib/CodeGen/MachineInstr.cpp (+3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+3-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+4)
diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h
index 10a9b1ff1411d..6f692ae32510b 100644
--- a/llvm/include/llvm/CodeGen/MachineInstr.h
+++ b/llvm/include/llvm/CodeGen/MachineInstr.h
@@ -123,8 +123,9 @@ class MachineInstr
     NoUSWrap = 1 << 20,      // Instruction supports geps
                              // no unsigned signed wrap.
     SameSign = 1 << 21,      // Both operands have the same sign.
-    InBounds = 1 << 22       // Pointer arithmetic remains inbounds.
+    InBounds = 1 << 22,       // Pointer arithmetic remains inbounds.
                              // Implies NoUSWrap.
+    BFloat16 = 1 << 23      // Instruction with bf16 type
   };
 
 private:
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 541269ab6bfce..2a6b66984c8ae 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
-    return false;
+  // if (containsBF16Type(U))
+  //   return false;
 
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();
@@ -2813,6 +2813,11 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
   if (isa<FPMathOperator>(CI))
     MIB->copyIRFlags(CI);
 
+  // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+  if (containsBF16Type(U)) {
+    MIB->setFlag(MachineInstr::MIFlag::BFloat16);
+  }
+
   for (const auto &Arg : enumerate(CI.args())) {
     // If this is required to be an immediate, don't materialize it in a
     // register.
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 79047f732808a..10ff667bcb522 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,6 +632,9 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
   if (I.getMetadata(LLVMContext::MD_unpredictable))
     MIFlags |= MachineInstr::MIFlag::Unpredictable;
 
+  if (I.getType()->isBFloatTy())
+    MIFlags |= MachineInstr::MIFlag::BFloat16;
+
   return MIFlags;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e7da5504b2d58..bd13a3bae92cd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -147,7 +147,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
         {"SPV_KHR_float_controls2",
          SPIRV::Extension::Extension::SPV_KHR_float_controls2},
         {"SPV_INTEL_tensor_float32_conversion",
-         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
+         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
+        {"SPV_KHR_bfloat16",
+         SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8039cf0c432fa..5bba5cdce3753 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1267,6 +1267,10 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16)
       Reqs.addCapability(SPIRV::Capability::Float16);
+    if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+      Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+    }
     break;
   }
   case SPIRV::OpTypeVector: {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..9d630356e8ffb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
 defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
 defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
 defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
 defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
 defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
 defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
+defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
+defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time

Copy link

github-actions bot commented Aug 27, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@YixingZhang007 YixingZhang007 changed the title [SPIRV] Add bfloat support [SPIRV] Add support for the SPIR-V extension SPV_KHR_bfloat16 Sep 3, 2025
@YixingZhang007 YixingZhang007 force-pushed the add_spv_khr_bfloat16_extension_support branch 3 times, most recently from 6de7136 to 95aa9a3 Compare September 4, 2025 12:09
@arsenm arsenm requested a review from tgymnich September 4, 2025 12:29
@YixingZhang007
Copy link
Contributor Author

YixingZhang007 commented Sep 4, 2025

The FPVariant support in this PR will be moved to PR #156871 and committed to llvm-project separately.

@YixingZhang007 YixingZhang007 force-pushed the add_spv_khr_bfloat16_extension_support branch from 32498d0 to e0b7026 Compare September 5, 2025 17:30
@YixingZhang007 YixingZhang007 marked this pull request as ready for review September 8, 2025 14:37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bfloat type should be encoded like OpTypeFloat 16 0, where 0 stands for FP encoding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out the issue 🙂 I’ve updated the PR so that the Bfloat type is encoded using the SPIR-V instruction OpTypeFloat 16 0, which is now distinct from the IEEE-754 float encoded as OpTypeFloat 16.

Copy link
Contributor

@MrSidims MrSidims left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually can keep the enum and pass it as an argument. Reason: apart of IEEE-754 and bfloat types we will be introducing 8-bit floating point types (soon) and potentially other types later. So having the enum instead of the boolean flag would help here.

Copy link
Contributor Author

@YixingZhang007 YixingZhang007 Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the suggestion! I’ve created an enum class FPEncoding to store the values of the supported FP encodings and it is then passed as an argument to function getOpTypeFloat.
https://github.com/llvm/llvm-project/blob/8e9389ddcccbbce15e484cbdf0a89f27a3c07256/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td#L2009-L2028
The function getOpTypeFloat now has two interfaces: one that takes FPEncoding as an argument, used when the float is not an IEEE-754 type, and another without FPEncoding, used when the float is an IEEE-754 type.
https://github.com/llvm/llvm-project/blob/8e9389ddcccbbce15e484cbdf0a89f27a3c07256/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp#L197-L198
https://github.com/llvm/llvm-project/blob/8e9389ddcccbbce15e484cbdf0a89f27a3c07256/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp#L207-L209

Copy link
Contributor

@MrSidims MrSidims left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit of TODO (unsure if it should go in this patch or in the later patches): per SPV_KHR_bfloat extension there are limited number of instructions that can use the type. For example arithmetic instructions like FAdd or FMul can't use bfloat values, hence SPIR-V backend should either emit an error or fall back to FP32 for arithmetic (probably just be calling OpFConvert to FP32 and using the result as the new value).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan,EnvOpenCL]>;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the suggestion! I have made the change :)

@YixingZhang007
Copy link
Contributor Author

YixingZhang007 commented Sep 9, 2025

a bit of TODO (unsure if it should go in this patch or in the later patches): per SPV_KHR_bfloat extension there are limited number of instructions that can use the type. For example arithmetic instructions like FAdd or FMul can't use bfloat values, hence SPIR-V backend should either emit an error or fall back to FP32 for arithmetic (probably just be calling OpFConvert to FP32 and using the result as the new value).

For sure! I’m okay with adding this either in this patch or in a later one. Since this PR already has quite a lot of changes (even after I move the part for supporting bfloat in the SPIR-V backend to a separate PR), I’d slightly prefer handling additional rules in a follow-up patch if there are many more to cover.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also check the bitwidth?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I have added the check for bitwidth here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use isBFloat16Type(MI) instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sure, I have made the change. Thank you :)

@YixingZhang007 YixingZhang007 force-pushed the add_spv_khr_bfloat16_extension_support branch from afd09a6 to 389096a Compare September 15, 2025 12:12
@YixingZhang007 YixingZhang007 force-pushed the add_spv_khr_bfloat16_extension_support branch from 389096a to 8ad95c0 Compare September 15, 2025 13:14
@MrSidims MrSidims merged commit f91e0bf into llvm:main Sep 22, 2025
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants