Skip to content

Commit 82737ce

Browse files
committed
Address review comments.
Move all the implementation to the .h file. Move uArchInterfaces to uArchBase.
1 parent 22dbba0 commit 82737ce

File tree

10 files changed

+213
-284
lines changed

10 files changed

+213
-284
lines changed

mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h

Lines changed: 170 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,22 @@
1515
#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
1616
#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
1717

18-
#include "mlir/Dialect/XeGPU/uArch/uArchInterfaces.h"
18+
#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
1919
#include "mlir/IR/BuiltinTypes.h"
2020
#include "mlir/IR/TypeUtilities.h"
21+
#include "llvm/Support/DebugLog.h"
2122
#include <map>
2223
#include <string>
2324
#include <vector>
2425

26+
#define DEBUG_TYPE "xegpu-uarch"
27+
28+
using namespace mlir;
29+
using namespace mlir::xegpu::uArch;
30+
2531
namespace mlir {
2632
namespace xegpu {
2733
namespace uArch {
28-
namespace Xe2Plus {
2934
struct XeCoreInfo {
3035
uint32_t num_threads;
3136
SharedMemory shared_memory;
@@ -61,30 +66,27 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
6166

6267
// Override all virtuals from MatrixOpInterface
6368
virtual std::vector<std::pair<uint32_t, uint32_t>>
64-
getSupportedShapes(mlir::Type dataType, MMAOpndKind matrixType) override;
65-
virtual std::vector<mlir::Type>
66-
getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
69+
getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
70+
virtual std::vector<Type> getSupportedTypes(MLIRContext &context,
71+
MMAOpndKind matrixType) override;
6772
virtual bool
6873
checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
6974
std::pair<uint32_t, uint32_t> BShape,
7075
std::pair<uint32_t, uint32_t> CShape,
71-
std::pair<uint32_t, uint32_t> DShape,
72-
mlir::Type AType, mlir::Type BType,
73-
mlir::Type CType, mlir::Type DType) override;
74-
virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
75-
mlir::Type CType, mlir::Type DType) override;
76+
std::pair<uint32_t, uint32_t> DShape, Type AType,
77+
Type BType, Type CType, Type DType) override;
78+
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
79+
Type DType) override;
7680
virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
7781
std::pair<uint32_t, uint32_t> BShape,
7882
std::pair<uint32_t, uint32_t> CShape,
79-
std::pair<uint32_t, uint32_t> DShape, mlir::Type AType,
80-
mlir::Type BType, mlir::Type CType,
81-
mlir::Type DType) override;
82-
virtual std::vector<uint32_t> getSupportedM(mlir::Type type) override;
83-
virtual std::vector<uint32_t> getSupportedK(mlir::Type type) override;
84-
virtual std::vector<uint32_t> getSupportedN(mlir::Type type) override;
83+
std::pair<uint32_t, uint32_t> DShape, Type AType,
84+
Type BType, Type CType, Type DType) override;
85+
virtual std::vector<uint32_t> getSupportedM(Type type) override;
86+
virtual std::vector<uint32_t> getSupportedK(Type type) override;
87+
virtual std::vector<uint32_t> getSupportedN(Type type) override;
8588
};
8689

87-
namespace PVCuArch {
8890
struct PVCuArch : public Xe2Plus {
8991
// Maintaines ownership of the instructions owned by PVUarch
9092
std::vector<std::shared_ptr<Instruction>> owned_instructions;
@@ -120,9 +122,7 @@ struct PVCuArch : public Xe2Plus {
120122
owned_instructions.push_back(dpas);
121123
}
122124
};
123-
} // namespace PVCuArch
124125

125-
namespace BMGuArch {
126126
struct BMGuArch : public Xe2Plus {
127127
// Maintaines ownership of the instructions owned by PVUarch
128128
std::vector<std::shared_ptr<Instruction>> owned_instructions;
@@ -156,11 +156,159 @@ struct BMGuArch : public Xe2Plus {
156156
owned_instructions.push_back(dpas);
157157
}
158158
};
159-
} // namespace BMGuArch
160-
161-
} // namespace Xe2Plus
162159
} // namespace uArch
163160
} // namespace xegpu
164161
} // namespace mlir
165162

