Skip to content

Commit 9b79dc9

Browse files
committed
--Added support for the extension SPV_INTEL_subgroup_matrix_multiply_accumulate
--Added test files for the extension SPV_INTEL_subgroup_matrix_multiply_accumulate
1 parent baab447 commit 9b79dc9

File tree

9 files changed

+403
-2
lines changed

9 files changed

+403
-2
lines changed

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ namespace ExecutionModel {
5252
#include "SPIRVGenTables.inc"
5353
} // namespace ExecutionModel
5454

55+
namespace MatrixMultiplyAccumulate {
56+
#define GET_MatrixMultiplyAccumulate_DECL
57+
#include "SPIRVGenTables.inc"
58+
} // namespace MatrixMultiplyAccumulate
59+
5560
namespace MemoryModel {
5661
#define GET_MemoryModel_DECL
5762
#include "SPIRVGenTables.inc"

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,40 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
237237
}
238238
break;
239239
}
240+
case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
241+
const unsigned NumOps = MI->getNumOperands();
242+
if (NumFixedOps == NumOps)
243+
break; // No extra operands, so no flags to process
244+
245+
OS << ' ';
246+
247+
// Extract the last operand only if it exists
248+
if (NumOps > NumFixedOps) {
249+
const unsigned Flags = MI->getOperand(NumOps - 1).getImm();
250+
251+
if (Flags == 0) {
252+
printSymbolicOperand<
253+
OperandCategory::MatrixMultiplyAccumulateOperand>(
254+
MI, NumOps - 1, OS);
255+
} else {
256+
std::string Buffer;
257+
for (unsigned Mask = 0x1;
258+
Mask != SPIRV::MatrixMultiplyAccumulate::
259+
MatrixBPackedBFloat16INTEL; // Replace with
260+
// actual last flag
261+
Mask <<= 1) {
262+
if (Flags & Mask) {
263+
if (!Buffer.empty())
264+
Buffer += '|';
265+
Buffer += getSymbolicOperandMnemonic(
266+
OperandCategory::MatrixMultiplyAccumulateOperand, Mask);
267+
}
268+
}
269+
OS << Buffer;
270+
}
271+
}
272+
break;
273+
}
240274
default:
241275
printRemainingVariableOps(MI, NumFixedOps, OS);
242276
break;

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
701701
MachineIRBuilder &MIRBuilder,
702702
SPIRVGlobalRegistry *GR) {
703703
if (Call->isSpirvOp())
704-
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
704+
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call,
705+
Register(0));
705706

706707
Register ScopeRegister =
707708
buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
@@ -2266,6 +2267,38 @@ static bool generateBindlessImageINTELInst(const SPIRV::IncomingCall *Call,
22662267
return buildBindlessImageINTELInst(Call, Opcode, MIRBuilder, GR);
22672268
}
22682269

2270+
static bool
2271+
generateSubgroupMatrixMultiplyAccumulateInst(const SPIRV::IncomingCall *Call,
2272+
MachineIRBuilder &MIRBuilder,
2273+
SPIRVGlobalRegistry *GR) {
2274+
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2275+
unsigned Opcode =
2276+
SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2277+
2278+
auto MIB = MIRBuilder.buildInstr(Opcode);
2279+
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
2280+
Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
2281+
MIB.addDef(Call->ReturnRegister).addUse(TypeReg);
2282+
2283+
size_t size = Call->Arguments.size();
2284+
2285+
if (size > 4) {
2286+
// Add first 4 arguments normally
2287+
for (size_t i = 0; i < 4; i++) {
2288+
MIB.addUse(Call->Arguments[i]);
2289+
}
2290+
const uint32_t memop = getConstFromIntrinsic(Call->Arguments.back(), MRI);
2291+
MIB.addImm(memop);
2292+
} else {
2293+
// Add all arguments if there are ≤ 4
2294+
for (size_t i = 0; i < size; i++) {
2295+
MIB.addUse(Call->Arguments[i]);
2296+
}
2297+
}
2298+
2299+
return true;
2300+
}
2301+
22692302
static bool buildNDRange(const SPIRV::IncomingCall *Call,
22702303
MachineIRBuilder &MIRBuilder,
22712304
SPIRVGlobalRegistry *GR) {
@@ -2847,6 +2880,9 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
28472880
return generateExtendedBitOpsInst(Call.get(), MIRBuilder, GR);
28482881
case SPIRV::BindlessINTEL:
28492882
return generateBindlessImageINTELInst(Call.get(), MIRBuilder, GR);
2883+
case SPIRV::SubgroupMatrixMultiplyAccumulate:
2884+
return generateSubgroupMatrixMultiplyAccumulateInst(Call.get(), MIRBuilder,
2885+
GR);
28502886
}
28512887
return false;
28522888
}

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def CoopMatr : BuiltinGroup;
6767
def ICarryBorrow : BuiltinGroup;
6868
def ExtendedBitOps : BuiltinGroup;
6969
def BindlessINTEL : BuiltinGroup;
70+
def SubgroupMatrixMultiplyAccumulate : BuiltinGroup;
7071

