|
15 | 15 | #ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
|
16 | 16 | #define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
|
17 | 17 |
|
18 |
| -#include "mlir/Dialect/XeGPU/uArch/uArchInterfaces.h" |
| 18 | +#include "mlir/Dialect/XeGPU/uArch/uArchBase.h" |
19 | 19 | #include "mlir/IR/BuiltinTypes.h"
|
20 | 20 | #include "mlir/IR/TypeUtilities.h"
|
| 21 | +#include "llvm/Support/DebugLog.h" |
21 | 22 | #include <map>
|
22 | 23 | #include <string>
|
23 | 24 | #include <vector>
|
24 | 25 |
|
| 26 | +#define DEBUG_TYPE "xegpu-uarch" |
| 27 | + |
| 28 | +using namespace mlir; |
| 29 | +using namespace mlir::xegpu::uArch; |
| 30 | + |
25 | 31 | namespace mlir {
|
26 | 32 | namespace xegpu {
|
27 | 33 | namespace uArch {
|
28 |
| -namespace Xe2Plus { |
29 | 34 | struct XeCoreInfo {
|
30 | 35 | uint32_t num_threads;
|
31 | 36 | SharedMemory shared_memory;
|
@@ -61,30 +66,27 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
|
61 | 66 |
|
62 | 67 | // Override all virtuals from MatrixOpInterface
|
63 | 68 | virtual std::vector<std::pair<uint32_t, uint32_t>>
|
64 |
| - getSupportedShapes(mlir::Type dataType, MMAOpndKind matrixType) override; |
65 |
| - virtual std::vector<mlir::Type> |
66 |
| - getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override; |
| 69 | + getSupportedShapes(Type dataType, MMAOpndKind matrixType) override; |
| 70 | + virtual std::vector<Type> getSupportedTypes(MLIRContext &context, |
| 71 | + MMAOpndKind matrixType) override; |
67 | 72 | virtual bool
|
68 | 73 | checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
|
69 | 74 | std::pair<uint32_t, uint32_t> BShape,
|
70 | 75 | std::pair<uint32_t, uint32_t> CShape,
|
71 |
| - std::pair<uint32_t, uint32_t> DShape, |
72 |
| - mlir::Type AType, mlir::Type BType, |
73 |
| - mlir::Type CType, mlir::Type DType) override; |
74 |
| - virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType, |
75 |
| - mlir::Type CType, mlir::Type DType) override; |
| 76 | + std::pair<uint32_t, uint32_t> DShape, Type AType, |
| 77 | + Type BType, Type CType, Type DType) override; |
| 78 | + virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, |
| 79 | + Type DType) override; |
76 | 80 | virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
|
77 | 81 | std::pair<uint32_t, uint32_t> BShape,
|
78 | 82 | std::pair<uint32_t, uint32_t> CShape,
|
79 |
| - std::pair<uint32_t, uint32_t> DShape, mlir::Type AType, |
80 |
| - mlir::Type BType, mlir::Type CType, |
81 |
| - mlir::Type DType) override; |
82 |
| - virtual std::vector<uint32_t> getSupportedM(mlir::Type type) override; |
83 |
| - virtual std::vector<uint32_t> getSupportedK(mlir::Type type) override; |
84 |
| - virtual std::vector<uint32_t> getSupportedN(mlir::Type type) override; |
| 83 | + std::pair<uint32_t, uint32_t> DShape, Type AType, |
| 84 | + Type BType, Type CType, Type DType) override; |
| 85 | + virtual std::vector<uint32_t> getSupportedM(Type type) override; |
| 86 | + virtual std::vector<uint32_t> getSupportedK(Type type) override; |
| 87 | + virtual std::vector<uint32_t> getSupportedN(Type type) override; |
85 | 88 | };
|
86 | 89 |
|
87 |
| -namespace PVCuArch { |
88 | 90 | struct PVCuArch : public Xe2Plus {
|
89 | 91 | // Maintaines ownership of the instructions owned by PVUarch
|
90 | 92 | std::vector<std::shared_ptr<Instruction>> owned_instructions;
|
@@ -120,9 +122,7 @@ struct PVCuArch : public Xe2Plus {
|
120 | 122 | owned_instructions.push_back(dpas);
|
121 | 123 | }
|
122 | 124 | };
|
123 |
| -} // namespace PVCuArch |
124 | 125 |
|
125 |
| -namespace BMGuArch { |
126 | 126 | struct BMGuArch : public Xe2Plus {
|
127 | 127 | // Maintaines ownership of the instructions owned by PVUarch
|
128 | 128 | std::vector<std::shared_ptr<Instruction>> owned_instructions;
|
@@ -156,11 +156,159 @@ struct BMGuArch : public Xe2Plus {
|
156 | 156 | owned_instructions.push_back(dpas);
|
157 | 157 | }
|
158 | 158 | };
|
159 |
| -} // namespace BMGuArch |
160 |
| - |
161 |
| -} // namespace Xe2Plus |
162 | 159 | } // namespace uArch
|
163 | 160 | } // namespace xegpu
|
164 | 161 | } // namespace mlir
|
165 | 162 |
|
| 163 | +inline std::vector<std::pair<uint32_t, uint32_t>> |
| 164 | +DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) { |
| 165 | + auto combineVectors = [](const std::vector<uint32_t> &a, |
| 166 | + const std::vector<uint32_t> &b) |
| 167 | + -> std::vector<std::pair<uint32_t, uint32_t>> { |
| 168 | + std::vector<std::pair<uint32_t, uint32_t>> result; |
| 169 | + for (unsigned x : a) { |
| 170 | + for (unsigned y : b) { |
| 171 | + result.emplace_back(x, y); |
| 172 | + } |
| 173 | + } |
| 174 | + return result; |
| 175 | + }; |
| 176 | + |
| 177 | + auto M = getSupportedM(dataType); |
| 178 | + auto K = getSupportedK(dataType); |
| 179 | + auto N = getSupportedN(dataType); |
| 180 | + std::vector<std::pair<unsigned, unsigned>> resultMatrix; |
| 181 | + |
| 182 | + switch (matrixType) { |
| 183 | + case MMAOpndKind::MatrixA: |
| 184 | + resultMatrix = combineVectors(M, K); |
| 185 | + break; |
| 186 | + case MMAOpndKind::MatrixB: |
| 187 | + resultMatrix = combineVectors(K, N); |
| 188 | + break; |
| 189 | + case MMAOpndKind::MatrixC: |
| 190 | + resultMatrix = combineVectors(M, N); |
| 191 | + break; |
| 192 | + case MMAOpndKind::MatrixD: |
| 193 | + resultMatrix = combineVectors(M, N); |
| 194 | + break; |
| 195 | + } |
| 196 | + return resultMatrix; |
| 197 | +} |
| 198 | + |
| 199 | +inline std::vector<Type> |
| 200 | +DPASInstruction::getSupportedTypes(MLIRContext &context, |
| 201 | + MMAOpndKind matrixType) { |
| 202 | + Type bf16Type = BFloat16Type::get(&context); |
| 203 | + Type f16Type = Float16Type::get(&context); |
| 204 | + Type tf32Type = FloatTF32Type::get(&context); |
| 205 | + Type f32Type = Float32Type::get(&context); |
| 206 | + |
| 207 | + switch (matrixType) { |
| 208 | + case MMAOpndKind::MatrixA: |
| 209 | + return {bf16Type, f16Type, tf32Type}; |
| 210 | + break; |
| 211 | + case MMAOpndKind::MatrixB: |
| 212 | + return {bf16Type, f16Type, tf32Type}; |
| 213 | + break; |
| 214 | + case MMAOpndKind::MatrixC: |
| 215 | + return {bf16Type, f16Type, f32Type}; |
| 216 | + break; |
| 217 | + case MMAOpndKind::MatrixD: |
| 218 | + return {bf16Type, f16Type, f32Type}; |
| 219 | + break; |
| 220 | + } |
| 221 | +} |
| 222 | + |
| 223 | +inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType, |
| 224 | + Type CType, Type DType) { |
| 225 | + if (AType.isF16() || BType.isF16()) { |
| 226 | + if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) || |
| 227 | + (!DType.isF32() && !DType.isF16())) { |
| 228 | + LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices."; |
| 229 | + return false; |
| 230 | + } |
| 231 | + } else if (AType.isBF16() || BType.isBF16()) { |
| 232 | + if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) || |
| 233 | + (!DType.isF32() && !DType.isBF16())) { |
| 234 | + LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices."; |
| 235 | + return false; |
| 236 | + } |
| 237 | + } else if (AType.isTF32() || BType.isTF32()) { |
| 238 | + if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) || |
| 239 | + (!DType.isF32())) { |
| 240 | + LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices."; |
| 241 | + return false; |
| 242 | + } |
| 243 | + } else if (!(AType.isInteger(2) || AType.isInteger(4) || |
| 244 | + AType.isInteger(8)) && |
| 245 | + !(BType.isInteger(2) || BType.isInteger(4) || |
| 246 | + BType.isInteger(8))) { |
| 247 | + LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices."; |
| 248 | + return false; |
| 249 | + } |
| 250 | + |
| 251 | + return true; |
| 252 | +} |
| 253 | + |
| 254 | +inline bool DPASInstruction::checkSupportedShapesAndTypes( |
| 255 | + std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape, |
| 256 | + std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape, |
| 257 | + Type AType, Type BType, Type CType, Type DType) { |
| 258 | + auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA); |
| 259 | + auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB); |
| 260 | + auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC); |
| 261 | + auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD); |
| 262 | + return llvm::is_contained(supportedAShapes, AShape) && |
| 263 | + llvm::is_contained(supportedBShapes, BShape) && |
| 264 | + llvm::is_contained(supportedCShapes, CShape) && |
| 265 | + llvm::is_contained(supportedDShapes, DShape) && |
| 266 | + checkSupportedTypes(AType, BType, CType, DType); |
| 267 | +} |
| 268 | + |
| 269 | +inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape, |
| 270 | + std::pair<uint32_t, uint32_t> BShape, |
| 271 | + std::pair<uint32_t, uint32_t> CShape, |
| 272 | + std::pair<uint32_t, uint32_t> DShape, |
| 273 | + Type AType, Type BType, Type CType, |
| 274 | + Type DType) { |
| 275 | + return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType, |
| 276 | + BType, CType, DType); |
| 277 | +} |
| 278 | + |
| 279 | +inline std::vector<uint32_t> DPASInstruction::getSupportedM(Type type) { |
| 280 | + return {1, 2, 3, 4, 5, 6, 7, 8}; |
| 281 | +} |
| 282 | + |
| 283 | +inline std::vector<uint32_t> DPASInstruction::getSupportedK(Type type) { |
| 284 | + // assert if data type is not int or float type |
| 285 | + assert(type.isIntOrFloat() && "Matrix type must be int or float"); |
| 286 | + auto bitWidth = type.getIntOrFloatBitWidth(); |
| 287 | + uint32_t kSize = 0; |
| 288 | + switch (bitWidth) { |
| 289 | + case 2: |
| 290 | + kSize = 64; |
| 291 | + break; |
| 292 | + case 4: |
| 293 | + kSize = 64; |
| 294 | + break; |
| 295 | + case 8: |
| 296 | + kSize = 32; |
| 297 | + break; |
| 298 | + case 16: |
| 299 | + kSize = 16; |
| 300 | + break; |
| 301 | + case 32: |
| 302 | + kSize = 8; |
| 303 | + break; |
| 304 | + default: |
| 305 | + llvm_unreachable("Invalid int or float"); |
| 306 | + } |
| 307 | + return {kSize}; |
| 308 | +} |
| 309 | + |
| 310 | +inline std::vector<uint32_t> DPASInstruction::getSupportedN(Type type) { |
| 311 | + return {16}; |
| 312 | +} |
| 313 | + |
166 | 314 | #endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2H
|
0 commit comments