18
18
#include " mlir/Dialect/XeGPU/uArch/uArchBase.h"
19
19
#include " mlir/IR/BuiltinTypes.h"
20
20
#include " mlir/IR/TypeUtilities.h"
21
+ #include " llvm/ADT/SmallVector.h"
21
22
#include " llvm/Support/DebugLog.h"
22
23
#include < map>
23
24
#include < string>
24
- #include < vector>
25
25
26
26
#define DEBUG_TYPE " xegpu-uarch"
27
27
@@ -47,28 +47,29 @@ struct XeCoreInfo {
47
47
48
48
struct Xe2Plus : public uArch {
49
49
XeCoreInfo xe_core;
50
- Xe2Plus (
51
- const std::string &archName, const std::string &archDescription ,
52
- const XeCoreInfo &xeCore ,
53
- const std::map<RegisterFileType, RegisterFileInfo > ®Info = {},
54
- const std::vector<CacheInfo> &cacheInfo = {},
55
- const std::map<std::string, std::shared_ptr<Instruction>> &instrs = {})
50
+ Xe2Plus (const std::string &archName, const std::string &archDescription,
51
+ const XeCoreInfo &xeCore ,
52
+ const std::map<RegisterFileType, RegisterFileInfo> ®Info = {} ,
53
+ const llvm::SmallVector<CacheInfo, 4 > &cacheInfo = {},
54
+ const std::map<InstructionKind, std::shared_ptr<Instruction>>
55
+ &instrs = {})
56
56
: uArch(archName, archDescription, regInfo, cacheInfo, instrs),
57
57
xe_core (xeCore) {}
58
58
};
59
59
60
60
// struct to represent DPAS instruction
61
61
struct DPASInstruction : public Instruction , public MMAInstructionInterface {
62
62
DPASInstruction ()
63
- : Instruction(" dpas" , // name
64
- " Dot Product Accumulate" ) // description
63
+ : Instruction(InstructionKind::DPAS, // name
64
+ " Dot Product Accumulate" ,
65
+ InstructionScope::Subgroup) // description
65
66
{}
66
67
67
68
// Override all virtuals from MatrixOpInterface
68
- virtual std::vector <std::pair<uint32_t , uint32_t >>
69
+ virtual llvm::SmallVector <std::pair<uint32_t , uint32_t >, 16 >
69
70
getSupportedShapes (Type dataType, MMAOpndKind matrixType) override ;
70
- virtual std::vector <Type> getSupportedTypes (MLIRContext &context,
71
- MMAOpndKind matrixType) override ;
71
+ virtual llvm::SmallVector <Type, 8 >
72
+ getSupportedTypes (MLIRContext &context, MMAOpndKind matrixType) override ;
72
73
virtual bool
73
74
checkSupportedShapesAndTypes (std::pair<uint32_t , uint32_t > AShape,
74
75
std::pair<uint32_t , uint32_t > BShape,
@@ -82,14 +83,14 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
82
83
std::pair<uint32_t , uint32_t > CShape,
83
84
std::pair<uint32_t , uint32_t > DShape, Type AType,
84
85
Type BType, Type CType, Type DType) override ;
85
- virtual std::vector <uint32_t > getSupportedM (Type type) override ;
86
- virtual std::vector <uint32_t > getSupportedK (Type type) override ;
87
- virtual std::vector <uint32_t > getSupportedN (Type type) override ;
86
+ virtual llvm::SmallVector <uint32_t , 8 > getSupportedM (Type type) override ;
87
+ virtual llvm::SmallVector <uint32_t , 8 > getSupportedK (Type type) override ;
88
+ virtual llvm::SmallVector <uint32_t , 8 > getSupportedN (Type type) override ;
88
89
};
89
90
90
91
struct PVCuArch : public Xe2Plus {
91
92
// Maintaines ownership of the instructions owned by PVUarch
92
- std::vector <std::shared_ptr<Instruction>> owned_instructions;
93
+ llvm::SmallVector <std::shared_ptr<Instruction>, 8 > owned_instructions;
93
94
PVCuArch ()
94
95
: Xe2Plus(" pvc" , // archName
95
96
" Ponte Vecchio Architecture" , // archDescription
@@ -115,17 +116,16 @@ struct PVCuArch : public Xe2Plus {
115
116
this ->cacheInfo .push_back (
116
117
CacheInfo (512 * 1024 , 64 , CacheHierarchyLevel::L2));
117
118
118
- // Add the instructions
119
+ // Add the instructions-
119
120
auto dpas = std::make_shared<DPASInstruction>();
120
- instructions.emplace (dpas->getName (), dpas);
121
- // instructions[dpas->name] = dpas.get();
121
+ instructions.emplace (dpas->getInstructionKind (), dpas);
122
122
owned_instructions.push_back (dpas);
123
123
}
124
124
};
125
125
126
126
struct BMGuArch : public Xe2Plus {
127
127
// Maintaines ownership of the instructions owned by PVUarch
128
- std::vector <std::shared_ptr<Instruction>> owned_instructions;
128
+ llvm::SmallVector <std::shared_ptr<Instruction>, 8 > owned_instructions;
129
129
BMGuArch ()
130
130
: Xe2Plus(" bmg" , // archName
131
131
" Battlemage Architecture" , // archDescription
@@ -151,21 +151,20 @@ struct BMGuArch : public Xe2Plus {
151
151
152
152
// Add the instructions
153
153
auto dpas = std::make_shared<DPASInstruction>();
154
- instructions.emplace (dpas->getName (), dpas);
155
- // instructions[dpas->name] = dpas.get();
154
+ instructions.emplace (dpas->getInstructionKind (), dpas);
156
155
owned_instructions.push_back (dpas);
157
156
}
158
157
};
159
158
} // namespace uArch
160
159
} // namespace xegpu
161
160
} // namespace mlir
162
161
163
- inline std::vector <std::pair<uint32_t , uint32_t >>
162
+ inline llvm::SmallVector <std::pair<uint32_t , uint32_t >, 16 >
164
163
DPASInstruction::getSupportedShapes (Type dataType, MMAOpndKind matrixType) {
165
- auto combineVectors = [](const std::vector <uint32_t > &a,
166
- const std::vector <uint32_t > &b)
167
- -> std::vector <std::pair<uint32_t , uint32_t >> {
168
- std::vector <std::pair<uint32_t , uint32_t >> result;
164
+ auto combineVectors = [](const llvm::SmallVector <uint32_t , 8 > &a,
165
+ const llvm::SmallVector <uint32_t , 8 > &b)
166
+ -> llvm::SmallVector <std::pair<uint32_t , uint32_t >, 16 > {
167
+ llvm::SmallVector <std::pair<uint32_t , uint32_t >, 16 > result;
169
168
for (unsigned x : a) {
170
169
for (unsigned y : b) {
171
170
result.emplace_back (x, y);
@@ -177,7 +176,7 @@ DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
177
176
auto M = getSupportedM (dataType);
178
177
auto K = getSupportedK (dataType);
179
178
auto N = getSupportedN (dataType);
180
- std::vector <std::pair<unsigned , unsigned >> resultMatrix;
179
+ llvm::SmallVector <std::pair<unsigned , unsigned >, 16 > resultMatrix;
181
180
182
181
switch (matrixType) {
183
182
case MMAOpndKind::MatrixA:
@@ -196,7 +195,7 @@ DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
196
195
return resultMatrix;
197
196
}
198
197
199
- inline std::vector <Type>
198
+ inline llvm::SmallVector <Type, 8 >
200
199
DPASInstruction::getSupportedTypes (MLIRContext &context,
201
200
MMAOpndKind matrixType) {
202
201
Type bf16Type = BFloat16Type::get (&context);
@@ -207,17 +206,14 @@ DPASInstruction::getSupportedTypes(MLIRContext &context,
207
206
switch (matrixType) {
208
207
case MMAOpndKind::MatrixA:
209
208
return {bf16Type, f16Type, tf32Type};
210
- break ;
211
209
case MMAOpndKind::MatrixB:
212
210
return {bf16Type, f16Type, tf32Type};
213
- break ;
214
211
case MMAOpndKind::MatrixC:
215
212
return {bf16Type, f16Type, f32Type};
216
- break ;
217
213
case MMAOpndKind::MatrixD:
218
214
return {bf16Type, f16Type, f32Type};
219
- break ;
220
215
}
216
+ return {};
221
217
}
222
218
223
219
inline bool DPASInstruction::checkSupportedTypes (Type AType, Type BType,
@@ -276,11 +272,13 @@ inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
276
272
BType, CType, DType);
277
273
}
278
274
279
- inline std::vector<uint32_t > DPASInstruction::getSupportedM (Type type) {
275
+ inline llvm::SmallVector<uint32_t , 8 >
276
+ DPASInstruction::getSupportedM (Type type) {
280
277
return {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 };
281
278
}
282
279
283
- inline std::vector<uint32_t > DPASInstruction::getSupportedK (Type type) {
280
+ inline llvm::SmallVector<uint32_t , 8 >
281
+ DPASInstruction::getSupportedK (Type type) {
284
282
// assert if data type is not int or float type
285
283
assert (type.isIntOrFloat () && " Matrix type must be int or float" );
286
284
auto bitWidth = type.getIntOrFloatBitWidth ();
@@ -307,8 +305,9 @@ inline std::vector<uint32_t> DPASInstruction::getSupportedK(Type type) {
307
305
return {kSize };
308
306
}
309
307
310
- inline std::vector<uint32_t > DPASInstruction::getSupportedN (Type type) {
308
+ inline llvm::SmallVector<uint32_t , 8 >
309
+ DPASInstruction::getSupportedN (Type type) {
311
310
return {16 };
312
311
}
313
312
314
- #endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2H
313
+ #endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
0 commit comments