7172
//===----------------------------------------------------------------------===//
7273
// Class defining a demangled builtin record. The information in the record
@@ -1128,6 +1129,9 @@ defm : DemangledNativeBuiltin<"clock_read_hilo_device", OpenCL_std, KernelClock,
11281129
defm : DemangledNativeBuiltin<"clock_read_hilo_work_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;
11291130
defm : DemangledNativeBuiltin<"clock_read_hilo_sub_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;
11301131

1132+
//SPV_INTEL_subgroup_matrix_multiply_accumulate
1133+
defm : DemangledNativeBuiltin<"__spirv_SubgroupMatrixMultiplyAccumulateINTEL", OpenCL_std, SubgroupMatrixMultiplyAccumulate, 4, 8, OpSubgroupMatrixMultiplyAccumulateINTEL>;
1134+
11311135
//===----------------------------------------------------------------------===//
11321136
// Class defining an atomic instruction on floating-point numbers.
11331137
//

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,10 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
9292
{"SPV_INTEL_long_composites",
9393
SPIRV::Extension::Extension::SPV_INTEL_long_composites},
9494
{"SPV_INTEL_fp_max_error",
95-
SPIRV::Extension::Extension::SPV_INTEL_fp_max_error}};
95+
SPIRV::Extension::Extension::SPV_INTEL_fp_max_error},
96+
{"SPV_INTEL_subgroup_matrix_multiply_accumulate",
97+
SPIRV::Extension::Extension::
98+
SPV_INTEL_subgroup_matrix_multiply_accumulate}};
9699

