-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR][XeGPU] Improve xegpu::uArch
design
#163986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Artem Kroviakov (akroviakov) ChangesThis PR improves Full diff: https://github.com/llvm/llvm-project/pull/163986.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 0519f7b2e277d..f264be5181b2a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -23,8 +23,6 @@
#include <map>
#include <string>
-#define DEBUG_TYPE "xegpu-uarch"
-
using namespace mlir;
using namespace mlir::xegpu::uArch;
@@ -33,21 +31,80 @@ namespace xegpu {
namespace uArch {
struct Xe2Plus : public uArch {
+ Xe2Plus(StringRef archName, StringRef archDescription,
+ llvm::ArrayRef<const Instruction *> instructionRegistry,
+ const XeCoreInfo &xeCore)
+ : uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
+ int getSubgroupSize() const override { return 16; }
+ unsigned getPackedFormatBitSize() const override { return 16; }
+ unsigned getPackedFormatBitSizeGatherScatter() const override { return 32; }
+
+protected:
XeCoreInfo xeCore;
- Xe2Plus(const std::string &archName, const std::string &archDescription,
- const XeCoreInfo &xeCore,
- const std::map<RegisterFileType, RegisterFileInfo> ®Info = {},
- const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
- const std::map<InstructionKind, std::shared_ptr<Instruction>>
- &instrs = {})
- : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
- xeCore(xeCore) {}
};
-// struct to represent DPAS instruction
+//===----------------------------------------------------------------------===//
+// uArch instructions
+//===----------------------------------------------------------------------===//
+struct StoreNdInstruction : public Instruction {
+ StoreNdInstruction()
+ : Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}
+
+ // Source :
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+ // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
+ // the specified pointer
+ llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
+ const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
+ return sortedLaneVectorLengths;
+ }
+};
+
+struct LoadNdInstruction : public Instruction {
+ LoadNdInstruction()
+ : Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}
+
+ // Source :
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+ // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
+ // the specified pointer.
+ llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
+ const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
+ return sortedLaneVectorLengths;
+ }
+};
+
+struct PrefetchNdInstruction : public Instruction {
+ PrefetchNdInstruction()
+ : Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}
+
+ // Source :
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
+ llvm::ArrayRef<int> getSortedLaneVectorLengths(int elementBitwidth) const {
+ const static int sortedNarrowTypesLengths[] = {1, 2, 4, 8, 16};
+ const static int sortedWideTypesLengths[] = {1, 2, 4, 8};
+ switch (elementBitwidth) {
+ case 8:
+ case 16:
+ return sortedNarrowTypesLengths;
+ case 32:
+ case 64:
+ return sortedWideTypesLengths;
+ default:
+ llvm_unreachable("Unsupported element bitwidth");
+ }
+ }
+};
+
struct DPASInstruction : public Instruction, public MMAInstructionInterface {
- DPASInstruction()
- : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
+ DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB,
+ unsigned packedFormatBitSizeC)
+ : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup),
+ packedFormatBitSizeA(packedFormatBitSizeA),
+ packedFormatBitSizeB(packedFormatBitSizeB),
+ packedFormatBitSizeC(packedFormatBitSizeC) {}
+ // Source:
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
// Override all virtuals from MatrixOpInterface
virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
@@ -67,82 +124,82 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
std::pair<uint32_t, uint32_t> CShape,
std::pair<uint32_t, uint32_t> DShape, Type AType,
Type BType, Type CType, Type DType) override;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) override;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) override;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
+ virtual llvm::SmallVector<uint32_t, 8>
+ getSupportedM(Type type) const override;
+ virtual llvm::SmallVector<uint32_t, 8>
+ getSupportedK(Type type) const override;
+ virtual llvm::SmallVector<uint32_t, 8>
+ getSupportedN(Type type) const override;
+
+ unsigned getPackedFormatBitSizeA() const { return packedFormatBitSizeA; }
+ unsigned getPackedFormatBitSizeB() const { return packedFormatBitSizeB; }
+ unsigned getPackedFormatBitSizeC() const { return packedFormatBitSizeC; }
+
+protected:
+ const unsigned packedFormatBitSizeA;
+ const unsigned packedFormatBitSizeB;
+ const unsigned packedFormatBitSizeC;
};
-struct PVCuArch : public Xe2Plus {
- // Maintaines ownership of the instructions owned by PVUarch
- llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
+//===----------------------------------------------------------------------===//
+// uArch instances
+//===----------------------------------------------------------------------===//
+
+struct PVCuArch final : public Xe2Plus {
+ inline static const DPASInstruction dpasInst{16, 32, 32};
+ inline static const StoreNdInstruction loadNdInst;
+ inline static const StoreNdInstruction storeNdInst;
+ inline static const PrefetchNdInstruction prefetchNdInst;
+ inline static const Instruction *const instructionRegistryArr[] = {
+ &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst};
PVCuArch()
: Xe2Plus("pvc", // archName
"Ponte Vecchio Architecture", // archDescription
- XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
- {/* registerFileInfo */}, // Optional: empty
- {/* cacheInfo */}, // Optional: empty
- {/* instructions */} // Optional: empty
- ) {
- // Intialize register file info
- // GRF
- this->registerFileInfo.emplace(
- RegisterFileType::GRF,
- RegisterFileInfo(
- 64 * 1024, // size in bits
- {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
- {128, 256} // registers per thread per mode
- ));
- // Initialize cache info
- // L1 cache, XeCore level
- this->cacheInfo.push_back(
- CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1));
- // L2 cache, XeStack level
- this->cacheInfo.push_back(
- CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
-
- // Add the instructions-
- auto dpas = std::make_shared<DPASInstruction>();
- instructions.emplace(dpas->getInstructionKind(), dpas);
- owned_instructions.push_back(dpas);
+ instructionRegistryArr,
+ XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore
+ ) {}
+ static const uArch *getInstance() {
+ static const PVCuArch instance;
+ return reinterpret_cast<const uArch *>(&instance);
}
};
struct BMGuArch : public Xe2Plus {
- // Maintaines ownership of the instructions owned by PVUarch
- llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
+ inline static const DPASInstruction dpasInst{16, 32, 32};
+ inline static const StoreNdInstruction loadNdInst;
+ inline static const StoreNdInstruction storeNdInst;
+ inline static const PrefetchNdInstruction prefetchNdInst;
+ inline static const Instruction *const instructionRegistryArr[] = {
+ &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst};
BMGuArch()
: Xe2Plus("bmg", // archName
"Battlemage Architecture", // archDescription
- XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
- {/* registerFileInfo */}, // Optional: empty
- {/* cacheInfo */}, // Optional: empty
- {/* instructions */} // Optional: empty
- ) {
- // Intialize register file info
- // GRF
- this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo(
- 64 * 1024, // size in bits
- {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
- {128, 256} // registers per thread per mode
- );
- // Initialize cache info
- // L1 cache, XeCore level
- this->cacheInfo.push_back(
- CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1));
- // L2 cache, XeStack level
- this->cacheInfo.push_back(
- CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2));
-
- // Add the instructions
- auto dpas = std::make_shared<DPASInstruction>();
- instructions.emplace(dpas->getInstructionKind(), dpas);
- owned_instructions.push_back(dpas);
+ instructionRegistryArr,
+ XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
+ ) {}
+ static const uArch *getInstance() {
+ static const BMGuArch instance;
+ return reinterpret_cast<const uArch *>(&instance);
}
};
+
+inline const uArch *getUArch(llvm::StringRef archName) {
+ if (archName.equals_insensitive("pvc"))
+ return PVCuArch::getInstance();
+ else if (archName.equals_insensitive("bmg"))
+ return BMGuArch::getInstance();
+
+ return nullptr;
+}
+
} // namespace uArch
} // namespace xegpu
} // namespace mlir
+//===----------------------------------------------------------------------===//
+// Instruction implementations
+//===----------------------------------------------------------------------===//
+
inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
@@ -257,12 +314,12 @@ inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
}
inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedM(Type type) {
+DPASInstruction::getSupportedM(Type type) const {
return {1, 2, 3, 4, 5, 6, 7, 8};
}
inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedK(Type type) {
+DPASInstruction::getSupportedK(Type type) const {
// assert if data type is not int or float type
assert(type.isIntOrFloat() && "Matrix type must be int or float");
auto bitWidth = type.getIntOrFloatBitWidth();
@@ -290,7 +347,7 @@ DPASInstruction::getSupportedK(Type type) {
}
inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedN(Type type) {
+DPASInstruction::getSupportedN(Type type) const {
return {16};
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 955994ea5ecf5..91f4b9eaeaf45 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -32,8 +32,11 @@ namespace uArch {
// An enum class to represent the scope of an instruction
enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
enum class InstructionKind {
- DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
- // multiply-add operation
+ DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
+ // multiply-add operation
+ STORE_ND, // Subgroup-level 2D block write instruction
+ LOAD_ND, // Subgroup-level 2D block load instruction
+ PREFETCH_ND // Subgroup-level 2D block prefetch instruction
// @TODO: Add more instructions as needed
};
@@ -48,12 +51,18 @@ struct Instruction {
virtual ~Instruction() = default;
// Get methods
- InstructionKind getInstructionKind() { return instKind; }
- InstructionScope getScope() { return scope; }
+ InstructionKind getInstructionKind() const { return instKind; }
+ InstructionScope getScope() const { return scope; }
static llvm::StringRef toString(InstructionKind instKind) {
switch (instKind) {
case InstructionKind::DPAS:
return "dpas";
+ case InstructionKind::STORE_ND:
+ return "store_nd";
+ case InstructionKind::LOAD_ND:
+ return "load_nd";
+ case InstructionKind::PREFETCH_ND:
+ return "prefetch_nd";
}
llvm_unreachable("Unknown InstructionKind");
}
@@ -66,9 +75,9 @@ struct Instruction {
}
protected:
- InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
- InstructionScope scope; // scope of the instruction (e.g., lane, subgroup,
- // workgroup, cluster)
+ const InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
+ const InstructionScope scope; // scope of the instruction (e.g., lane,
+ // subgroup, workgroup, cluster)
// @TODO: Add more fields as needed
};
@@ -129,61 +138,37 @@ struct CacheInfo {
// latency, throughput, bandwidth)
};
-// A struct to represent the uArch
-// This struct is used to represent the microarchitecture of a target device.
struct uArch {
// Constructor
- uArch(
- const std::string &name, const std::string &description,
- const std::map<RegisterFileType, RegisterFileInfo> ®isterFileInfo = {},
- const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
- const std::map<InstructionKind, std::shared_ptr<Instruction>>
- &instructions = {})
- : name(name), description(description),
- registerFileInfo(registerFileInfo), cacheInfo(cacheInfo),
- instructions(instructions) {}
-
- // Get methods
- const std::string &getName() const { return name; }
-
- const std::string &getDescription() const { return description; }
-
- const std::map<RegisterFileType, RegisterFileInfo> &
- getRegisterFileInfo() const {
- return registerFileInfo;
- }
-
- const llvm::SmallVector<CacheInfo, 4> &getCacheInfo() const {
- return cacheInfo;
- }
-
- const std::map<InstructionKind, std::shared_ptr<Instruction>> &
- getInstructions() const {
- return instructions;
+ uArch(StringRef name, StringRef description,
+ llvm::ArrayRef<const Instruction *> instructionRegistry)
+ : name(name), description(description) {
+ for (const Instruction *instr : instructionRegistry)
+ this->instructionRegistry[instr->getInstructionKind()] = instr;
}
-
- // Get the name of the supported instruction names for that
- // architecture. It returns the names of the instructions added to the uArch.
- llvm::SmallVector<StringRef, 8> getSupportedInstructionNames() const {
- llvm::SmallVector<StringRef, 8> instructionNames;
- for (const auto &inst : instructions) {
- instructionNames.push_back(Instruction::toString(inst.first));
- }
- return instructionNames;
+ virtual ~uArch() = default;
+ StringRef getName() const { return name; }
+ StringRef getDescription() const { return description; }
+ virtual int getSubgroupSize() const = 0;
+ virtual unsigned getPackedFormatBitSize() const = 0;
+ virtual unsigned getPackedFormatBitSizeGatherScatter() const = 0;
+
+ const Instruction *getInstruction(InstructionKind instKind) const {
+ auto it = instructionRegistry.find(instKind);
+ assert(it != instructionRegistry.end() &&
+ "Instruction not found in registry");
+ return it->second;
}
- // Checks if an instruction is supported in this uArch
- bool checkSupportedInstruction(InstructionKind instr) const {
- return instructions.find(instr) != instructions.end();
+ bool isSupportedInstruction(InstructionKind instr) const {
+ return instructionRegistry.contains(instr);
}
protected:
- std::string name; // Name of the uArch, similar to target triple
- std::string description;
- std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
- llvm::SmallVector<CacheInfo, 4> cacheInfo;
- std::map<InstructionKind, std::shared_ptr<Instruction>>
- instructions; // set of instructions supported by the uArch
+ StringRef name;
+ StringRef description;
+ llvm::SmallDenseMap<InstructionKind, const Instruction *, 32>
+ instructionRegistry;
};
// A struct to represent shared memory information
@@ -251,9 +236,9 @@ struct MMAInstructionInterface {
std::pair<uint32_t, uint32_t> CShape,
std::pair<uint32_t, uint32_t> DShape, Type AType,
Type BType, Type CType, Type DType) = 0;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) = 0;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) = 0;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) const = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) const = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) const = 0;
virtual ~MMAInstructionInterface() = default;
};
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
4fb6241
to
416eee1
Compare
This PR focuses on improving general
xegpu::uArch
design to make it both easier to understand and use (faster is optional). It is a part of #163801, where we applyxegpu::uArch
inside xegpu's distribution passes. I try to preserve parts that we need, but remain open to leaving grf and cache info in the base class.