Skip to content

Commit 38ff707

Browse files
committed
Address review comments.
Use LLVM data structures whenever possible.
1 parent 82737ce commit 38ff707

File tree

2 files changed

+88
-70
lines changed

2 files changed

+88
-70
lines changed

mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
1919
#include "mlir/IR/BuiltinTypes.h"
2020
#include "mlir/IR/TypeUtilities.h"
21+
#include "llvm/ADT/SmallVector.h"
2122
#include "llvm/Support/DebugLog.h"
2223
#include <map>
2324
#include <string>
24-
#include <vector>
2525

2626
#define DEBUG_TYPE "xegpu-uarch"
2727

@@ -47,28 +47,29 @@ struct XeCoreInfo {
4747

4848
struct Xe2Plus : public uArch {
4949
XeCoreInfo xe_core;
50-
Xe2Plus(
51-
const std::string &archName, const std::string &archDescription,
52-
const XeCoreInfo &xeCore,
53-
const std::map<RegisterFileType, RegisterFileInfo> &regInfo = {},
54-
const std::vector<CacheInfo> &cacheInfo = {},
55-
const std::map<std::string, std::shared_ptr<Instruction>> &instrs = {})
50+
Xe2Plus(const std::string &archName, const std::string &archDescription,
51+
const XeCoreInfo &xeCore,
52+
const std::map<RegisterFileType, RegisterFileInfo> &regInfo = {},
53+
const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
54+
const std::map<InstructionKind, std::shared_ptr<Instruction>>
55+
&instrs = {})
5656
: uArch(archName, archDescription, regInfo, cacheInfo, instrs),
5757
xe_core(xeCore) {}
5858
};
5959

6060
// struct to represent DPAS instruction
6161
struct DPASInstruction : public Instruction, public MMAInstructionInterface {
6262
DPASInstruction()
63-
: Instruction("dpas", // name
64-
"Dot Product Accumulate") // description
63+
: Instruction(InstructionKind::DPAS, // name
64+
"Dot Product Accumulate",
65+
InstructionScope::Subgroup) // description
6566
{}
6667

6768
// Override all virtuals from MatrixOpInterface
68-
virtual std::vector<std::pair<uint32_t, uint32_t>>
69+
virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
6970
getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
70-
virtual std::vector<Type> getSupportedTypes(MLIRContext &context,
71-
MMAOpndKind matrixType) override;
71+
virtual llvm::SmallVector<Type, 8>
72+
getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
7273
virtual bool
7374
checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
7475
std::pair<uint32_t, uint32_t> BShape,
@@ -82,14 +83,14 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
8283
std::pair<uint32_t, uint32_t> CShape,
8384
std::pair<uint32_t, uint32_t> DShape, Type AType,
8485
Type BType, Type CType, Type DType) override;
85-
virtual std::vector<uint32_t> getSupportedM(Type type) override;
86-
virtual std::vector<uint32_t> getSupportedK(Type type) override;
87-
virtual std::vector<uint32_t> getSupportedN(Type type) override;
86+
virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) override;
87+
virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) override;
88+
virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
8889
};
8990

9091
struct PVCuArch : public Xe2Plus {
9192
// Maintaines ownership of the instructions owned by PVUarch
92-
std::vector<std::shared_ptr<Instruction>> owned_instructions;
93+
llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
9394
PVCuArch()
9495
: Xe2Plus("pvc", // archName
9596
"Ponte Vecchio Architecture", // archDescription
@@ -115,17 +116,16 @@ struct PVCuArch : public Xe2Plus {
115116
this->cacheInfo.push_back(
116117
CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
117118

118-
// Add the instructions
119+
// Add the instructions-
119120
auto dpas = std::make_shared<DPASInstruction>();
120-
instructions.emplace(dpas->getName(), dpas);
121-
// instructions[dpas->name] = dpas.get();
121+
instructions.emplace(dpas->getInstructionKind(), dpas);
122122
owned_instructions.push_back(dpas);
123123
}
124124
};
125125

126126
struct BMGuArch : public Xe2Plus {
127127
// Maintaines ownership of the instructions owned by PVUarch
128-
std::vector<std::shared_ptr<Instruction>> owned_instructions;
128+
llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
129129
BMGuArch()
130130
: Xe2Plus("bmg", // archName
131131
"Battlemage Architecture", // archDescription
@@ -151,21 +151,20 @@ struct BMGuArch : public Xe2Plus {
151151

152152
// Add the instructions
153153
auto dpas = std::make_shared<DPASInstruction>();
154-
instructions.emplace(dpas->getName(), dpas);
155-
// instructions[dpas->name] = dpas.get();
154+
instructions.emplace(dpas->getInstructionKind(), dpas);
156155
owned_instructions.push_back(dpas);
157156
}
158157
};
159158
} // namespace uArch
160159
} // namespace xegpu
161160
} // namespace mlir
162161

