Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 52 additions & 33 deletions llvm/include/llvm/CodeGenTypes/LowLevelType.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "llvm/CodeGenTypes/MachineValueType.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include <cassert>

namespace llvm {
Expand All @@ -41,30 +42,36 @@ class LLT {
public:
/// Get a low-level scalar or aggregate "bag of bits".
static constexpr LLT scalar(unsigned SizeInBits) {
return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true,
return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, /*isBfloat=*/false,
ElementCount::getFixed(0), SizeInBits,
/*AddressSpace=*/0};
}

static constexpr LLT scalar_bfloat(unsigned SizeInBits) {
return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, /*isBfloat=*/true,
ElementCount::getFixed(0), SizeInBits,
/*AddressSpace=*/0};
}

/// Get a low-level token; just a scalar with zero bits (or no size).
static constexpr LLT token() {
return LLT{/*isPointer=*/false, /*isVector=*/false,
/*isScalar=*/true, ElementCount::getFixed(0),
/*isScalar=*/true, /*isBfloat=*/false, ElementCount::getFixed(0),
/*SizeInBits=*/0,
/*AddressSpace=*/0};
}

/// Get a low-level pointer in the given address space.
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits) {
assert(SizeInBits > 0 && "invalid pointer size");
return LLT{/*isPointer=*/true, /*isVector=*/false, /*isScalar=*/false,
return LLT{/*isPointer=*/true, /*isVector=*/false, /*isScalar=*/false, /*isBfloat=*/false,
ElementCount::getFixed(0), SizeInBits, AddressSpace};
}

/// Get a low-level vector of some number of elements and element width.
static constexpr LLT vector(ElementCount EC, unsigned ScalarSizeInBits) {
assert(!EC.isScalar() && "invalid number of vector elements");
return LLT{/*isPointer=*/false, /*isVector=*/true, /*isScalar=*/false,
return LLT{/*isPointer=*/false, /*isVector=*/true, /*isScalar=*/false, /*isBfloat=*/false,
EC, ScalarSizeInBits, /*AddressSpace=*/0};
}

Expand All @@ -75,11 +82,17 @@ class LLT {
return LLT{ScalarTy.isPointer(),
/*isVector=*/true,
/*isScalar=*/false,
/*isBfloat=*/false,
EC,
ScalarTy.getSizeInBits().getFixedValue(),
ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0};
}

// Get a 16-bit brain float value.
static constexpr LLT bfloat16() {
return scalar_bfloat(16);
}

/// Get a 16-bit IEEE half value.
/// TODO: Add IEEE semantics to type - This currently returns a simple `scalar(16)`.
static constexpr LLT float16() {
Expand Down Expand Up @@ -132,14 +145,14 @@ class LLT {
return scalarOrVector(EC, LLT::scalar(static_cast<unsigned>(ScalarSize)));
}

explicit constexpr LLT(bool isPointer, bool isVector, bool isScalar,
explicit constexpr LLT(bool isPointer, bool isVector, bool isScalar, bool isBfloat,
ElementCount EC, uint64_t SizeInBits,
unsigned AddressSpace)
: LLT() {
init(isPointer, isVector, isScalar, EC, SizeInBits, AddressSpace);
init(isPointer, isVector, isScalar, isBfloat, EC, SizeInBits, AddressSpace);
}
explicit constexpr LLT()
: IsScalar(false), IsPointer(false), IsVector(false), RawData(0) {}
: IsScalar(false), IsPointer(false), IsVector(false), IsBfloat(false), RawData(0) {}

LLVM_ABI explicit LLT(MVT VT);

Expand All @@ -154,6 +167,7 @@ class LLT {
constexpr bool isPointerOrPointerVector() const {
return IsPointer && isValid();
}
constexpr bool isBfloat() const { return IsBfloat; }

/// Returns the number of elements in a vector LLT. Must only be called on
/// vector types.
Expand Down Expand Up @@ -304,32 +318,35 @@ class LLT {
/// isScalar : 1
/// isPointer : 1
/// isVector : 1
/// with 61 bits remaining for Kind-specific data, packed in bitfields
/// isBfloat : 1
/// with 60 bits remaining for Kind-specific data, packed in bitfields
/// as described below. As there isn't a simple portable way to pack bits
/// into bitfields, here the different fields in the packed structure is
/// described in static const *Field variables. Each of these variables
/// is a 2-element array, with the first element describing the bitfield size
/// and the second element describing the bitfield offset.
///
/// +--------+---------+--------+----------+----------------------+
/// |isScalar|isPointer|isVector| RawData |Notes |
/// +--------+---------+--------+----------+----------------------+
/// | 0 | 0 | 0 | 0 |Invalid |
/// +--------+---------+--------+----------+----------------------+
/// | 0 | 0 | 1 | 0 |Tombstone Key |
/// +--------+---------+--------+----------+----------------------+
/// | 0 | 1 | 0 | 0 |Empty Key |
/// +--------+---------+--------+----------+----------------------+
/// | 1 | 0 | 0 | 0 |Token |
/// +--------+---------+--------+----------+----------------------+
/// | 1 | 0 | 0 | non-zero |Scalar |
/// +--------+---------+--------+----------+----------------------+
/// | 0 | 1 | 0 | non-zero |Pointer |
/// +--------+---------+--------+----------+----------------------+
/// | 0 | 0 | 1 | non-zero |Vector of non-pointer |
/// +--------+---------+--------+----------+----------------------+
/// | 0 | 1 | 1 | non-zero |Vector of pointer |
/// +--------+---------+--------+----------+----------------------+
/// +--------+---------+--------+----------+----------+----------------------+
/// |isScalar|isPointer|isVector| isBfloat | RawData |Notes |
/// +--------+---------+--------+----------+----------+----------------------+
/// | 0 | 0 | 0 | 0 | 0 |Invalid |
/// +--------+---------+--------+----------+----------+----------------------+
/// | 0 | 0 | 1 | 0 | 0 |Tombstone Key |
/// +--------+---------+--------+----------+----------+----------------------+
/// | 0 | 1 | 0 | 0 | 0 |Empty Key |
/// +--------+---------+--------+----------+----------+----------------------+
/// | 1 | 0 | 0 | 0 | 0 |Token |
/// +--------+---------+--------+----------+----------+----------------------+
/// | 1 | 0 | 0 | 0 | non-zero |Scalar |
/// +--------+---------+--------+----------+----------+----------------------+
/// | 1 | 0 | 0 | 1 | non-zero |Scalar (Bfloat 16) |
/// +--------+---------+--------+----------+----------+----------------------+
/// | 0 | 1 | 0 | 0 | non-zero |Pointer |
/// +--------+---------+--------+----------+----------+----------------------+
/// | 0 | 0 | 1 | 0 | non-zero |Vector of non-pointer |
/// +--------+---------+--------+----------+----------+----------------------+
/// | 0 | 1 | 1 0 | non-zero |Vector of pointer |
/// +--------+---------+--------+----------+----------+----------------------+
///
/// Everything else is reserved.
typedef int BitFieldInfo[2];
Expand All @@ -340,12 +357,12 @@ class LLT {
/// valid encodings, SizeInBits/SizeOfElement must be larger than 0.
/// * Non-pointer scalar (isPointer == 0 && isVector == 0):
/// SizeInBits: 32;
static const constexpr BitFieldInfo ScalarSizeFieldInfo{32, 29};
static const constexpr BitFieldInfo ScalarSizeFieldInfo{32, 28};
/// * Pointer (isPointer == 1 && isVector == 0):
/// SizeInBits: 16;
/// AddressSpace: 24;
static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 45};
static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{24, 21};
static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 44};
static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{24, 20};
/// * Vector-of-non-pointer (isPointer == 0 && isVector == 1):
/// NumElements: 16;
/// SizeOfElement: 32;
Expand All @@ -361,7 +378,8 @@ class LLT {
uint64_t IsScalar : 1;
uint64_t IsPointer : 1;
uint64_t IsVector : 1;
uint64_t RawData : 61;
uint64_t IsBfloat : 1;
uint64_t RawData : 60;

static constexpr uint64_t getMask(const BitFieldInfo FieldInfo) {
const int FieldSizeInBits = FieldInfo[0];
Expand All @@ -381,14 +399,15 @@ class LLT {
return getMask(FieldInfo) & (RawData >> FieldInfo[1]);
}

constexpr void init(bool IsPointer, bool IsVector, bool IsScalar,
constexpr void init(bool IsPointer, bool IsVector, bool IsScalar, bool IsBfloat,
ElementCount EC, uint64_t SizeInBits,
unsigned AddressSpace) {
assert(SizeInBits <= std::numeric_limits<unsigned>::max() &&
"Not enough bits in LLT to represent size");
this->IsPointer = IsPointer;
this->IsVector = IsVector;
this->IsScalar = IsScalar;
this->IsBfloat = IsBfloat;
if (IsPointer) {
RawData = maskAndShift(SizeInBits, PointerSizeFieldInfo) |
maskAndShift(AddressSpace, PointerAddressSpaceFieldInfo);
Expand All @@ -403,7 +422,7 @@ class LLT {

public:
constexpr uint64_t getUniqueRAWLLTData() const {
return ((uint64_t)RawData) << 3 | ((uint64_t)IsScalar) << 2 |
return ((uint64_t)RawData) << 4 | ((uint64_t)IsBfloat) << 3 | ((uint64_t)IsScalar) << 2 |
((uint64_t)IsPointer) << 1 | ((uint64_t)IsVector);
}
};
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/CodeGen/LowLevelTypeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) {
// concerned.
auto SizeInBits = DL.getTypeSizeInBits(&Ty);
assert(SizeInBits != 0 && "invalid zero-sized type");
if (Ty.isBFloatTy()) {
return LLT::scalar_bfloat(SizeInBits);
}
return LLT::scalar(SizeInBits);
}

Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/CodeGenTypes/LowLevelType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,21 @@ using namespace llvm;
LLT::LLT(MVT VT) {
if (VT.isVector()) {
bool asVector = VT.getVectorMinNumElements() > 1 || VT.isScalableVector();
init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector,
init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector, /*isBfloat=*/false,
VT.getVectorElementCount(), VT.getVectorElementType().getSizeInBits(),
/*AddressSpace=*/0);
} else if (VT.isValid() && !VT.isScalableTargetExtVT()) {
// Aggregates are no different from real scalars as far as GlobalISel is
// concerned.
init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true,
MVT ElemVT = VT.getVectorElementType();
bool isElemBfloat = (ElemVT == MVT::bf16);
init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true, /*isBfloat=*/false,
ElementCount::getFixed(0), VT.getSizeInBits(), /*AddressSpace=*/0);
} else {
IsScalar = false;
IsPointer = false;
IsVector = false;
IsBfloat = false;
RawData = 0;
}
}
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,17 @@ void addInstrRequirements(const MachineInstr &MI,
break;
}
case SPIRV::OpTypeFloat: {
const MachineBasicBlock *MBB = MI.getParent();
const MachineFunction *MF = MBB->getParent();
const MachineRegisterInfo &MRI = MF->getRegInfo();
const MachineOperand &MO = MI.getOperand(1);
if (MO.isReg()) {
LLT Ty = MRI.getType(MO.getReg());
if(Ty.isBfloat()) {
assert(1 && "hola, ur wrong");
}
}

unsigned BitWidth = MI.getOperand(1).getImm();
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
Expand Down
17 changes: 17 additions & 0 deletions llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
; XFAIL: *
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: BFloat16TypeKHR requires the following SPIR-V extension: SPV_KHR_subgroup_rotate

; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2

define spir_kernel void @test() {
entry:
%addr1 = alloca bfloat
ret void
}
Loading