Skip to content

Commit 4ad6d5b

Browse files
committed
--Added matrix multiply accumulate operands for the extension SPV_INTEL_subgroup_matrix_multiply_accumulate
1 parent d66dbd6 commit 4ad6d5b

File tree

4 files changed

+103
-29
lines changed

4 files changed

+103
-29
lines changed

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ namespace MemoryModel {
5757
#include "SPIRVGenTables.inc"
5858
} // namespace MemoryModel
5959

60+
namespace MatrixMultiplyAccumulate {
61+
#define GET_MatrixMultiplyAccumulate_DECL
62+
#include "SPIRVGenTables.inc"
63+
} // namespace MatrixMultiplyAccumulate
64+
6065
namespace ExecutionMode {
6166
#define GET_ExecutionMode_DECL
6267
#include "SPIRVGenTables.inc"

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,37 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
242242
}
243243
break;
244244
}
245+
case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
246+
const unsigned NumOps = MI->getNumOperands();
247+
if (NumFixedOps == NumOps)
248+
break; // No extra operands, so no flags to process
249+
OS << ' ';
250+
// Extract the last operand only if it exists
251+
if (NumOps > NumFixedOps) {
252+
const unsigned Flags = MI->getOperand(NumOps - 1).getImm();
253+
if (Flags == 0) {
254+
printSymbolicOperand<
255+
OperandCategory::MatrixMultiplyAccumulateOperandsOperand>(
256+
MI, NumOps - 1, OS);
257+
} else {
258+
std::string Buffer;
259+
for (unsigned Mask = 0x1;
260+
Mask != SPIRV::MatrixMultiplyAccumulate::
261+
MatrixBPackedBFloat16INTEL; // Replace with
262+
// actual last flag
263+
Mask <<= 1) {
264+
if (Flags & Mask) {
265+
if (!Buffer.empty())
266+
Buffer += '|';
267+
Buffer += getSymbolicOperandMnemonic(
268+
OperandCategory::MatrixMultiplyAccumulateOperandsOperand, Mask);
269+
}
270+
}
271+
OS << Buffer;
272+
}
273+
}
274+
break;
275+
}
245276
default:
246277
printRemainingVariableOps(MI, NumFixedOps, OS);
247278
break;

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def KernelProfilingInfoOperand : OperandCategory;
172172
def OpcodeOperand : OperandCategory;
173173
def CooperativeMatrixLayoutOperand : OperandCategory;
174174
def CooperativeMatrixOperandsOperand : OperandCategory;
175+
def MatrixMultiplyAccumulateOperandsOperand :OperandCategory;
175176

