Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
- Introduces two new storage classes that are subclasses of the CrossWorkgroup storage class that provides additional information that can enable optimization.
* - ``SPV_INTEL_variable_length_array``
- Allows to allocate local arrays whose number of elements is unknown at compile time.
* - ``SPV_INTEL_joint_matrix``
- Adds few matrix capabilities on top of SPV_KHR_cooperative_matrix extension, such as matrix prefetch, get element coordinate and checked load/store/construct instructions, tensor float 32 and bfloat type interpretations for multuply-add instruction.
* - ``SPV_KHR_bit_instructions``
- Enables bit instructions to be used by SPIR-V modules without requiring the Shader capability.
* - ``SPV_KHR_expect_assume``
Expand Down
8 changes: 6 additions & 2 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,12 @@ getCapabilitiesEnabledByExtension(SPIRV::Extension::Extension Extension) {

CapabilityList Capabilities;
while (Entry &&
Entry->Category == SPIRV::OperandCategory::CapabilityOperand &&
Entry->ReqExtension == Extension) {
Entry->Category == SPIRV::OperandCategory::CapabilityOperand) {
// Some capabilities' codes might go not in order.
if (Entry->ReqExtension != Extension) {
++Entry;
continue;
}
Capabilities.push_back(
static_cast<SPIRV::Capability::Capability>(Entry->Value));
++Entry;
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,16 @@ namespace Opcode {
#include "SPIRVGenTables.inc"
} // namespace Opcode

namespace CooperativeMatrixLayout {
#define GET_CooperativeMatrixLayout_DECL
#include "SPIRVGenTables.inc"
} // namespace CooperativeMatrixLayout

namespace CooperativeMatrixOperands {
#define GET_CooperativeMatrixOperands_DECL
#include "SPIRVGenTables.inc"
} // namespace CooperativeMatrixOperands

struct ExtendedBuiltin {
StringRef Name;
InstructionSet::InstructionSet Set;
Expand Down
28 changes: 28 additions & 0 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,34 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
// are part of the variable value.
printOpConstantVarOps(MI, NumFixedOps - 1, OS);
break;
case SPIRV::OpCooperativeMatrixMulAddKHR: {
const unsigned NumOps = MI->getNumOperands();
if (NumFixedOps == NumOps)
break;

OS << ' ';
const unsigned MulAddOp = MI->getOperand(FirstVariableIndex).getImm();
if (MulAddOp == 0) {
printSymbolicOperand<
OperandCategory::CooperativeMatrixOperandsOperand>(
MI, FirstVariableIndex, OS);
} else {
std::string Buffer;
for (unsigned Mask = 0x1;
Mask != SPIRV::CooperativeMatrixOperands::
MatrixResultBFloat16ComponentsINTEL;
Mask <<= 1) {
if (MulAddOp & Mask) {
if (!Buffer.empty())
Buffer += '|';
Buffer += getSymbolicOperandMnemonic(
OperandCategory::CooperativeMatrixOperandsOperand, Mask);
}
}
OS << Buffer;
}
break;
}
default:
printRemainingVariableOps(MI, NumFixedOps, OS);
break;
Expand Down
44 changes: 39 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1969,15 +1969,49 @@ static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
unsigned Opcode =
SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR;
bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR &&
Opcode != SPIRV::OpCooperativeMatrixStoreCheckedINTEL &&
Opcode != SPIRV::OpCooperativeMatrixPrefetchINTEL;
unsigned ArgSz = Call->Arguments.size();
unsigned LiteralIdx = 0;
if (Opcode == SPIRV::OpCooperativeMatrixLoadKHR && ArgSz > 3)
LiteralIdx = 3;
else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4)
LiteralIdx = 4;
switch (Opcode) {
// Memory operand is optional and is literal.
case SPIRV::OpCooperativeMatrixLoadKHR:
LiteralIdx = ArgSz > 3 ? 3 : 0;
break;
case SPIRV::OpCooperativeMatrixStoreKHR:
LiteralIdx = ArgSz > 4 ? 4 : 0;
break;
case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
LiteralIdx = ArgSz > 7 ? 7 : 0;
break;
case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
LiteralIdx = ArgSz > 8 ? 8 : 0;
break;
// Cooperative Matrix Operands operand is optional and is literal.
case SPIRV::OpCooperativeMatrixMulAddKHR:
LiteralIdx = ArgSz > 3 ? 3 : 0;
break;
};

SmallVector<uint32_t, 1> ImmArgs;
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
if (Opcode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
const uint32_t CacheLevel = getConstFromIntrinsic(Call->Arguments[3], MRI);
auto MIB = MIRBuilder.buildInstr(SPIRV::OpCooperativeMatrixPrefetchINTEL)
.addUse(Call->Arguments[0]) // pointer
.addUse(Call->Arguments[1]) // rows
.addUse(Call->Arguments[2]) // columns
.addImm(CacheLevel) // cache level
.addUse(Call->Arguments[4]); // memory layout
if (ArgSz > 5)
MIB.addUse(Call->Arguments[5]); // stride
if (ArgSz > 6) {
const uint32_t MemOp = getConstFromIntrinsic(Call->Arguments[6], MRI);
MIB.addImm(MemOp); // memory operand
}
return true;
}
if (LiteralIdx > 0)
ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI));
Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,13 @@ defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreKHR", OpenCL_std, C
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixMulAddKHR", OpenCL_std, CoopMatr, 3, 4, OpCooperativeMatrixMulAddKHR>;
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLengthKHR", OpenCL_std, CoopMatr, 1, 1, OpCooperativeMatrixLengthKHR>;

