Skip to content

Commit 32a2531

Browse files
committed
Address feedback
1 parent e834a9f commit 32a2531

File tree

2 files changed

+64
-57
lines changed

2 files changed

+64
-57
lines changed

mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ struct Xe2Plus : public uArch {
3636
const XeCoreInfo &xeCore)
3737
: uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
3838
int getSubgroupSize() const override { return 16; }
39-
unsigned getPackedFormatBitSize() const override { return 16; }
40-
unsigned getPackedFormatBitSizeGatherScatter() const override { return 32; }
39+
unsigned getGeneralPackedFormatBitSize() const override { return 32; }
4140

4241
protected:
4342
XeCoreInfo xeCore;
@@ -46,16 +45,15 @@ struct Xe2Plus : public uArch {
4645
//===----------------------------------------------------------------------===//
4746
// uArch instructions
4847
//===----------------------------------------------------------------------===//
49-
struct StoreNdInstruction : public Instruction {
50-
StoreNdInstruction()
51-
: Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}
48+
struct Subgroup2DBlockStoreInstruction : public Instruction {
49+
Subgroup2DBlockStoreInstruction()
50+
: Instruction(InstructionKind::Subgroup2DBlockStore,
51+
InstructionScope::Subgroup) {}
5252
static bool classof(const Instruction *B) {
53-
return B->getInstructionKind() == InstructionKind::STORE_ND;
53+
return B->getInstructionKind() == InstructionKind::Subgroup2DBlockStore;
5454
}
5555
// Source :
5656
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
57-
// Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
58-
// the specified pointer
5957
std::optional<
6058
std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
6159
getBlockWidthHeightCount(Type elemTy) const {
@@ -74,19 +72,20 @@ struct StoreNdInstruction : public Instruction {
7472
llvm::ArrayRef<int>(kCount));
7573
return std::nullopt;
7674
}
75+
76+
int32_t getPackedFormatBitSize() const { return 16; }
7777
};
7878

79-
struct LoadNdInstruction : public Instruction {
80-
LoadNdInstruction()
81-
: Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}
79+
struct Subgroup2DBlockLoadInstruction : public Instruction {
80+
Subgroup2DBlockLoadInstruction()
81+
: Instruction(InstructionKind::Subgroup2DBlockLoad,
82+
InstructionScope::Subgroup) {}
8283
static bool classof(const Instruction *B) {
83-
return B->getInstructionKind() == InstructionKind::LOAD_ND;
84+
return B->getInstructionKind() == InstructionKind::Subgroup2DBlockLoad;
8485
}
8586

8687
// Source :
8788
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
88-
// Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
89-
// the specified pointer.
9089
std::optional<
9190
std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
9291
getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose,
@@ -126,13 +125,16 @@ struct LoadNdInstruction : public Instruction {
126125
return it->second;
127126
return std::nullopt;
128127
}
128+
129+
int32_t getPackedFormatBitSize() const { return 16; }
129130
};
130131

131-
struct PrefetchNdInstruction : public Instruction {
132-
PrefetchNdInstruction()
133-
: Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}
132+
struct Subgroup2DBlockPrefetchInstruction : public Instruction {
133+
Subgroup2DBlockPrefetchInstruction()
134+
: Instruction(InstructionKind::Subgroup2DBlockPrefetch,
135+
InstructionScope::Subgroup) {}
134136
static bool classof(const Instruction *B) {
135-
return B->getInstructionKind() == InstructionKind::PREFETCH_ND;
137+
return B->getInstructionKind() == InstructionKind::Subgroup2DBlockPrefetch;
136138
}
137139
// Source :
138140
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
@@ -162,15 +164,20 @@ struct PrefetchNdInstruction : public Instruction {
162164
return it->second;
163165
return std::nullopt;
164166
}
167+
int32_t getPackedFormatBitSize() const { return 16; }
165168
};
166169

167-
struct DPASInstruction : public Instruction, public MMAInstructionInterface {
168-
DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
169-
: Instruction(InstructionKind::DPAS, InstructionScope::Subgroup),
170+
struct SubgroupMatrixMultiplyAcc : public Instruction,
171+
public MMAInstructionInterface {
172+
SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA,
173+
unsigned packedFormatBitSizeB)
174+
: Instruction(InstructionKind::SubgroupMatrixMultiplyAcc,
175+
InstructionScope::Subgroup),
170176
packedFormatBitSizeA(packedFormatBitSizeA),
171177
packedFormatBitSizeB(packedFormatBitSizeB) {}
172178
static bool classof(const Instruction *B) {
173-
return B->getInstructionKind() == InstructionKind::DPAS;
179+
return B->getInstructionKind() ==
180+
InstructionKind::SubgroupMatrixMultiplyAcc;
174181
}
175182
// Source:
176183
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
@@ -214,10 +221,10 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
214221

215222
struct PVCuArch final : public Xe2Plus {
216223
static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
217-
static const DPASInstruction dpasInst{16, 32};
218-
static const StoreNdInstruction loadNdInst;
219-
static const StoreNdInstruction storeNdInst;
220-
static const PrefetchNdInstruction prefetchNdInst;
224+
static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
225+
static const Subgroup2DBlockLoadInstruction loadNdInst;
226+
static const Subgroup2DBlockStoreInstruction storeNdInst;
227+
static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
221228
static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
222229
&prefetchNdInst};
223230
return arr;
@@ -237,10 +244,10 @@ struct PVCuArch final : public Xe2Plus {
237244

238245
struct BMGuArch : public Xe2Plus {
239246
static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
240-
static const DPASInstruction dpasInst{16, 32};
241-
static const StoreNdInstruction loadNdInst;
242-
static const StoreNdInstruction storeNdInst;
243-
static const PrefetchNdInstruction prefetchNdInst;
247+
static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
248+
static const Subgroup2DBlockLoadInstruction loadNdInst;
249+
static const Subgroup2DBlockStoreInstruction storeNdInst;
250+
static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
244251
static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
245252
&prefetchNdInst};
246253
return arr;
@@ -276,7 +283,8 @@ inline const uArch *getUArch(llvm::StringRef archName) {
276283
//===----------------------------------------------------------------------===//
277284

278285
inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
279-
DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
286+
SubgroupMatrixMultiplyAcc::getSupportedShapes(Type dataType,
287+
MMAOpndKind matrixType) {
280288
auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
281289
const llvm::SmallVector<uint32_t, 8> &b)
282290
-> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
@@ -312,8 +320,8 @@ DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
312320
}
313321

314322
inline llvm::SmallVector<Type, 8>
315-
DPASInstruction::getSupportedTypes(MLIRContext &context,
316-
MMAOpndKind matrixType) {
323+
SubgroupMatrixMultiplyAcc::getSupportedTypes(MLIRContext &context,
324+
MMAOpndKind matrixType) {
317325
Type bf16Type = BFloat16Type::get(&context);
318326
Type f16Type = Float16Type::get(&context);
319327
Type tf32Type = FloatTF32Type::get(&context);
@@ -332,8 +340,10 @@ DPASInstruction::getSupportedTypes(MLIRContext &context,
332340
return {};
333341
}
334342

335-
inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
336-
Type CType, Type DType) {
343+
inline bool SubgroupMatrixMultiplyAcc::checkSupportedTypes(Type AType,
344+
Type BType,
345+
Type CType,
346+
Type DType) {
337347
if (AType.isF16() || BType.isF16()) {
338348
if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
339349
(!DType.isF32() && !DType.isF16())) {
@@ -363,7 +373,7 @@ inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
363373
return true;
364374
}
365375

366-
inline bool DPASInstruction::checkSupportedShapesAndTypes(
376+
inline bool SubgroupMatrixMultiplyAcc::checkSupportedShapesAndTypes(
367377
std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
368378
std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
369379
Type AType, Type BType, Type CType, Type DType) {
@@ -378,23 +388,21 @@ inline bool DPASInstruction::checkSupportedShapesAndTypes(
378388
checkSupportedTypes(AType, BType, CType, DType);
379389
}
380390

381-
inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
382-
std::pair<uint32_t, uint32_t> BShape,
383-
std::pair<uint32_t, uint32_t> CShape,
384-
std::pair<uint32_t, uint32_t> DShape,
385-
Type AType, Type BType, Type CType,
386-
Type DType) {
391+
inline bool SubgroupMatrixMultiplyAcc::validate(
392+
std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
393+
std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
394+
Type AType, Type BType, Type CType, Type DType) {
387395
return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
388396
BType, CType, DType);
389397
}
390398

391399
inline llvm::SmallVector<uint32_t, 8>
392-
DPASInstruction::getSupportedM(Type type) const {
400+
SubgroupMatrixMultiplyAcc::getSupportedM(Type type) const {
393401
return {1, 2, 3, 4, 5, 6, 7, 8};
394402
}
395403

396404
inline llvm::SmallVector<uint32_t, 8>
397-
DPASInstruction::getSupportedK(Type type) const {
405+
SubgroupMatrixMultiplyAcc::getSupportedK(Type type) const {
398406
// assert if data type is not int or float type
399407
assert(type.isIntOrFloat() && "Matrix type must be int or float");
400408
auto bitWidth = type.getIntOrFloatBitWidth();
@@ -422,7 +430,7 @@ DPASInstruction::getSupportedK(Type type) const {
422430
}
423431

424432
inline llvm::SmallVector<uint32_t, 8>
425-
DPASInstruction::getSupportedN(Type type) const {
433+
SubgroupMatrixMultiplyAcc::getSupportedN(Type type) const {
426434
return {16};
427435
}
428436

mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ namespace uArch {
3232
// An enum class to represent the scope of an instruction
3333
enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
3434
enum class InstructionKind {
35-
DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
36-
// multiply-add operation
37-
STORE_ND, // Subgroup-level 2D block write instruction
38-
LOAD_ND, // Subgroup-level 2D block load instruction
39-
PREFETCH_ND // Subgroup-level 2D block prefetch instruction
35+
SubgroupMatrixMultiplyAcc, // Dot Product Accumulate Systolic (DPAS) is a
36+
// matrix multiply-add operation
37+
Subgroup2DBlockStore, // Subgroup-level 2D block write instruction
38+
Subgroup2DBlockLoad, // Subgroup-level 2D block load instruction
39+
Subgroup2DBlockPrefetch // Subgroup-level 2D block prefetch instruction
4040
// @TODO: Add more instructions as needed
4141
};
4242

@@ -55,13 +55,13 @@ struct Instruction {
5555
InstructionScope getScope() const { return scope; }
5656
static llvm::StringRef toString(InstructionKind instKind) {
5757
switch (instKind) {
58-
case InstructionKind::DPAS:
58+
case InstructionKind::SubgroupMatrixMultiplyAcc:
5959
return "dpas";
60-
case InstructionKind::STORE_ND:
60+
case InstructionKind::Subgroup2DBlockStore:
6161
return "store_nd";
62-
case InstructionKind::LOAD_ND:
62+
case InstructionKind::Subgroup2DBlockLoad:
6363
return "load_nd";
64-
case InstructionKind::PREFETCH_ND:
64+
case InstructionKind::Subgroup2DBlockPrefetch:
6565
return "prefetch_nd";
6666
}
6767
llvm_unreachable("Unknown InstructionKind");
@@ -70,7 +70,7 @@ struct Instruction {
7070
static std::optional<InstructionKind>
7171
parseInstructionKind(llvm::StringRef str) {
7272
if (str.equals_insensitive("dpas"))
73-
return InstructionKind::DPAS;
73+
return InstructionKind::SubgroupMatrixMultiplyAcc;
7474
return std::nullopt;
7575
}
7676

@@ -150,8 +150,7 @@ struct uArch {
150150
StringRef getName() const { return name; }
151151
StringRef getDescription() const { return description; }
152152
virtual int getSubgroupSize() const = 0;
153-
virtual unsigned getPackedFormatBitSize() const = 0;
154-
virtual unsigned getPackedFormatBitSizeGatherScatter() const = 0;
153+
virtual unsigned getGeneralPackedFormatBitSize() const = 0;
155154

156155
const Instruction *getInstruction(InstructionKind instKind) const {
157156
auto it = instructionRegistry.find(instKind);

0 commit comments

Comments
 (0)