97100
bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName,
98101
llvm::StringRef ArgValue,

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,3 +956,7 @@ def OpAliasScopeDeclINTEL: Op<5912, (outs ID:$res), (ins ID:$AliasDomain, variab
956956
"$res = OpAliasScopeDeclINTEL $AliasDomain">;
957957
def OpAliasScopeListDeclINTEL: Op<5913, (outs ID:$res), (ins variable_ops),
958958
"$res = OpAliasScopeListDeclINTEL">;
959+
960+
//SPV_INTEL_subgroup_matrix_multiply_accumulate
961+
def OpSubgroupMatrixMultiplyAccumulateINTEL:Op<6237, (outs ID:$res), (ins TYPE:$result_type, ID:$dim, ID:$a, ID:$b, ID:$c, variable_ops),
962+
"$res = OpSubgroupMatrixMultiplyAccumulateINTEL $result_type $dim $a $b $c">;

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,6 +1699,19 @@ void addInstrRequirements(const MachineInstr &MI,
16991699
Reqs.addCapability(
17001700
SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
17011701
break;
1702+
case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
1703+
if (!ST.canUseExtension(
1704+
SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
1705+
report_fatal_error("This matrix instructions require the "
1706+
"following SPIR-V extension: "
1707+
"SPV_INTEL_subgroup_matrix_multiply_accumulate",
1708+
false);
1709+
Reqs.addExtension(
1710+
SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
1711+
Reqs.addCapability(
1712+
SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
1713+
break;
1714+
}
17021715
case SPIRV::OpConvertHandleToImageINTEL:
17031716
case SPIRV::OpConvertHandleToSamplerINTEL:
17041717
case SPIRV::OpConvertHandleToSampledImageINTEL:

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def KernelProfilingInfoOperand : OperandCategory;
172172
def OpcodeOperand : OperandCategory;
173173
def CooperativeMatrixLayoutOperand : OperandCategory;
174174
def CooperativeMatrixOperandsOperand : OperandCategory;
175+
def MatrixMultiplyAccumulateOperand :OperandCategory;
175176

176177
//===----------------------------------------------------------------------===//
177178
// Multiclass used to define Extesions enum values and at the same time
@@ -313,6 +314,7 @@ defm SPV_INTEL_bindless_images : ExtensionOperand<116>;
313314
defm SPV_INTEL_long_composites : ExtensionOperand<117>;
314315
defm SPV_INTEL_memory_access_aliasing : ExtensionOperand<118>;
315316
defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
317+
defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<120>;
316318

317319
//===----------------------------------------------------------------------===//
318320
// Multiclass used to define Capabilities enum values and at the same time
@@ -512,6 +514,7 @@ defm FunctionFloatControlINTEL : CapabilityOperand<5821, 0, 0, [SPV_INTEL_float_
512514
defm LongCompositesINTEL : CapabilityOperand<6089, 0, 0, [SPV_INTEL_long_composites], []>;
513515
defm BindlessImagesINTEL : CapabilityOperand<6528, 0, 0, [SPV_INTEL_bindless_images], []>;
514516
defm MemoryAccessAliasingINTEL : CapabilityOperand<5910, 0, 0, [SPV_INTEL_memory_access_aliasing], []>;
517+
defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_INTEL_subgroup_matrix_multiply_accumulate], []>;
515518
defm FPMaxErrorINTEL : CapabilityOperand<6169, 0, 0, [SPV_INTEL_fp_max_error], []>;
516519

517520
//===----------------------------------------------------------------------===//
@@ -1741,3 +1744,41 @@ defm MatrixAAndBTF32ComponentsINTEL : CooperativeMatrixOperandsOperand<0x20, [SP
17411744
defm MatrixAAndBBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x40, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
17421745
defm MatrixCBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x80, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
17431746
defm MatrixResultBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x100, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
1747+
1748+
//===----------------------------------------------------------------------===//
1749+
// Multiclass used to define Matrix Multiply Accumulate Operands enum values and at the same time
1750+
// SymbolicOperand entries with string mnemonics and capabilities.
1751+
//===----------------------------------------------------------------------===//
1752+
def MatrixMultiplyAccumulate : GenericEnum, Operand<i32> {
1753+
let FilterClass = "MatrixMultiplyAccumulate";
1754+
let NameField = "Name";
1755+
let ValueField = "Value";
1756+
let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
1757+
}
1758+
1759+
class MatrixMultiplyAccumulate<string name, bits<32> value> {
1760+
string Name = name;
1761+
bits<32> Value = value;
1762+
}
1763+
1764+
multiclass MatrixMultiplyAccumulateOperand<bits<32> value, list<Extension> reqExtensions> {
1765+
def : MatrixMultiplyAccumulate<NAME, value>;
1766+
defm : SymbolicOperandWithRequirements< MatrixMultiplyAccumulateOperand, value, NAME, 0, 0, reqExtensions, []>;
1767+
}
1768+
1769+
defm None : MatrixMultiplyAccumulateOperand<0x0, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1770+
defm MatrixASignedComponentsINTEL : MatrixMultiplyAccumulateOperand<0x1, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1771+
defm MatrixBSignedComponentsINTEL : MatrixMultiplyAccumulateOperand<0x2, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1772+
defm MatrixCBFloat16INTEL : MatrixMultiplyAccumulateOperand<0x4, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1773+
defm MatrixResultBFloat16INTEL : MatrixMultiplyAccumulateOperand<0x8, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1774+
defm MatrixAPackedInt8INTEL : MatrixMultiplyAccumulateOperand<0x10, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1775+
defm MatrixBPackedInt8INTEL : MatrixMultiplyAccumulateOperand<0x20, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1776+
defm MatrixAPackedInt4INTEL : MatrixMultiplyAccumulateOperand<0x40, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1777+
defm MatrixBPackedInt4INTEL : MatrixMultiplyAccumulateOperand<0x80, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1778+
defm MatrixATF32INTEL : MatrixMultiplyAccumulateOperand<0x100, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1779+
defm MatrixBTF32INTEL : MatrixMultiplyAccumulateOperand<0x200, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1780+
defm MatrixAPackedFloat16INTEL : MatrixMultiplyAccumulateOperand<0x400, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1781+
defm MatrixBPackedFloat16INTEL : MatrixMultiplyAccumulateOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1782+
defm MatrixAPackedBFloat16INTEL : MatrixMultiplyAccumulateOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1783+
defm MatrixBPackedBFloat16INTEL : MatrixMultiplyAccumulateOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1784+

0 commit comments

Comments
 (0)