// Cooperative Matrix Intel builtin records:
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixPrefetchINTEL", OpenCL_std, CoopMatr, 5, 7, OpCooperativeMatrixPrefetchINTEL>;
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLoadCheckedINTEL", OpenCL_std, CoopMatr, 6, 8, OpCooperativeMatrixLoadCheckedINTEL>;
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreCheckedINTEL", OpenCL_std, CoopMatr, 7, 9, OpCooperativeMatrixStoreCheckedINTEL>;
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixConstructCheckedINTEL", OpenCL_std, CoopMatr, 5, 5, OpCooperativeMatrixConstructCheckedINTEL>;
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixGetElementCoordINTEL", OpenCL_std, CoopMatr, 2, 2, OpCooperativeMatrixGetElementCoordINTEL>;

//===----------------------------------------------------------------------===//
// Class defining a work/sub group builtin that should be translated into a
// SPIR-V instruction using the defined properties.
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::SPV_INTEL_subgroups},
{"SPV_INTEL_media_block_io",
SPIRV::Extension::Extension::SPV_INTEL_media_block_io},
{"SPV_INTEL_joint_matrix",
SPIRV::Extension::Extension::SPV_INTEL_joint_matrix},
{"SPV_KHR_uniform_group_instructions",
SPIRV::Extension::Extension::SPV_KHR_uniform_group_instructions},
{"SPV_KHR_no_integer_wrap_decoration",
Expand Down
17 changes: 17 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,23 @@ def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res),
def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type),
"$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">;

// SPV_INTEL_joint_matrix
def OpCooperativeMatrixLoadCheckedINTEL: Op<6193, (outs ID:$res),
(ins TYPE:$resType, ID:$pointer, ID:$xOffset, ID:$yOffset, ID:$memory_layout, ID:$height, ID:$width, variable_ops),
"$res = OpCooperativeMatrixLoadCheckedINTEL $resType $pointer $xOffset $yOffset $memory_layout $height $width">;
def OpCooperativeMatrixStoreCheckedINTEL: Op<6194, (outs),
(ins ID:$pointer, ID:$xOffset, ID:$yOffset, ID:$objectToStore, ID:$memory_layout, ID:$height, ID:$width, variable_ops),
"OpCooperativeMatrixStoreCheckedINTEL $pointer $xOffset $yOffset $objectToStore $memory_layout $height $width">;
def OpCooperativeMatrixConstructCheckedINTEL: Op<6195, (outs ID:$res),
(ins TYPE:$resType, ID:$xOffset, ID:$yOffset, ID:$height, ID:$width, ID:$value),
"$res = OpCooperativeMatrixConstructCheckedINTEL $resType $xOffset $yOffset $height $width $value">;
def OpCooperativeMatrixGetElementCoordINTEL: Op<6440, (outs ID:$res),
(ins TYPE:$resType, ID:$matrix, ID:$index),
"$res = OpCooperativeMatrixGetElementCoordINTEL $resType $matrix $index">;
def OpCooperativeMatrixPrefetchINTEL: Op<6449, (outs),
(ins ID:$pointer, ID:$rows, ID:$columns, i32imm:$cacheLevel, ID:$memory_layout, variable_ops),
"OpCooperativeMatrixPrefetchINTEL $pointer $rows $columns $cacheLevel $memory_layout">;

