Skip to content

Commit e834a9f

Browse files
committed
Update instructions
1 parent 1bf2b65 commit e834a9f

File tree

1 file changed

+87
-29
lines changed

1 file changed

+87
-29
lines changed

mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h

Lines changed: 87 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,26 @@ struct StoreNdInstruction : public Instruction {
5353
return B->getInstructionKind() == InstructionKind::STORE_ND;
5454
}
5555
// Source :
56-
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
56+
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
5757
// Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
5858
// the specified pointer
59-
llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
60-
const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
61-
return sortedLaneVectorLengths;
59+
std::optional<
60+
std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
61+
getBlockWidthHeightCount(Type elemTy) const {
62+
const static int kHeight[] = {1, 2, 4, 8};
63+
const static int kWidth16[] = {16};
64+
const static int kWidth32[] = {16};
65+
const static int kCount[] = {1};
66+
const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
67+
if (elemByteSize == 1)
68+
return std::make_tuple(llvm::ArrayRef<int>(kWidth32),
69+
llvm::ArrayRef<int>(kHeight),
70+
llvm::ArrayRef<int>(kCount));
71+
else if (elemByteSize == 2 || elemByteSize == 4)
72+
return std::make_tuple(llvm::ArrayRef<int>(kWidth16),
73+
llvm::ArrayRef<int>(kHeight),
74+
llvm::ArrayRef<int>(kCount));
75+
return std::nullopt;
6276
}
6377
};
6478

@@ -68,13 +82,49 @@ struct LoadNdInstruction : public Instruction {
6882
static bool classof(const Instruction *B) {
6983
return B->getInstructionKind() == InstructionKind::LOAD_ND;
7084
}
85+
7186
// Source :
72-
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
87+
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
7388
// Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
7489
// the specified pointer.
75-
llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
76-
const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
77-
return sortedLaneVectorLengths;
90+
std::optional<
91+
std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
92+
getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose,
93+
bool upConv = false) const {
94+
static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
95+
static const int kHeightAtLeast8[] = {8, 16, 32};
96+
static const int kHeightAtLeast16[] = {16, 32};
97+
static const int kHeightAtLeast32[] = {32};
98+
99+
static const int kWidth32[] = {32};
100+
static const int kWidth16[] = {16};
101+
static const int kWidth8[] = {8};
102+
103+
static const int32_t kCount1[] = {1};
104+
static const int32_t kCount2[] = {1, 2};
105+
static const int32_t kCount4[] = {1, 2, 4};
106+
static const int32_t kCount4Only[] = {4};
107+
// (elemBytes, transform, transpose, upConvert)
108+
using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
109+
// (widths, heights, counts)
110+
using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
111+
llvm::ArrayRef<int32_t>>;
112+
static const llvm::DenseMap<Key, Value> kMap = {
113+
{{1, false, false, false}, {kWidth32, kHeightAtLeast1, kCount2}},
114+
{{1, false, false, true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
115+
{{2, false, false, false}, {kWidth16, kHeightAtLeast1, kCount2}},
116+
{{4, false, false, false}, {kWidth16, kHeightAtLeast1, kCount1}},
117+
// Block Loads with Transform:
118+
{{1, true, false, false}, {kWidth16, kHeightAtLeast32, kCount4}},
119+
{{2, true, false, false}, {kWidth16, kHeightAtLeast16, kCount2}},
120+
// Block Loads with Transpose:
121+
{{4, false, true, false}, {kWidth8, kHeightAtLeast16, kCount1}},
122+
};
123+
const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
124+
auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
125+
if (it != kMap.end())
126+
return it->second;
127+
return std::nullopt;
78128
}
79129
};
80130

@@ -86,29 +136,39 @@ struct PrefetchNdInstruction : public Instruction {
86136
}
87137
// Source :
88138
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
89-
llvm::ArrayRef<int> getSortedLaneVectorLengths(int elementBitwidth) const {
90-
const static int sortedNarrowTypesLengths[] = {1, 2, 4, 8, 16};
91-
const static int sortedWideTypesLengths[] = {1, 2, 4, 8};
92-
switch (elementBitwidth) {
93-
case 8:
94-
case 16:
95-
return sortedNarrowTypesLengths;
96-
case 32:
97-
case 64:
98-
return sortedWideTypesLengths;
99-
default:
100-
llvm_unreachable("Unsupported element bitwidth");
101-
}
139+
std::optional<
140+
std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
141+
getBlockWidthHeightCount(Type elemTy) const {
142+
static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
143+
144+
static const int kWidth32[] = {32};
145+
static const int kWidth16[] = {16};
146+
147+
static const int32_t kCount1[] = {1};
148+
static const int32_t kCount2[] = {1, 2};
149+
// elemBytes
150+
using Key = int;
151+
// (widths, heights, counts)
152+
using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
153+
llvm::ArrayRef<int32_t>>;
154+
static const llvm::DenseMap<Key, Value> kMap = {
155+
{1, {kWidth32, kHeightAtLeast1, kCount2}},
156+
{2, {kWidth16, kHeightAtLeast1, kCount2}},
157+
{4, {kWidth16, kHeightAtLeast1, kCount1}},
158+
};
159+
const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
160+
auto it = kMap.find(elemByteSize);
161+
if (it != kMap.end())
162+
return it->second;
163+
return std::nullopt;
102164
}
103165
};
104166

105167
struct DPASInstruction : public Instruction, public MMAInstructionInterface {
106-
DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB,
107-
unsigned packedFormatBitSizeC)
168+
DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
108169
: Instruction(InstructionKind::DPAS, InstructionScope::Subgroup),
109170
packedFormatBitSizeA(packedFormatBitSizeA),
110-
packedFormatBitSizeB(packedFormatBitSizeB),
111-
packedFormatBitSizeC(packedFormatBitSizeC) {}
171+
packedFormatBitSizeB(packedFormatBitSizeB) {}
112172
static bool classof(const Instruction *B) {
113173
return B->getInstructionKind() == InstructionKind::DPAS;
114174
}
@@ -142,12 +202,10 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
142202

143203
unsigned getPackedFormatBitSizeA() const { return packedFormatBitSizeA; }
144204
unsigned getPackedFormatBitSizeB() const { return packedFormatBitSizeB; }
145-
unsigned getPackedFormatBitSizeC() const { return packedFormatBitSizeC; }
146205

147206
protected:
148207
const unsigned packedFormatBitSizeA;
149208
const unsigned packedFormatBitSizeB;
150-
const unsigned packedFormatBitSizeC;
151209
};
152210

153211
//===----------------------------------------------------------------------===//
@@ -156,7 +214,7 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
156214

157215
struct PVCuArch final : public Xe2Plus {
158216
static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
159-
static const DPASInstruction dpasInst{16, 32, 32};
217+
static const DPASInstruction dpasInst{16, 32};
160218
static const StoreNdInstruction loadNdInst;
161219
static const StoreNdInstruction storeNdInst;
162220
static const PrefetchNdInstruction prefetchNdInst;
@@ -179,7 +237,7 @@ struct PVCuArch final : public Xe2Plus {
179237

180238
struct BMGuArch : public Xe2Plus {
181239
static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
182-
static const DPASInstruction dpasInst{16, 32, 32};
240+
static const DPASInstruction dpasInst{16, 32};
183241
static const StoreNdInstruction loadNdInst;
184242
static const StoreNdInstruction storeNdInst;
185243
static const PrefetchNdInstruction prefetchNdInst;

0 commit comments

Comments
 (0)