Skip to content

Commit 006813f

Browse files
committed
--Added Matrix multiply accumulate operands for the extension SPV_INTEL_subgroup_matrix_multiply_accumulate
1 parent d66dbd6 commit 006813f

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ namespace MemoryModel {
5757
#include "SPIRVGenTables.inc"
5858
} // namespace MemoryModel
5959

60+
namespace MatrixMultiplyAccumulate {
61+
#define GET_MatrixMultiplyAccumulate_DECL
62+
#include "SPIRVGenTables.inc"
63+
} // namespace MatrixMultiplyAccumulate
64+
6065
namespace ExecutionMode {
6166
#define GET_ExecutionMode_DECL
6267
#include "SPIRVGenTables.inc"

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,37 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
242242
}
243243
break;
244244
}
245+
case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
246+
const unsigned NumOps = MI->getNumOperands();
247+
if (NumFixedOps == NumOps)
248+
break; // No extra operands, so no flags to process
249+
OS << ' ';
250+
// Extract the last operand only if it exists
251+
if (NumOps > NumFixedOps) {
252+
const unsigned Flags = MI->getOperand(NumOps - 1).getImm();
253+
if (Flags == 0) {
254+
printSymbolicOperand<
255+
OperandCategory::MatrixMultiplyAccumulateOperandsOperand>(
256+
MI, NumOps - 1, OS);
257+
} else {
258+
std::string Buffer;
259+
for (unsigned Mask = 0x1;
260+
Mask != SPIRV::MatrixMultiplyAccumulate::
261+
MatrixBPackedBFloat16INTEL; // Replace with
262+
// actual last flag
263+
Mask <<= 1) {
264+
if (Flags & Mask) {
265+
if (!Buffer.empty())
266+
Buffer += '|';
267+
Buffer += getSymbolicOperandMnemonic(
268+
OperandCategory::MatrixMultiplyAccumulateOperandsOperand, Mask);
269+
}
270+
}
271+
OS << Buffer;
272+
}
273+
}
274+
break;
275+
}
245276
default:
246277
printRemainingVariableOps(MI, NumFixedOps, OS);
247278
break;

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def KernelProfilingInfoOperand : OperandCategory;
172172
def OpcodeOperand : OperandCategory;
173173
def CooperativeMatrixLayoutOperand : OperandCategory;
174174
def CooperativeMatrixOperandsOperand : OperandCategory;
175+
def MatrixMultiplyAccumulateOperandsOperand :OperandCategory;
175176

176177
//===----------------------------------------------------------------------===//
177178
// Multiclass used to define Extesions enum values and at the same time
@@ -1746,3 +1747,40 @@ defm MatrixAAndBTF32ComponentsINTEL : CooperativeMatrixOperandsOperand<0x20, [SP
17461747
defm MatrixAAndBBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x40, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
17471748
defm MatrixCBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x80, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
17481749
defm MatrixResultBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x100, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
1750+
1751+
//===----------------------------------------------------------------------===//
1752+
// Multiclass used to define Matrix Multiply Accumulate Operands enum values and at the same time
1753+
// SymbolicOperand entries with string mnemonics and capabilities.
1754+
//===----------------------------------------------------------------------===//
1755+
def MatrixMultiplyAccumulate : GenericEnum, Operand<i32> {
1756+
let FilterClass = "MatrixMultiplyAccumulate";
1757+
let NameField = "Name";
1758+
let ValueField = "Value";
1759+
let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
1760+
}
1761+
1762+
class MatrixMultiplyAccumulate<string name, bits<32> value> {
1763+
string Name = name;
1764+
bits<32> Value = value;
1765+
}
1766+
1767+
multiclass MatrixMultiplyAccumulateOperandsOperand<bits<32> value, list<Extension> reqExtensions> {
1768+
def : MatrixMultiplyAccumulate<NAME, value>;
1769+
defm : SymbolicOperandWithRequirements< MatrixMultiplyAccumulateOperandsOperand, value, NAME, 0, 0, reqExtensions, []>;
1770+
}
1771+
1772+
defm None : MatrixMultiplyAccumulateOperandsOperand<0x0, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1773+
defm MatrixASignedComponentsINTEL : MatrixMultiplyAccumulateOperandsOperand<0x1, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1774+
defm MatrixBSignedComponentsINTEL : MatrixMultiplyAccumulateOperandsOperand<0x2, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1775+
defm MatrixCBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x4, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1776+
defm MatrixResultBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x8, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1777+
defm MatrixAPackedInt8INTEL : MatrixMultiplyAccumulateOperandsOperand<0x10, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1778+
defm MatrixBPackedInt8INTEL : MatrixMultiplyAccumulateOperandsOperand<0x20, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1779+
defm MatrixAPackedInt4INTEL : MatrixMultiplyAccumulateOperandsOperand<0x40, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1780+
defm MatrixBPackedInt4INTEL : MatrixMultiplyAccumulateOperandsOperand<0x80, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1781+
defm MatrixATF32INTEL : MatrixMultiplyAccumulateOperandsOperand<0x100, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1782+
defm MatrixBTF32INTEL : MatrixMultiplyAccumulateOperandsOperand<0x200, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1783+
defm MatrixAPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x400, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1784+
defm MatrixBPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1785+
defm MatrixAPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1786+
defm MatrixBPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;

0 commit comments

Comments
 (0)