// SPV_EXT_arithmetic_fence
def OpArithmeticFenceEXT: Op<6145, (outs ID:$res), (ins TYPE:$type, ID:$target),
"$res = OpArithmeticFenceEXT $type $target">;
108 changes: 108 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,114 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
}
break;
case SPIRV::OpCooperativeMatrixMulAddKHR: {
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
report_fatal_error("Cooperative matrix instructions require the "
"following SPIR-V extension: "
"SPV_KHR_cooperative_matrix",
false);
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
constexpr unsigned MulAddMaxSize = 6;
if (MI.getNumOperands() != MulAddMaxSize)
break;
const int64_t CoopOperands = MI.getOperand(MulAddMaxSize - 1).getImm();
if (CoopOperands &
SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
Reqs.addCapability(
SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
}
if (CoopOperands & SPIRV::CooperativeMatrixOperands::
MatrixAAndBBFloat16ComponentsINTEL ||
CoopOperands &
SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
CoopOperands & SPIRV::CooperativeMatrixOperands::
MatrixResultBFloat16ComponentsINTEL) {
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
Reqs.addCapability(
SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
}
break;
}
case SPIRV::OpCooperativeMatrixLoadKHR:
case SPIRV::OpCooperativeMatrixStoreKHR:
case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
report_fatal_error("Cooperative matrix instructions require the "
"following SPIR-V extension: "
"SPV_KHR_cooperative_matrix",
false);
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);

// Check Layout operand in case if it's not a standart one and add the
// appropriate capability.
std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
{SPIRV::OpCooperativeMatrixLoadKHR, 3},
{SPIRV::OpCooperativeMatrixStoreKHR, 2},
{SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
{SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
{SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};

const auto OpCode = MI.getOpcode();
const unsigned LayoutNum = LayoutToInstMap[OpCode];
Register RegLayout = MI.getOperand(LayoutNum).getReg();
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout);
if (MILayout->getOpcode() == SPIRV::OpConstantI) {
const unsigned LayoutVal = MILayout->getOperand(2).getImm();
if (LayoutVal ==
static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
report_fatal_error("PackedINTEL layout require the following SPIR-V "
"extension: SPV_INTEL_joint_matrix",
false);
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL);
}
}

// Nothing to do.
if (OpCode == SPIRV::OpCooperativeMatrixLoadKHR ||
OpCode == SPIRV::OpCooperativeMatrixStoreKHR)
break;

if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
report_fatal_error("OpCooperativeMatrix[Load/Store]CheckedINTEL "
"instructions require the following SPIR-V extension: "
"SPV_INTEL_joint_matrix",
false);
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
break;
}
Reqs.addCapability(
SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
break;
}
case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL"
" instructions require the following SPIR-V extension:"
" SPV_INTEL_joint_matrix",
false);
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
Reqs.addCapability(
SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
break;
case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
report_fatal_error("OpCooperativeMatrixGetElementCoordINTEL requires the "
"following SPIR-V extension: SPV_INTEL_joint_matrix",
false);
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
Reqs.addCapability(
SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
break;
case SPIRV::OpKill: {
Reqs.addCapability(SPIRV::Capability::Shader);
} break;
Expand Down
Loading
Loading