163-
inline std::vector<std::pair<uint32_t, uint32_t>>
162+
inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
164163
DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
165-
auto combineVectors = [](const std::vector<uint32_t> &a,
166-
const std::vector<uint32_t> &b)
167-
-> std::vector<std::pair<uint32_t, uint32_t>> {
168-
std::vector<std::pair<uint32_t, uint32_t>> result;
164+
auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
165+
const llvm::SmallVector<uint32_t, 8> &b)
166+
-> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
167+
llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> result;
169168
for (unsigned x : a) {
170169
for (unsigned y : b) {
171170
result.emplace_back(x, y);
@@ -177,7 +176,7 @@ DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
177176
auto M = getSupportedM(dataType);
178177
auto K = getSupportedK(dataType);
179178
auto N = getSupportedN(dataType);
180-
std::vector<std::pair<unsigned, unsigned>> resultMatrix;
179+
llvm::SmallVector<std::pair<unsigned, unsigned>, 16> resultMatrix;
181180

182181
switch (matrixType) {
183182
case MMAOpndKind::MatrixA:
@@ -196,7 +195,7 @@ DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
196195
return resultMatrix;
197196
}
198197

199-
inline std::vector<Type>
198+
inline llvm::SmallVector<Type, 8>
200199
DPASInstruction::getSupportedTypes(MLIRContext &context,
201200
MMAOpndKind matrixType) {
202201
Type bf16Type = BFloat16Type::get(&context);
@@ -207,17 +206,14 @@ DPASInstruction::getSupportedTypes(MLIRContext &context,
207206
switch (matrixType) {
208207
case MMAOpndKind::MatrixA:
209208
return {bf16Type, f16Type, tf32Type};
210-
break;
211209
case MMAOpndKind::MatrixB:
212210
return {bf16Type, f16Type, tf32Type};
213-
break;
214211
case MMAOpndKind::MatrixC:
215212
return {bf16Type, f16Type, f32Type};
216-
break;
217213
case MMAOpndKind::MatrixD:
218214
return {bf16Type, f16Type, f32Type};
219-
break;
220215
}
216+
return {};
221217
}
222218

223219
inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
@@ -276,11 +272,13 @@ inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
276272
BType, CType, DType);
277273
}
278274

279-
inline std::vector<uint32_t> DPASInstruction::getSupportedM(Type type) {
275+
inline llvm::SmallVector<uint32_t, 8>
276+
DPASInstruction::getSupportedM(Type type) {
280277
return {1, 2, 3, 4, 5, 6, 7, 8};
281278
}
282279

283-
inline std::vector<uint32_t> DPASInstruction::getSupportedK(Type type) {
280+
inline llvm::SmallVector<uint32_t, 8>
281+
DPASInstruction::getSupportedK(Type type) {
284282
// assert if data type is not int or float type
285283
assert(type.isIntOrFloat() && "Matrix type must be int or float");
286284
auto bitWidth = type.getIntOrFloatBitWidth();
@@ -307,8 +305,9 @@ inline std::vector<uint32_t> DPASInstruction::getSupportedK(Type type) {
307305
return {kSize};
308306
}
309307

310-
inline std::vector<uint32_t> DPASInstruction::getSupportedN(Type type) {
308+
inline llvm::SmallVector<uint32_t, 8>
309+
DPASInstruction::getSupportedN(Type type) {
311310
return {16};
312311
}
313312

314-
#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2H
313+
#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H

mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <tuple>
2424

2525
#include "mlir/IR/Types.h"
26+
#include "llvm/ADT/SmallVector.h"
2627

2728
namespace mlir {
2829
namespace xegpu {
@@ -31,12 +32,26 @@ namespace uArch {
3132
// An enum class to represent the scope of an instruction
3233
enum class InstructionScope { WorkItem, Subgroup, Workgroup, Cluster };
3334

34-
enum class InstructionName {
35-
DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix multiply-add
36-
// operation
35+
enum class InstructionKind {
36+
DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
37+
// multiply-add operation
3738
// Add more instructions as needed
3839
};
3940

41+
llvm::StringRef toString(InstructionKind name) {
42+
switch (name) {
43+
case InstructionKind::DPAS:
44+
return "dpas";
45+
}
46+
llvm_unreachable("Unknown InstructionKind");
47+
}
48+
49+
std::optional<InstructionKind> parseInstructionKind(llvm::StringRef str) {
50+
if (str.equals_insensitive("dpas"))
51+
return InstructionKind::DPAS;
52+
return std::nullopt;
53+
}
54+
4055
// A struct to represent basic information about an instruction
4156
// This struct is used to represent the information about an instruction in the
4257
// uArch The information includes:
@@ -56,17 +71,17 @@ enum class InstructionName {
5671

5772
struct Instruction {
5873
// @TODO: Add more fields as needed
59-
Instruction(std::string name, std::string desc)
60-
: name(std::move(name)), description(std::move(desc)) {}
74+
Instruction(InstructionKind kind, std::string desc, InstructionScope scope)
75+
: instKind(kind), description(std::move(desc)), scope(scope) {}
6176

6277
virtual ~Instruction() = default;
6378
// Get methods
64-
std::string getName() { return name; }
79+
InstructionKind getInstructionKind() { return instKind; }
6580
std::string getDescription() { return description; }
6681
InstructionScope getScope() { return scope; }
6782

6883
protected:
69-
std::string name;
84+
InstructionKind instKind;
7085
std::string description;
7186
InstructionScope scope;
7287
};
@@ -78,23 +93,25 @@ enum class RegisterFileType : uint8_t { GRF, ARF };
7893
struct RegisterFileInfo {
7994
// Constructor
8095
RegisterFileInfo() = default;
81-
RegisterFileInfo(uint32_t size, const std::vector<RegisterFileMode> &mode,
82-
const std::vector<uint32_t> &numRegs)
96+
RegisterFileInfo(uint32_t size,
97+
const llvm::SmallVector<RegisterFileMode, 4> &mode,
98+
const llvm::SmallVector<uint32_t, 4> &numRegs)
8399
: size(size), mode(mode), numRegsPerThreadPerMode(numRegs) {}
84100

85101
uint32_t getSize() const { return size; }
86-
const std::vector<RegisterFileMode> &getModes() const { return mode; }
87-
const std::vector<uint32_t> &getNumRegsPerThreadPerMode() const {
102+
const llvm::SmallVector<RegisterFileMode, 4> &getModes() const {
103+
return mode;
104+
}
105+
const llvm::SmallVector<uint32_t, 4> &getNumRegsPerThreadPerMode() const {
88106
return numRegsPerThreadPerMode;
89107
}
90108

91109
protected:
92-
uint32_t size; // size per register in bits
93-
std::vector<RegisterFileMode> mode; // e.g., "small", "large" GRF modes
94-
std::vector<uint32_t>
110+
uint32_t size; // size per register in bits
111+
llvm::SmallVector<RegisterFileMode, 4>
112+
mode; // e.g., "small", "large" GRF modes
113+
llvm::SmallVector<uint32_t, 4>
95114
numRegsPerThreadPerMode; // number of registers per thread per mode
96-
// TODO: Add more fields as needed (e.g., num_banks, bank_size, num_ports,
97-
// port_width, bank_conflicts)
98115
};
99116

100117
enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
@@ -136,8 +153,8 @@ struct uArch {
136153
uArch(const std::string &name, const std::string &description,
137154
const std::map<RegisterFileType, RegisterFileInfo> &register_file_info =
138155
{},
139-
const std::vector<CacheInfo> &cache_info = {},
140-
const std::map<std::string, std::shared_ptr<Instruction>>
156+
const llvm::SmallVector<CacheInfo, 4> &cache_info = {},
157+
const std::map<InstructionKind, std::shared_ptr<Instruction>>
141158
&instructions = {})
142159
: name(name), description(description),
143160
registerFileInfo(register_file_info), cacheInfo(cache_info),
@@ -153,34 +170,36 @@ struct uArch {
153170
return registerFileInfo;
154171
}
155172

156-
const std::vector<CacheInfo> &getCacheInfo() const { return cacheInfo; }
173+
const llvm::SmallVector<CacheInfo, 4> &getCacheInfo() const {
174+
return cacheInfo;
175+
}
157176

158-
const std::map<std::string, std::shared_ptr<Instruction>> &
177+
const std::map<InstructionKind, std::shared_ptr<Instruction>> &
159178
getInstructions() const {
160179
return instructions;
161180
}
162181

163182
// Get the name of the supported instruction names for that
164183
// architecture. It returns the names of the instructions added to the uArch.
165-
std::vector<std::string> getSupportedInstructionNames() const {
166-
std::vector<std::string> instructionNames;
184+
llvm::SmallVector<StringRef, 8> getSupportedInstructionNames() const {
185+
llvm::SmallVector<StringRef, 8> instructionNames;
167186
for (const auto &inst : instructions) {
168-
instructionNames.push_back(inst.first);
187+
instructionNames.push_back(toString(inst.first));
169188
}
170189
return instructionNames;
171190
}
172191

173192
// Checks if an instruction is supported in this uArch
174-
bool checkSupportedInstruction(const std::string &instructionName) const {
175-
return instructions.find(instructionName) != instructions.end();
193+
bool checkSupportedInstruction(InstructionKind instr) const {
194+
return instructions.find(instr) != instructions.end();
176195
}
177196

178197
protected:
179198
std::string name; // Similar to target triple
180199
std::string description;
181200
std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
182-
std::vector<CacheInfo> cacheInfo;
183-
std::map<std::string, std::shared_ptr<Instruction>> instructions;
201+
llvm::SmallVector<CacheInfo, 4> cacheInfo;
202+
std::map<InstructionKind, std::shared_ptr<Instruction>> instructions;
184203
};
185204

186205
// A struct to represent shared memory information
@@ -205,7 +224,7 @@ struct SharedMemory {
205224
enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD };
206225
struct MMAInstructionInterface {
207226
// Get supported Matrix shapes
208-
virtual std::vector<std::pair<uint32_t, uint32_t>>
227+
virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
209228
getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
210229
// @TODO: This method takes an context object as a parameter, this is to
211230
// create the Type objects from the same context. Since type objects are
@@ -220,8 +239,8 @@ struct MMAInstructionInterface {
220239
//
221240
// Untill we have a better solution, we stick to passing context object to
222241
// this method.
223-
virtual std::vector<Type> getSupportedTypes(MLIRContext &context,
224-
MMAOpndKind matrixType) = 0;
242+
virtual llvm::SmallVector<Type, 8>
243+
getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) = 0;
225244
virtual bool
226245
checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
227246
std::pair<uint32_t, uint32_t> BShape,
@@ -235,9 +254,9 @@ struct MMAInstructionInterface {
235254
std::pair<uint32_t, uint32_t> CShape,
236255
std::pair<uint32_t, uint32_t> DShape, Type AType,
237256
Type BType, Type CType, Type DType) = 0;
238-
virtual std::vector<uint32_t> getSupportedM(Type type) = 0;
239-
virtual std::vector<uint32_t> getSupportedK(Type type) = 0;
240-
virtual std::vector<uint32_t> getSupportedN(Type type) = 0;
257+
virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) = 0;
258+
virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) = 0;
259+
virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) = 0;
241260

242261
virtual ~MMAInstructionInterface() = default;
243262
};

0 commit comments

Comments
 (0)