Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ namespace ExecutionModel {
#include "SPIRVGenTables.inc"
} // namespace ExecutionModel

namespace MatrixMultiplyAccumulate {
#define GET_MatrixMultiplyAccumulate_DECL
#include "SPIRVGenTables.inc"
} // namespace MatrixMultiplyAccumulate

namespace MemoryModel {
#define GET_MemoryModel_DECL
#include "SPIRVGenTables.inc"
Expand Down
34 changes: 34 additions & 0 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,40 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
}
break;
}
case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
const unsigned NumOps = MI->getNumOperands();
if (NumFixedOps == NumOps)
break; // No extra operands, so no flags to process

OS << ' ';

// Extract the last operand only if it exists
if (NumOps > NumFixedOps) {
const unsigned Flags = MI->getOperand(NumOps - 1).getImm();

if (Flags == 0) {
printSymbolicOperand<
OperandCategory::MatrixMultiplyAccumulateOperand>(
MI, NumOps - 1, OS);
} else {
std::string Buffer;
for (unsigned Mask = 0x1;
Mask != SPIRV::MatrixMultiplyAccumulate::
MatrixBPackedBFloat16INTEL; // Replace with
// actual last flag
Mask <<= 1) {
if (Flags & Mask) {
if (!Buffer.empty())
Buffer += '|';
Buffer += getSymbolicOperandMnemonic(
OperandCategory::MatrixMultiplyAccumulateOperand, Mask);
}
}
OS << Buffer;
}
}
break;
}
default:
printRemainingVariableOps(MI, NumFixedOps, OS);
break;
Expand Down
38 changes: 37 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
if (Call->isSpirvOp())
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call,
Register(0));

Register ScopeRegister =
buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
Expand Down Expand Up @@ -2266,6 +2267,38 @@ static bool generateBindlessImageINTELInst(const SPIRV::IncomingCall *Call,
return buildBindlessImageINTELInst(Call, Opcode, MIRBuilder, GR);
}

static bool
generateSubgroupMatrixMultiplyAccumulateInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
unsigned Opcode =
SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;

auto MIB = MIRBuilder.buildInstr(Opcode);
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
MIB.addDef(Call->ReturnRegister).addUse(TypeReg);

size_t size = Call->Arguments.size();

if (size > 4) {
// Add first 4 arguments normally
for (size_t i = 0; i < 4; i++) {
MIB.addUse(Call->Arguments[i]);
}
const uint32_t memop = getConstFromIntrinsic(Call->Arguments.back(), MRI);
MIB.addImm(memop);
} else {
// Add all arguments if there are ≤ 4
for (size_t i = 0; i < size; i++) {
MIB.addUse(Call->Arguments[i]);
}
}

return true;
}

static bool buildNDRange(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
Expand Down Expand Up @@ -2847,6 +2880,9 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateExtendedBitOpsInst(Call.get(), MIRBuilder, GR);
case SPIRV::BindlessINTEL:
return generateBindlessImageINTELInst(Call.get(), MIRBuilder, GR);
case SPIRV::SubgroupMatrixMultiplyAccumulate:
return generateSubgroupMatrixMultiplyAccumulateInst(Call.get(), MIRBuilder,
GR);
}
return false;
}
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def CoopMatr : BuiltinGroup;
def ICarryBorrow : BuiltinGroup;
def ExtendedBitOps : BuiltinGroup;
def BindlessINTEL : BuiltinGroup;
def SubgroupMatrixMultiplyAccumulate : BuiltinGroup;

//===----------------------------------------------------------------------===//
// Class defining a demangled builtin record. The information in the record
Expand Down Expand Up @@ -1128,6 +1129,9 @@ defm : DemangledNativeBuiltin<"clock_read_hilo_device", OpenCL_std, KernelClock,
defm : DemangledNativeBuiltin<"clock_read_hilo_work_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;
defm : DemangledNativeBuiltin<"clock_read_hilo_sub_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;

//SPV_INTEL_subgroup_matrix_multiply_accumulate
defm : DemangledNativeBuiltin<"__spirv_SubgroupMatrixMultiplyAccumulateINTEL", OpenCL_std, SubgroupMatrixMultiplyAccumulate, 4, 8, OpSubgroupMatrixMultiplyAccumulateINTEL>;

//===----------------------------------------------------------------------===//
// Class defining an atomic instruction on floating-point numbers.
//
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_INTEL_long_composites",
SPIRV::Extension::Extension::SPV_INTEL_long_composites},
{"SPV_INTEL_fp_max_error",
SPIRV::Extension::Extension::SPV_INTEL_fp_max_error}};
SPIRV::Extension::Extension::SPV_INTEL_fp_max_error},
{"SPV_INTEL_subgroup_matrix_multiply_accumulate",
SPIRV::Extension::Extension::
SPV_INTEL_subgroup_matrix_multiply_accumulate}};

bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName,
llvm::StringRef ArgValue,
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -956,3 +956,7 @@ def OpAliasScopeDeclINTEL: Op<5912, (outs ID:$res), (ins ID:$AliasDomain, variab
"$res = OpAliasScopeDeclINTEL $AliasDomain">;
def OpAliasScopeListDeclINTEL: Op<5913, (outs ID:$res), (ins variable_ops),
"$res = OpAliasScopeListDeclINTEL">;

//SPV_INTEL_subgroup_matrix_multiply_accumulate
def OpSubgroupMatrixMultiplyAccumulateINTEL:Op<6237, (outs ID:$res), (ins TYPE:$result_type, ID:$dim, ID:$a, ID:$b, ID:$c, variable_ops),
"$res = OpSubgroupMatrixMultiplyAccumulateINTEL $result_type $dim $a $b $c">;
13 changes: 13 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,19 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(
SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
break;
case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
if (!ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
report_fatal_error("This matrix instructions require the "
"following SPIR-V extension: "
"SPV_INTEL_subgroup_matrix_multiply_accumulate",
false);
Reqs.addExtension(
SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
Reqs.addCapability(
SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
break;
}
case SPIRV::OpConvertHandleToImageINTEL:
case SPIRV::OpConvertHandleToSamplerINTEL:
case SPIRV::OpConvertHandleToSampledImageINTEL:
Expand Down
41 changes: 41 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def KernelProfilingInfoOperand : OperandCategory;
def OpcodeOperand : OperandCategory;
def CooperativeMatrixLayoutOperand : OperandCategory;
def CooperativeMatrixOperandsOperand : OperandCategory;
def MatrixMultiplyAccumulateOperand :OperandCategory;

//===----------------------------------------------------------------------===//
// Multiclass used to define Extesions enum values and at the same time
Expand Down Expand Up @@ -313,6 +314,7 @@ defm SPV_INTEL_bindless_images : ExtensionOperand<116>;
defm SPV_INTEL_long_composites : ExtensionOperand<117>;
defm SPV_INTEL_memory_access_aliasing : ExtensionOperand<118>;
defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<120>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
Expand Down Expand Up @@ -512,6 +514,7 @@ defm FunctionFloatControlINTEL : CapabilityOperand<5821, 0, 0, [SPV_INTEL_float_
defm LongCompositesINTEL : CapabilityOperand<6089, 0, 0, [SPV_INTEL_long_composites], []>;
defm BindlessImagesINTEL : CapabilityOperand<6528, 0, 0, [SPV_INTEL_bindless_images], []>;
defm MemoryAccessAliasingINTEL : CapabilityOperand<5910, 0, 0, [SPV_INTEL_memory_access_aliasing], []>;
defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_INTEL_subgroup_matrix_multiply_accumulate], []>;
defm FPMaxErrorINTEL : CapabilityOperand<6169, 0, 0, [SPV_INTEL_fp_max_error], []>;

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1741,3 +1744,41 @@ defm MatrixAAndBTF32ComponentsINTEL : CooperativeMatrixOperandsOperand<0x20, [SP
defm MatrixAAndBBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x40, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
defm MatrixCBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x80, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
defm MatrixResultBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x100, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Matrix Multiply Accumulate Operands enum values and at the same time
// SymbolicOperand entries with string mnemonics and capabilities.
//===----------------------------------------------------------------------===//
def MatrixMultiplyAccumulate : GenericEnum, Operand<i32> {
let FilterClass = "MatrixMultiplyAccumulate";
let NameField = "Name";
let ValueField = "Value";
let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
}

class MatrixMultiplyAccumulate<string name, bits<32> value> {
string Name = name;
bits<32> Value = value;
}

multiclass MatrixMultiplyAccumulateOperand<bits<32> value, list<Extension> reqExtensions> {
def : MatrixMultiplyAccumulate<NAME, value>;
defm : SymbolicOperandWithRequirements< MatrixMultiplyAccumulateOperand, value, NAME, 0, 0, reqExtensions, []>;
}

defm None : MatrixMultiplyAccumulateOperand<0x0, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixASignedComponentsINTEL : MatrixMultiplyAccumulateOperand<0x1, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixBSignedComponentsINTEL : MatrixMultiplyAccumulateOperand<0x2, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixCBFloat16INTEL : MatrixMultiplyAccumulateOperand<0x4, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixResultBFloat16INTEL : MatrixMultiplyAccumulateOperand<0x8, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixAPackedInt8INTEL : MatrixMultiplyAccumulateOperand<0x10, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixBPackedInt8INTEL : MatrixMultiplyAccumulateOperand<0x20, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixAPackedInt4INTEL : MatrixMultiplyAccumulateOperand<0x40, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixBPackedInt4INTEL : MatrixMultiplyAccumulateOperand<0x80, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixATF32INTEL : MatrixMultiplyAccumulateOperand<0x100, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixBTF32INTEL : MatrixMultiplyAccumulateOperand<0x200, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixAPackedFloat16INTEL : MatrixMultiplyAccumulateOperand<0x400, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixBPackedFloat16INTEL : MatrixMultiplyAccumulateOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixAPackedBFloat16INTEL : MatrixMultiplyAccumulateOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
defm MatrixBPackedBFloat16INTEL : MatrixMultiplyAccumulateOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;

Loading