From d55de5c309c0c7fc5b34a85046f9c764bd234ed3 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Mon, 18 Aug 2025 05:06:14 -0700 Subject: [PATCH 1/7] add the field isBfloat in LLT --- llvm/lib/CodeGen/LowLevelTypeUtils.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp index 936c9fbb2fff0..78f68421f49b7 100644 --- a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp +++ b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp @@ -27,7 +27,7 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) { } if (auto PTy = dyn_cast(&Ty)) { - unsigned AddrSpace = PTy->getAddressSpace(); + unsigned AddrSpace = PTy->isTokenTy(); return LLT::pointer(AddrSpace, DL.getPointerSizeInBits(AddrSpace)); } @@ -36,7 +36,7 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) { // concerned. auto SizeInBits = DL.getTypeSizeInBits(&Ty); assert(SizeInBits != 0 && "invalid zero-sized type"); - return LLT::scalar(SizeInBits); + return LLT::scalar(SizeInBits, EC.isBfloat()); } if (Ty.isTokenTy()) @@ -73,6 +73,9 @@ LLT llvm::getLLTForMVT(MVT Ty) { const llvm::fltSemantics &llvm::getFltSemanticForLLT(LLT Ty) { assert(Ty.isScalar() && "Expected a scalar type."); + if(Ty.isBFloat()) { + return APFloat::BFloat(); + } switch (Ty.getSizeInBits()) { case 16: return APFloat::IEEEhalf(); From 1c4c952d6601d5a06a70014092aaa6906110b74f Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Mon, 18 Aug 2025 05:08:32 -0700 Subject: [PATCH 2/7] add support for isbfloat in llt --- llvm/include/llvm/CodeGenTypes/LowLevelType.h | 76 +++++++++++-------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/llvm/include/llvm/CodeGenTypes/LowLevelType.h b/llvm/include/llvm/CodeGenTypes/LowLevelType.h index d8e0848aff84d..a484a85e50793 100644 --- a/llvm/include/llvm/CodeGenTypes/LowLevelType.h +++ b/llvm/include/llvm/CodeGenTypes/LowLevelType.h @@ -30,6 +30,7 @@ #include "llvm/CodeGenTypes/MachineValueType.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include namespace llvm { @@ -40,8 +41,8 @@ class raw_ostream; class LLT { public: /// Get a low-level scalar or aggregate "bag of bits". - static constexpr LLT scalar(unsigned SizeInBits) { - return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, + static constexpr LLT scalar(unsigned SizeInBits, bool isBfloat) { + return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, /*isBfloat=*/isBfloat, ElementCount::getFixed(0), SizeInBits, /*AddressSpace=*/0}; } @@ -49,7 +50,7 @@ class LLT { /// Get a low-level token; just a scalar with zero bits (or no size). static constexpr LLT token() { return LLT{/*isPointer=*/false, /*isVector=*/false, - /*isScalar=*/true, ElementCount::getFixed(0), + /*isScalar=*/true, /*isBfloat=*/false, ElementCount::getFixed(0), /*SizeInBits=*/0, /*AddressSpace=*/0}; } @@ -57,14 +58,14 @@ class LLT { /// Get a low-level pointer in the given address space. static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits) { assert(SizeInBits > 0 && "invalid pointer size"); - return LLT{/*isPointer=*/true, /*isVector=*/false, /*isScalar=*/false, + return LLT{/*isPointer=*/true, /*isVector=*/false, /*isScalar=*/false, /*isBfloat=*/false, ElementCount::getFixed(0), SizeInBits, AddressSpace}; } /// Get a low-level vector of some number of elements and element width. static constexpr LLT vector(ElementCount EC, unsigned ScalarSizeInBits) { assert(!EC.isScalar() && "invalid number of vector elements"); - return LLT{/*isPointer=*/false, /*isVector=*/true, /*isScalar=*/false, + return LLT{/*isPointer=*/false, /*isVector=*/true, /*isScalar=*/false, /*isBfloat=*/false, EC, ScalarSizeInBits, /*AddressSpace=*/0}; } @@ -80,10 +81,15 @@ class LLT { ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0}; } + // Get a 16-bit brain float value. + static constexpr LLT bfloat16() { + return scalar(16, true); + } + /// Get a 16-bit IEEE half value. /// TODO: Add IEEE semantics to type - This currently returns a simple `scalar(16)`. static constexpr LLT float16() { - return scalar(16); + return scalar(16, false); } /// Get a 32-bit IEEE float value. @@ -136,10 +142,10 @@ class LLT { ElementCount EC, uint64_t SizeInBits, unsigned AddressSpace) : LLT() { - init(isPointer, isVector, isScalar, EC, SizeInBits, AddressSpace); + init(isPointer, isVector, isScalar, isBfloat, E C, SizeInBits, AddressSpace); } explicit constexpr LLT() - : IsScalar(false), IsPointer(false), IsVector(false), RawData(0) {} + : IsScalar(false), IsPointer(false), IsVector(false), isBfloat(false), RawData(0) {} LLVM_ABI explicit LLT(MVT VT); @@ -154,6 +160,7 @@ class LLT { constexpr bool isPointerOrPointerVector() const { return IsPointer && isValid(); } + constexpr bool isBfloat() const { return isBfloat; } /// Returns the number of elements in a vector LLT. Must only be called on /// vector types. @@ -304,6 +311,7 @@ class LLT { /// isScalar : 1 /// isPointer : 1 /// isVector : 1 + /// isVector : 1 /// with 61 bits remaining for Kind-specific data, packed in bitfields /// as described below. As there isn't a simple portable way to pack bits /// into bitfields, here the different fields in the packed structure is @@ -311,25 +319,27 @@ class LLT { /// is a 2-element array, with the first element describing the bitfield size /// and the second element describing the bitfield offset. /// - /// +--------+---------+--------+----------+----------------------+ - /// |isScalar|isPointer|isVector| RawData |Notes | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 0 | 0 | 0 |Invalid | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 0 | 1 | 0 |Tombstone Key | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 1 | 0 | 0 |Empty Key | - /// +--------+---------+--------+----------+----------------------+ - /// | 1 | 0 | 0 | 0 |Token | - /// +--------+---------+--------+----------+----------------------+ - /// | 1 | 0 | 0 | non-zero |Scalar | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 1 | 0 | non-zero |Pointer | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 0 | 1 | non-zero |Vector of non-pointer | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 1 | 1 | non-zero |Vector of pointer | - /// +--------+---------+--------+----------+----------------------+ + /// +--------+---------+--------+----------+----------+----------------------+ + /// |isScalar|isPointer|isVector| isBfloat | RawData |Notes | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 0 | 0 | 0 | 0 | 0 |Invalid | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 0 | 0 | 1 | 0 | 0 |Tombstone Key | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 0 | 1 | 0 | 0 | 0 |Empty Key | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 1 | 0 | 0 | 0 | 0 |Token | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 1 | 0 | 0 | 0 | non-zero |Scalar | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 1 | 0 | 0 | 1 | non-zero |Scalar (Bfloat 16) | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 0 | 1 | 0 | 0 | non-zero |Pointer | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 0 | 0 | 1 | 0 | non-zero |Vector of non-pointer | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 0 | 1 | 1 0 | non-zero |Vector of pointer | + /// +--------+---------+--------+----------+----------+----------------------+ /// /// Everything else is reserved. typedef int BitFieldInfo[2]; @@ -340,17 +350,17 @@ class LLT { /// valid encodings, SizeInBits/SizeOfElement must be larger than 0. /// * Non-pointer scalar (isPointer == 0 && isVector == 0): /// SizeInBits: 32; - static const constexpr BitFieldInfo ScalarSizeFieldInfo{32, 29}; + static const constexpr BitFieldInfo ScalarSizeFieldInfo{32, 28}; /// * Pointer (isPointer == 1 && isVector == 0): /// SizeInBits: 16; /// AddressSpace: 24; - static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 45}; + static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 44}; static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{24, 21}; /// * Vector-of-non-pointer (isPointer == 0 && isVector == 1): /// NumElements: 16; /// SizeOfElement: 32; /// Scalable: 1; - static const constexpr BitFieldInfo VectorElementsFieldInfo{16, 5}; + static const constexpr BitFieldInfo VectorElementsFieldInfo{16, 4}; static const constexpr BitFieldInfo VectorScalableFieldInfo{1, 0}; /// * Vector-of-pointer (isPointer == 1 && isVector == 1): /// NumElements: 16; @@ -361,7 +371,8 @@ class LLT { uint64_t IsScalar : 1; uint64_t IsPointer : 1; uint64_t IsVector : 1; - uint64_t RawData : 61; + uint64_t IsBfloat : 1; + uint64_t RawData : 60; static constexpr uint64_t getMask(const BitFieldInfo FieldInfo) { const int FieldSizeInBits = FieldInfo[0]; @@ -389,6 +400,7 @@ class LLT { this->IsPointer = IsPointer; this->IsVector = IsVector; this->IsScalar = IsScalar; + this->IsScalar = IsBfloat; if (IsPointer) { RawData = maskAndShift(SizeInBits, PointerSizeFieldInfo) | maskAndShift(AddressSpace, PointerAddressSpaceFieldInfo); @@ -403,7 +415,7 @@ class LLT { public: constexpr uint64_t getUniqueRAWLLTData() const { - return ((uint64_t)RawData) << 3 | ((uint64_t)IsScalar) << 2 | + return ((uint64_t)RawData) << 4 | ((uint64_t)IsBfloat) << 3 | ((uint64_t)IsScalar) << 2 | ((uint64_t)IsPointer) << 1 | ((uint64_t)IsVector); } }; From c04c8a33c861afddedbbd702686c79d1572272ad Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Mon, 18 Aug 2025 09:19:58 -0700 Subject: [PATCH 3/7] update the llt field define --- llvm/include/llvm/CodeGenTypes/LowLevelType.h | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/llvm/include/llvm/CodeGenTypes/LowLevelType.h b/llvm/include/llvm/CodeGenTypes/LowLevelType.h index a484a85e50793..2addc9383ba3d 100644 --- a/llvm/include/llvm/CodeGenTypes/LowLevelType.h +++ b/llvm/include/llvm/CodeGenTypes/LowLevelType.h @@ -138,7 +138,7 @@ class LLT { return scalarOrVector(EC, LLT::scalar(static_cast(ScalarSize))); } - explicit constexpr LLT(bool isPointer, bool isVector, bool isScalar, + explicit constexpr LLT(bool isPointer, bool isVector, bool isScalar, bool isBfloat, ElementCount EC, uint64_t SizeInBits, unsigned AddressSpace) : LLT() { @@ -250,7 +250,10 @@ class LLT { } assert(getScalarSizeInBits() % Factor == 0); - return scalar(getScalarSizeInBits() / Factor); + if(isBfloat()) { + return scalar(getScalarSizeInBits() / Factor, true); + } + return scalar(getScalarSizeInBits() / Factor, false); } /// Produce a vector type that is \p Factor times bigger, preserving the @@ -311,8 +314,8 @@ class LLT { /// isScalar : 1 /// isPointer : 1 /// isVector : 1 - /// isVector : 1 - /// with 61 bits remaining for Kind-specific data, packed in bitfields + /// isBfloat : 1 + /// with 60 bits remaining for Kind-specific data, packed in bitfields /// as described below. As there isn't a simple portable way to pack bits /// into bitfields, here the different fields in the packed structure is /// described in static const *Field variables. Each of these variables @@ -355,12 +358,12 @@ class LLT { /// SizeInBits: 16; /// AddressSpace: 24; static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 44}; - static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{24, 21}; + static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{24, 20}; /// * Vector-of-non-pointer (isPointer == 0 && isVector == 1): /// NumElements: 16; /// SizeOfElement: 32; /// Scalable: 1; - static const constexpr BitFieldInfo VectorElementsFieldInfo{16, 4}; + static const constexpr BitFieldInfo VectorElementsFieldInfo{16, 5}; static const constexpr BitFieldInfo VectorScalableFieldInfo{1, 0}; /// * Vector-of-pointer (isPointer == 1 && isVector == 1): /// NumElements: 16; @@ -392,7 +395,7 @@ class LLT { return getMask(FieldInfo) & (RawData >> FieldInfo[1]); } - constexpr void init(bool IsPointer, bool IsVector, bool IsScalar, + constexpr void init(bool IsPointer, bool IsVector, bool IsScalar, bool IsBfloat, ElementCount EC, uint64_t SizeInBits, unsigned AddressSpace) { assert(SizeInBits <= std::numeric_limits::max() && @@ -400,7 +403,7 @@ class LLT { this->IsPointer = IsPointer; this->IsVector = IsVector; this->IsScalar = IsScalar; - this->IsScalar = IsBfloat; + this->IsBfloat = IsBfloat; if (IsPointer) { RawData = maskAndShift(SizeInBits, PointerSizeFieldInfo) | maskAndShift(AddressSpace, PointerAddressSpaceFieldInfo); From 39c2d0c970261283de8a802f0b0a55ca31301a05 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Mon, 18 Aug 2025 12:33:45 -0700 Subject: [PATCH 4/7] solve the build issue --- llvm/include/llvm/CodeGenTypes/LowLevelType.h | 26 +++++++++++-------- llvm/lib/CodeGen/LowLevelTypeUtils.cpp | 4 +-- llvm/lib/CodeGenTypes/LowLevelType.cpp | 7 +++-- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/llvm/include/llvm/CodeGenTypes/LowLevelType.h b/llvm/include/llvm/CodeGenTypes/LowLevelType.h index 2addc9383ba3d..0aec3d2537d9c 100644 --- a/llvm/include/llvm/CodeGenTypes/LowLevelType.h +++ b/llvm/include/llvm/CodeGenTypes/LowLevelType.h @@ -41,8 +41,14 @@ class raw_ostream; class LLT { public: /// Get a low-level scalar or aggregate "bag of bits". - static constexpr LLT scalar(unsigned SizeInBits, bool isBfloat) { - return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, /*isBfloat=*/isBfloat, + static constexpr LLT scalar(unsigned SizeInBits) { + return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, /*isBfloat=*/false, + ElementCount::getFixed(0), SizeInBits, + /*AddressSpace=*/0}; + } + + static constexpr LLT scalar_bfloat(unsigned SizeInBits) { + return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, /*isBfloat=*/true, ElementCount::getFixed(0), SizeInBits, /*AddressSpace=*/0}; } @@ -76,6 +82,7 @@ class LLT { return LLT{ScalarTy.isPointer(), /*isVector=*/true, /*isScalar=*/false, + /*isBfloat=*/false, EC, ScalarTy.getSizeInBits().getFixedValue(), ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0}; @@ -83,13 +90,13 @@ class LLT { // Get a 16-bit brain float value. static constexpr LLT bfloat16() { - return scalar(16, true); + return scalar_bfloat(16); } /// Get a 16-bit IEEE half value. /// TODO: Add IEEE semantics to type - This currently returns a simple `scalar(16)`. static constexpr LLT float16() { - return scalar(16, false); + return scalar(16); } /// Get a 32-bit IEEE float value. @@ -142,10 +149,10 @@ class LLT { ElementCount EC, uint64_t SizeInBits, unsigned AddressSpace) : LLT() { - init(isPointer, isVector, isScalar, isBfloat, E C, SizeInBits, AddressSpace); + init(isPointer, isVector, isScalar, isBfloat, EC, SizeInBits, AddressSpace); } explicit constexpr LLT() - : IsScalar(false), IsPointer(false), IsVector(false), isBfloat(false), RawData(0) {} + : IsScalar(false), IsPointer(false), IsVector(false), IsBfloat(false), RawData(0) {} LLVM_ABI explicit LLT(MVT VT); @@ -160,7 +167,7 @@ class LLT { constexpr bool isPointerOrPointerVector() const { return IsPointer && isValid(); } - constexpr bool isBfloat() const { return isBfloat; } + constexpr bool isBfloat() const { return IsBfloat; } /// Returns the number of elements in a vector LLT. Must only be called on /// vector types. @@ -250,10 +257,7 @@ class LLT { } assert(getScalarSizeInBits() % Factor == 0); - if(isBfloat()) { - return scalar(getScalarSizeInBits() / Factor, true); - } - return scalar(getScalarSizeInBits() / Factor, false); + return scalar(getScalarSizeInBits() / Factor); } /// Produce a vector type that is \p Factor times bigger, preserving the diff --git a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp index 78f68421f49b7..c59b092496c50 100644 --- a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp +++ b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp @@ -36,7 +36,7 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) { // concerned. auto SizeInBits = DL.getTypeSizeInBits(&Ty); assert(SizeInBits != 0 && "invalid zero-sized type"); - return LLT::scalar(SizeInBits, EC.isBfloat()); + return LLT::scalar(SizeInBits); } if (Ty.isTokenTy()) @@ -73,7 +73,7 @@ LLT llvm::getLLTForMVT(MVT Ty) { const llvm::fltSemantics &llvm::getFltSemanticForLLT(LLT Ty) { assert(Ty.isScalar() && "Expected a scalar type."); - if(Ty.isBFloat()) { + if(Ty.isBfloat()) { return APFloat::BFloat(); } switch (Ty.getSizeInBits()) { diff --git a/llvm/lib/CodeGenTypes/LowLevelType.cpp b/llvm/lib/CodeGenTypes/LowLevelType.cpp index 4785f2652b00e..a13363a337472 100644 --- a/llvm/lib/CodeGenTypes/LowLevelType.cpp +++ b/llvm/lib/CodeGenTypes/LowLevelType.cpp @@ -19,18 +19,21 @@ using namespace llvm; LLT::LLT(MVT VT) { if (VT.isVector()) { bool asVector = VT.getVectorMinNumElements() > 1 || VT.isScalableVector(); - init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector, + init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector, /*isBfloat=*/false, VT.getVectorElementCount(), VT.getVectorElementType().getSizeInBits(), /*AddressSpace=*/0); } else if (VT.isValid() && !VT.isScalableTargetExtVT()) { // Aggregates are no different from real scalars as far as GlobalISel is // concerned. - init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true, + MVT ElemVT = VT.getVectorElementType(); + bool isElemBfloat = (ElemVT == MVT::bf16); + init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true, /*isBfloat=*/isElemBfloat, ElementCount::getFixed(0), VT.getSizeInBits(), /*AddressSpace=*/0); } else { IsScalar = false; IsPointer = false; IsVector = false; + IsBfloat = false; RawData = 0; } } From e7eeecb2eeb32d5c276098c15a4a37994369a8c3 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Mon, 18 Aug 2025 13:19:31 -0700 Subject: [PATCH 5/7] time out when running test --- llvm/lib/CodeGen/LowLevelTypeUtils.cpp | 6 +++--- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp index c59b092496c50..9fe3d3e43a47f 100644 --- a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp +++ b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp @@ -36,6 +36,9 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) { // concerned. auto SizeInBits = DL.getTypeSizeInBits(&Ty); assert(SizeInBits != 0 && "invalid zero-sized type"); + if (Ty.isBFloatTy()) { + return LLT::scalar_bfloat(SizeInBits); + } return LLT::scalar(SizeInBits); } @@ -73,9 +76,6 @@ LLT llvm::getLLTForMVT(MVT Ty) { const llvm::fltSemantics &llvm::getFltSemanticForLLT(LLT Ty) { assert(Ty.isScalar() && "Expected a scalar type."); - if(Ty.isBfloat()) { - return APFloat::BFloat(); - } switch (Ty.getSizeInBits()) { case 16: return APFloat::IEEEhalf(); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 8039cf0c432fa..51fc45446d50e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1262,6 +1262,17 @@ void addInstrRequirements(const MachineInstr &MI, break; } case SPIRV::OpTypeFloat: { + // const MachineBasicBlock *MBB = MI.getParent(); + // const MachineFunction *MF = MBB->getParent(); + // const MachineRegisterInfo &MRI = MF->getRegInfo(); + // const MachineOperand &MO = MI.getOperand(1); + // if (MO.isReg()) { + // LLT Ty = MRI.getType(MO.getReg()); + // if(!Ty.isScalar()) { + // assert(1 && "hola, ur wrong"); + // } + // } + unsigned BitWidth = MI.getOperand(1).getImm(); if (BitWidth == 64) Reqs.addCapability(SPIRV::Capability::Float64); From 23a05707090d7671e45b0aa5f4752e5af7ea05a3 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Mon, 18 Aug 2025 13:57:32 -0700 Subject: [PATCH 6/7] add the test file --- .../extensions/SPV_KHR_bfloat16/bfloat16.ll | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll new file mode 100644 index 0000000000000..e88d966e14b46 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll @@ -0,0 +1,16 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %} +; XFAIL: * +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK-ERROR: BFloat16TypeKHR requires the following SPIR-V extension: SPV_KHR_subgroup_rotate + +; CHECK-DAG: OpCapability BFloat16TypeKHR +; CHECK-DAG: OpExtension "SPV_KHR_bfloat16" +; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0 +; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2 + +define spir_kernel void @test() { +entry: + ret void +} \ No newline at end of file From 180d470f2d483860a3f3108651c2e4319a8ccdbb Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Mon, 18 Aug 2025 16:31:59 -0700 Subject: [PATCH 7/7] the test is now not failing, but the isfloat check does not work --- llvm/lib/CodeGen/LowLevelTypeUtils.cpp | 2 +- llvm/lib/CodeGenTypes/LowLevelType.cpp | 2 +- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 20 +++++++++---------- .../extensions/SPV_KHR_bfloat16/bfloat16.ll | 3 ++- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp index 9fe3d3e43a47f..03a97eacb049e 100644 --- a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp +++ b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp @@ -27,7 +27,7 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) { } if (auto PTy = dyn_cast(&Ty)) { - unsigned AddrSpace = PTy->isTokenTy(); + unsigned AddrSpace = PTy->getAddressSpace(); return LLT::pointer(AddrSpace, DL.getPointerSizeInBits(AddrSpace)); } diff --git a/llvm/lib/CodeGenTypes/LowLevelType.cpp b/llvm/lib/CodeGenTypes/LowLevelType.cpp index a13363a337472..8828135fcbb27 100644 --- a/llvm/lib/CodeGenTypes/LowLevelType.cpp +++ b/llvm/lib/CodeGenTypes/LowLevelType.cpp @@ -27,7 +27,7 @@ LLT::LLT(MVT VT) { // concerned. MVT ElemVT = VT.getVectorElementType(); bool isElemBfloat = (ElemVT == MVT::bf16); - init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true, /*isBfloat=*/isElemBfloat, + init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true, /*isBfloat=*/false, ElementCount::getFixed(0), VT.getSizeInBits(), /*AddressSpace=*/0); } else { IsScalar = false; diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 51fc45446d50e..dc00d97e2d8d2 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1262,16 +1262,16 @@ void addInstrRequirements(const MachineInstr &MI, break; } case SPIRV::OpTypeFloat: { - // const MachineBasicBlock *MBB = MI.getParent(); - // const MachineFunction *MF = MBB->getParent(); - // const MachineRegisterInfo &MRI = MF->getRegInfo(); - // const MachineOperand &MO = MI.getOperand(1); - // if (MO.isReg()) { - // LLT Ty = MRI.getType(MO.getReg()); - // if(!Ty.isScalar()) { - // assert(1 && "hola, ur wrong"); - // } - // } + const MachineBasicBlock *MBB = MI.getParent(); + const MachineFunction *MF = MBB->getParent(); + const MachineRegisterInfo &MRI = MF->getRegInfo(); + const MachineOperand &MO = MI.getOperand(1); + if (MO.isReg()) { + LLT Ty = MRI.getType(MO.getReg()); + if(Ty.isBfloat()) { + assert(1 && "hola, ur wrong"); + } + } unsigned BitWidth = MI.getOperand(1).getImm(); if (BitWidth == 64) diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll index e88d966e14b46..336b2b013bc60 100644 --- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll @@ -12,5 +12,6 @@ define spir_kernel void @test() { entry: + %addr1 = alloca bfloat ret void -} \ No newline at end of file +}