163+
inline std::vector<std::pair<uint32_t, uint32_t>>
164+
DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
165+
auto combineVectors = [](const std::vector<uint32_t> &a,
166+
const std::vector<uint32_t> &b)
167+
-> std::vector<std::pair<uint32_t, uint32_t>> {
168+
std::vector<std::pair<uint32_t, uint32_t>> result;
169+
for (unsigned x : a) {
170+
for (unsigned y : b) {
171+
result.emplace_back(x, y);
172+
}
173+
}
174+
return result;
175+
};
176+
177+
auto M = getSupportedM(dataType);
178+
auto K = getSupportedK(dataType);
179+
auto N = getSupportedN(dataType);
180+
std::vector<std::pair<unsigned, unsigned>> resultMatrix;
181+
182+
switch (matrixType) {
183+
case MMAOpndKind::MatrixA:
184+
resultMatrix = combineVectors(M, K);
185+
break;
186+
case MMAOpndKind::MatrixB:
187+
resultMatrix = combineVectors(K, N);
188+
break;
189+
case MMAOpndKind::MatrixC:
190+
resultMatrix = combineVectors(M, N);
191+
break;
192+
case MMAOpndKind::MatrixD:
193+
resultMatrix = combineVectors(M, N);
194+
break;
195+
}
196+
return resultMatrix;
197+
}
198+
199+
inline std::vector<Type>
200+
DPASInstruction::getSupportedTypes(MLIRContext &context,
201+
MMAOpndKind matrixType) {
202+
Type bf16Type = BFloat16Type::get(&context);
203+
Type f16Type = Float16Type::get(&context);
204+
Type tf32Type = FloatTF32Type::get(&context);
205+
Type f32Type = Float32Type::get(&context);
206+
207+
switch (matrixType) {
208+
case MMAOpndKind::MatrixA:
209+
return {bf16Type, f16Type, tf32Type};
210+
break;
211+
case MMAOpndKind::MatrixB:
212+
return {bf16Type, f16Type, tf32Type};
213+
break;
214+
case MMAOpndKind::MatrixC:
215+
return {bf16Type, f16Type, f32Type};
216+
break;
217+
case MMAOpndKind::MatrixD:
218+
return {bf16Type, f16Type, f32Type};
219+
break;
220+
}
221+
}
222+
223+
inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
224+
Type CType, Type DType) {
225+
if (AType.isF16() || BType.isF16()) {
226+
if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
227+
(!DType.isF32() && !DType.isF16())) {
228+
LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
229+
return false;
230+
}
231+
} else if (AType.isBF16() || BType.isBF16()) {
232+
if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
233+
(!DType.isF32() && !DType.isBF16())) {
234+
LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
235+
return false;
236+
}
237+
} else if (AType.isTF32() || BType.isTF32()) {
238+
if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
239+
(!DType.isF32())) {
240+
LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
241+
return false;
242+
}
243+
} else if (!(AType.isInteger(2) || AType.isInteger(4) ||
244+
AType.isInteger(8)) &&
245+
!(BType.isInteger(2) || BType.isInteger(4) ||
246+
BType.isInteger(8))) {
247+
LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
248+
return false;
249+
}
250+
251+
return true;
252+
}
253+
254+
inline bool DPASInstruction::checkSupportedShapesAndTypes(
255+
std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
256+
std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
257+
Type AType, Type BType, Type CType, Type DType) {
258+
auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
259+
auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
260+
auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
261+
auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
262+
return llvm::is_contained(supportedAShapes, AShape) &&
263+
llvm::is_contained(supportedBShapes, BShape) &&
264+
llvm::is_contained(supportedCShapes, CShape) &&
265+
llvm::is_contained(supportedDShapes, DShape) &&
266+
checkSupportedTypes(AType, BType, CType, DType);
267+
}
268+
269+
inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
270+
std::pair<uint32_t, uint32_t> BShape,
271+
std::pair<uint32_t, uint32_t> CShape,
272+
std::pair<uint32_t, uint32_t> DShape,
273+
Type AType, Type BType, Type CType,
274+
Type DType) {
275+
return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
276+
BType, CType, DType);
277+
}
278+
279+
inline std::vector<uint32_t> DPASInstruction::getSupportedM(Type type) {
280+
return {1, 2, 3, 4, 5, 6, 7, 8};
281+
}
282+
283+
inline std::vector<uint32_t> DPASInstruction::getSupportedK(Type type) {
284+
// assert if data type is not int or float type
285+
assert(type.isIntOrFloat() && "Matrix type must be int or float");
286+
auto bitWidth = type.getIntOrFloatBitWidth();
287+
uint32_t kSize = 0;
288+
switch (bitWidth) {
289+
case 2:
290+
kSize = 64;
291+
break;
292+
case 4:
293+
kSize = 64;
294+
break;
295+
case 8:
296+
kSize = 32;
297+
break;
298+
case 16:
299+
kSize = 16;
300+
break;
301+
case 32:
302+
kSize = 8;
303+
break;
304+
default:
305+
llvm_unreachable("Invalid int or float");
306+
}
307+
return {kSize};
308+
}
309+
310+
inline std::vector<uint32_t> DPASInstruction::getSupportedN(Type type) {
311+
return {16};
312+
}
313+
166314
#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2H

mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,49 @@ struct SharedMemory {
199199
// @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
200200
};
201201

202+
//===----------------------------------------------------------------------===//
203+
// Interfaces
204+
//===----------------------------------------------------------------------===//
205+
enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD };
206+
struct MMAInstructionInterface {
207+
// Get supported Matrix shapes
208+
virtual std::vector<std::pair<uint32_t, uint32_t>>
209+
getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
210+
// @TODO: This method takes an context object as a parameter, this is to
211+
// create the Type objects from the same context. Since type objects are
212+
// uniqued in a specific context, to do things like "aType == bType" (where
213+
// aType and bType are both same type) kind of checks, the both types should
214+
// be from the same context.
215+
//
216+
// One alternative to this is to create enum to represent each types, but this
217+
// adds an extra burden to user to convert these enums to specific types. In
218+
// fact the utility that would convert enumToType() and vice versa would still
219+
// have to use the context object.
220+
//
221+
// Untill we have a better solution, we stick to passing context object to
222+
// this method.
223+
virtual std::vector<Type> getSupportedTypes(MLIRContext &context,
224+
MMAOpndKind matrixType) = 0;
225+
virtual bool
226+
checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
227+
std::pair<uint32_t, uint32_t> BShape,
228+
std::pair<uint32_t, uint32_t> CShape,
229+
std::pair<uint32_t, uint32_t> DShape, Type AType,
230+
Type BType, Type CType, Type DType) = 0;
231+
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
232+
Type DType) = 0;
233+
virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
234+
std::pair<uint32_t, uint32_t> BShape,
235+
std::pair<uint32_t, uint32_t> CShape,
236+
std::pair<uint32_t, uint32_t> DShape, Type AType,
237+
Type BType, Type CType, Type DType) = 0;
238+
virtual std::vector<uint32_t> getSupportedM(Type type) = 0;
239+
virtual std::vector<uint32_t> getSupportedK(Type type) = 0;
240+
virtual std::vector<uint32_t> getSupportedN(Type type) = 0;
241+
242+
virtual ~MMAInstructionInterface() = default;
243+
};
244+
202245
} // namespace uArch
203246
} // namespace xegpu
204247
} // namespace mlir

mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h

Lines changed: 0 additions & 74 deletions
This file was deleted.

mlir/lib/Dialect/LLVMIR/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,5 @@ add_mlir_dialect_library(MLIRXeVMDialect
128128
MLIRDialectUtils
129129
MLIRIR
130130
MLIRLLVMDialect
131-
MLIRXeGPUuArch
132131
MLIRSideEffectInterfaces
133132
)

mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
99
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
1010
#include "mlir/Dialect/Utils/StaticValueUtils.h"
11-
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
1211
#include "mlir/IR/DialectImplementation.h"
1312
#include "llvm/ADT/TypeSwitch.h"
1413
#include "llvm/Support/FileSystem.h"
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
3-
add_subdirectory(uArch)
43
add_subdirectory(Utils)

mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ add_mlir_dialect_library(MLIRXeGPUDialect
1818
MLIRArithUtils
1919
MLIRDialectUtils
2020
MLIRIR
21-
MLIRXeGPUuArch
2221
MLIRViewLikeInterface
2322
MLIRVectorDialect
2423
)

0 commit comments

Comments
 (0)