176177
//===----------------------------------------------------------------------===//
177178
// Multiclass used to define Extesions enum values and at the same time
@@ -1746,3 +1747,40 @@ defm MatrixAAndBTF32ComponentsINTEL : CooperativeMatrixOperandsOperand<0x20, [SP
17461747
defm MatrixAAndBBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x40, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
17471748
defm MatrixCBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x80, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
17481749
defm MatrixResultBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x100, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
1750+
1751+
//===----------------------------------------------------------------------===//
1752+
// Multiclass used to define Matrix Multiply Accumulate Operands enum values and at the same time
1753+
// SymbolicOperand entries with string mnemonics and capabilities.
1754+
//===----------------------------------------------------------------------===//
1755+
def MatrixMultiplyAccumulate : GenericEnum, Operand<i32> {
1756+
let FilterClass = "MatrixMultiplyAccumulate";
1757+
let NameField = "Name";
1758+
let ValueField = "Value";
1759+
let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
1760+
}
1761+
1762+
class MatrixMultiplyAccumulate<string name, bits<32> value> {
1763+
string Name = name;
1764+
bits<32> Value = value;
1765+
}
1766+
1767+
multiclass MatrixMultiplyAccumulateOperandsOperand<bits<32> value, list<Extension> reqExtensions> {
1768+
def : MatrixMultiplyAccumulate<NAME, value>;
1769+
defm : SymbolicOperandWithRequirements< MatrixMultiplyAccumulateOperandsOperand, value, NAME, 0, 0, reqExtensions, []>;
1770+
}
1771+
1772+
defm None : MatrixMultiplyAccumulateOperandsOperand<0x0, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1773+
defm MatrixASignedComponentsINTEL : MatrixMultiplyAccumulateOperandsOperand<0x1, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1774+
defm MatrixBSignedComponentsINTEL : MatrixMultiplyAccumulateOperandsOperand<0x2, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1775+
defm MatrixCBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x4, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1776+
defm MatrixResultBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x8, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1777+
defm MatrixAPackedInt8INTEL : MatrixMultiplyAccumulateOperandsOperand<0x10, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1778+
defm MatrixBPackedInt8INTEL : MatrixMultiplyAccumulateOperandsOperand<0x20, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1779+
defm MatrixAPackedInt4INTEL : MatrixMultiplyAccumulateOperandsOperand<0x40, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1780+
defm MatrixBPackedInt4INTEL : MatrixMultiplyAccumulateOperandsOperand<0x80, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1781+
defm MatrixATF32INTEL : MatrixMultiplyAccumulateOperandsOperand<0x100, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1782+
defm MatrixBTF32INTEL : MatrixMultiplyAccumulateOperandsOperand<0x200, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1783+
defm MatrixAPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x400, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1784+
defm MatrixBPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1785+
defm MatrixAPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1786+
defm MatrixBPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;

llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_subgroup_matrix_multiply_accumulate/subgroup_matrix_multiply_accumulate_generic.ll

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -131,35 +131,35 @@
131131
; CHECK: %[[#hM2:]] = OpFunctionParameter %[[#Vec2HalfTy]]
132132
; CHECK: %[[#hM4:]] = OpFunctionParameter %[[#Vec4HalfTy]]
133133
; CHECK: %[[#hM8:]] = OpFunctionParameter %[[#Vec8HalfTy]]
134-
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int32Ty]] %[[#Const42]] %[[#iM]] %[[#iM8]] %[[#iM]] 10
135-
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int32Ty]] %[[#Const42]] %[[#iM2]] %[[#iM8]] %[[#iM2]] 10
136-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int32Ty]] %[[#Const42]] %[[#iM4]] %[[#iM8]] %[[#iM4]] 10
137-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int32Ty]] %[[#Const42]] %[[#iM8]] %[[#iM8]] %[[#iM8]] 10
138-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#iM]] %[[#iM8]] %[[#fM]] 10
139-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#iM2]] %[[#iM8]] %[[#fM2]] 10
140-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#iM4]] %[[#iM8]] %[[#fM4]] 10
141-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#iM8]] %[[#iM8]] %[[#fM8]] 10
142-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int32Ty]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#iM]] 10
143-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int32Ty]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#iM2]] 10
144-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int32Ty]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#iM4]] 10
145-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int32Ty]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#iM8]] 10
146-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#fM]] 10
147-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#fM2]] 10
148-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#fM4]] 10
149-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#fM8]] 10
150-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#HalfTy]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#hM]] 10
151-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2HalfTy]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#hM2]] 10
152-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4HalfTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#hM4]] 10
153-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8HalfTy]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#hM8]] 10
154-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int16Ty]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#sM]] 10
155-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int16Ty]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#sM2]] 10
156-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int16Ty]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#sM4]] 10
157-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int16Ty]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#sM8]] 10
158-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#fM]] %[[#fM8]] %[[#fM]] 10
159-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#fM2]] %[[#fM8]] %[[#fM2]] 10
160-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#fM4]] %[[#fM8]] %[[#fM4]] 10
161-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#fM8]] %[[#fM8]] %[[#fM8]] 10
162-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#fM4]]
134+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int32Ty]] %[[#Const42]] %[[#iM]] %[[#iM8]] %[[#iM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
135+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int32Ty]] %[[#Const42]] %[[#iM2]] %[[#iM8]] %[[#iM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
136+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int32Ty]] %[[#Const42]] %[[#iM4]] %[[#iM8]] %[[#iM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
137+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int32Ty]] %[[#Const42]] %[[#iM8]] %[[#iM8]] %[[#iM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
138+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#iM]] %[[#iM8]] %[[#fM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
139+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#iM2]] %[[#iM8]] %[[#fM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
140+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#iM4]] %[[#iM8]] %[[#fM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
141+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#iM8]] %[[#iM8]] %[[#fM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
142+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int32Ty]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#iM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
143+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int32Ty]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#iM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
144+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int32Ty]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#iM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
145+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int32Ty]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#iM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
146+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#fM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
147+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#fM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
148+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#fM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
149+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#fM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
150+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#HalfTy]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#hM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
151+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2HalfTy]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#hM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
152+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4HalfTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#hM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
153+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8HalfTy]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#hM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
154+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int16Ty]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#sM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
155+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int16Ty]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#sM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
156+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int16Ty]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#sM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
157+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int16Ty]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#sM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
158+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#fM]] %[[#fM8]] %[[#fM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
159+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#fM2]] %[[#fM8]] %[[#fM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
160+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#fM4]] %[[#fM8]] %[[#fM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
161+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#fM8]] %[[#fM8]] %[[#fM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
162+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#fM4]]
163163

164164
define spir_func void @foo(i32 %iM, <2 x i32> %iM2, <4 x i32> %iM4, <8 x i32> %iM8,
165165
i16 signext %sM, <2 x i16> %sM2, <4 x i16> %sM4, <8 x i16> %sM8,

0 commit comments

Comments
 (0)