diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h index 0519f7b2e277d..f264be5181b2a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h +++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h @@ -23,8 +23,6 @@ #include #include -#define DEBUG_TYPE "xegpu-uarch" - using namespace mlir; using namespace mlir::xegpu::uArch; @@ -33,21 +31,80 @@ namespace xegpu { namespace uArch { struct Xe2Plus : public uArch { + Xe2Plus(StringRef archName, StringRef archDescription, + llvm::ArrayRef instructionRegistry, + const XeCoreInfo &xeCore) + : uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {} + int getSubgroupSize() const override { return 16; } + unsigned getPackedFormatBitSize() const override { return 16; } + unsigned getPackedFormatBitSizeGatherScatter() const override { return 32; } + +protected: XeCoreInfo xeCore; - Xe2Plus(const std::string &archName, const std::string &archDescription, - const XeCoreInfo &xeCore, - const std::map ®Info = {}, - const llvm::SmallVector &cacheInfo = {}, - const std::map> - &instrs = {}) - : uArch(archName, archDescription, regInfo, cacheInfo, instrs), - xeCore(xeCore) {} }; -// struct to represent DPAS instruction +//===----------------------------------------------------------------------===// +// uArch instructions +//===----------------------------------------------------------------------===// +struct StoreNdInstruction : public Instruction { + StoreNdInstruction() + : Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {} + + // Source : + // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions + // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from + // the specified pointer + llvm::ArrayRef getSortedLaneVectorLengths() const { + const static int sortedLaneVectorLengths[] = {1, 2, 4, 8}; + return sortedLaneVectorLengths; + } +}; + +struct LoadNdInstruction : public Instruction { + LoadNdInstruction() + : Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {} + + // Source : + // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions + // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to + // the specified pointer. + llvm::ArrayRef getSortedLaneVectorLengths() const { + const static int sortedLaneVectorLengths[] = {1, 2, 4, 8}; + return sortedLaneVectorLengths; + } +}; + +struct PrefetchNdInstruction : public Instruction { + PrefetchNdInstruction() + : Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {} + + // Source : + // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions + llvm::ArrayRef getSortedLaneVectorLengths(int elementBitwidth) const { + const static int sortedNarrowTypesLengths[] = {1, 2, 4, 8, 16}; + const static int sortedWideTypesLengths[] = {1, 2, 4, 8}; + switch (elementBitwidth) { + case 8: + case 16: + return sortedNarrowTypesLengths; + case 32: + case 64: + return sortedWideTypesLengths; + default: + llvm_unreachable("Unsupported element bitwidth"); + } + } +}; + struct DPASInstruction : public Instruction, public MMAInstructionInterface { - DPASInstruction() - : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {} + DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB, + unsigned packedFormatBitSizeC) + : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup), + packedFormatBitSizeA(packedFormatBitSizeA), + packedFormatBitSizeB(packedFormatBitSizeB), + packedFormatBitSizeC(packedFormatBitSizeC) {} + // Source: + // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html // Override all virtuals from MatrixOpInterface virtual llvm::SmallVector, 16> @@ -67,82 +124,82 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface { std::pair CShape, std::pair DShape, Type AType, Type BType, Type CType, Type DType) override; - virtual llvm::SmallVector getSupportedM(Type type) override; - virtual llvm::SmallVector getSupportedK(Type type) override; - virtual llvm::SmallVector getSupportedN(Type type) override; + virtual llvm::SmallVector + getSupportedM(Type type) const override; + virtual llvm::SmallVector + getSupportedK(Type type) const override; + virtual llvm::SmallVector + getSupportedN(Type type) const override; + + unsigned getPackedFormatBitSizeA() const { return packedFormatBitSizeA; } + unsigned getPackedFormatBitSizeB() const { return packedFormatBitSizeB; } + unsigned getPackedFormatBitSizeC() const { return packedFormatBitSizeC; } + +protected: + const unsigned packedFormatBitSizeA; + const unsigned packedFormatBitSizeB; + const unsigned packedFormatBitSizeC; }; -struct PVCuArch : public Xe2Plus { - // Maintaines ownership of the instructions owned by PVUarch - llvm::SmallVector, 8> owned_instructions; +//===----------------------------------------------------------------------===// +// uArch instances +//===----------------------------------------------------------------------===// + +struct PVCuArch final : public Xe2Plus { + inline static const DPASInstruction dpasInst{16, 32, 32}; + inline static const StoreNdInstruction loadNdInst; + inline static const StoreNdInstruction storeNdInst; + inline static const PrefetchNdInstruction prefetchNdInst; + inline static const Instruction *const instructionRegistryArr[] = { + &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst}; PVCuArch() : Xe2Plus("pvc", // archName "Ponte Vecchio Architecture", // archDescription - XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore - {/* registerFileInfo */}, // Optional: empty - {/* cacheInfo */}, // Optional: empty - {/* instructions */} // Optional: empty - ) { - // Intialize register file info - // GRF - this->registerFileInfo.emplace( - RegisterFileType::GRF, - RegisterFileInfo( - 64 * 1024, // size in bits - {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes - {128, 256} // registers per thread per mode - )); - // Initialize cache info - // L1 cache, XeCore level - this->cacheInfo.push_back( - CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1)); - // L2 cache, XeStack level - this->cacheInfo.push_back( - CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2)); - - // Add the instructions- - auto dpas = std::make_shared(); - instructions.emplace(dpas->getInstructionKind(), dpas); - owned_instructions.push_back(dpas); + instructionRegistryArr, + XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore + ) {} + static const uArch *getInstance() { + static const PVCuArch instance; + return reinterpret_cast(&instance); } }; struct BMGuArch : public Xe2Plus { - // Maintaines ownership of the instructions owned by PVUarch - llvm::SmallVector, 8> owned_instructions; + inline static const DPASInstruction dpasInst{16, 32, 32}; + inline static const StoreNdInstruction loadNdInst; + inline static const StoreNdInstruction storeNdInst; + inline static const PrefetchNdInstruction prefetchNdInst; + inline static const Instruction *const instructionRegistryArr[] = { + &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst}; BMGuArch() : Xe2Plus("bmg", // archName "Battlemage Architecture", // archDescription - XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore - {/* registerFileInfo */}, // Optional: empty - {/* cacheInfo */}, // Optional: empty - {/* instructions */} // Optional: empty - ) { - // Intialize register file info - // GRF - this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo( - 64 * 1024, // size in bits - {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes - {128, 256} // registers per thread per mode - ); - // Initialize cache info - // L1 cache, XeCore level - this->cacheInfo.push_back( - CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1)); - // L2 cache, XeStack level - this->cacheInfo.push_back( - CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2)); - - // Add the instructions - auto dpas = std::make_shared(); - instructions.emplace(dpas->getInstructionKind(), dpas); - owned_instructions.push_back(dpas); + instructionRegistryArr, + XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore + ) {} + static const uArch *getInstance() { + static const BMGuArch instance; + return reinterpret_cast(&instance); } }; + +inline const uArch *getUArch(llvm::StringRef archName) { + if (archName.equals_insensitive("pvc")) + return PVCuArch::getInstance(); + else if (archName.equals_insensitive("bmg")) + return BMGuArch::getInstance(); + + return nullptr; +} + } // namespace uArch } // namespace xegpu } // namespace mlir +//===----------------------------------------------------------------------===// +// Instruction implementations +//===----------------------------------------------------------------------===// + inline llvm::SmallVector, 16> DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) { auto combineVectors = [](const llvm::SmallVector &a, @@ -257,12 +314,12 @@ inline bool DPASInstruction::validate(std::pair AShape, } inline llvm::SmallVector -DPASInstruction::getSupportedM(Type type) { +DPASInstruction::getSupportedM(Type type) const { return {1, 2, 3, 4, 5, 6, 7, 8}; } inline llvm::SmallVector -DPASInstruction::getSupportedK(Type type) { +DPASInstruction::getSupportedK(Type type) const { // assert if data type is not int or float type assert(type.isIntOrFloat() && "Matrix type must be int or float"); auto bitWidth = type.getIntOrFloatBitWidth(); @@ -290,7 +347,7 @@ DPASInstruction::getSupportedK(Type type) { } inline llvm::SmallVector -DPASInstruction::getSupportedN(Type type) { +DPASInstruction::getSupportedN(Type type) const { return {16}; } diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h index 955994ea5ecf5..82a5223c43651 100644 --- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h +++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h @@ -32,8 +32,11 @@ namespace uArch { // An enum class to represent the scope of an instruction enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster }; enum class InstructionKind { - DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix - // multiply-add operation + DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix + // multiply-add operation + STORE_ND, // Subgroup-level 2D block write instruction + LOAD_ND, // Subgroup-level 2D block load instruction + PREFETCH_ND // Subgroup-level 2D block prefetch instruction // @TODO: Add more instructions as needed }; @@ -46,14 +49,20 @@ struct Instruction { Instruction(InstructionKind kind, InstructionScope scope) : instKind(kind), scope(scope) {} - virtual ~Instruction() = default; + ~Instruction() = default; // Get methods - InstructionKind getInstructionKind() { return instKind; } - InstructionScope getScope() { return scope; } + InstructionKind getInstructionKind() const { return instKind; } + InstructionScope getScope() const { return scope; } static llvm::StringRef toString(InstructionKind instKind) { switch (instKind) { case InstructionKind::DPAS: return "dpas"; + case InstructionKind::STORE_ND: + return "store_nd"; + case InstructionKind::LOAD_ND: + return "load_nd"; + case InstructionKind::PREFETCH_ND: + return "prefetch_nd"; } llvm_unreachable("Unknown InstructionKind"); } @@ -66,9 +75,9 @@ struct Instruction { } protected: - InstructionKind instKind; // Specific InstructionKind (e.g., DPAS) - InstructionScope scope; // scope of the instruction (e.g., lane, subgroup, - // workgroup, cluster) + const InstructionKind instKind; // Specific InstructionKind (e.g., DPAS) + const InstructionScope scope; // scope of the instruction (e.g., lane, + // subgroup, workgroup, cluster) // @TODO: Add more fields as needed }; @@ -129,61 +138,37 @@ struct CacheInfo { // latency, throughput, bandwidth) }; -// A struct to represent the uArch -// This struct is used to represent the microarchitecture of a target device. struct uArch { // Constructor - uArch( - const std::string &name, const std::string &description, - const std::map ®isterFileInfo = {}, - const llvm::SmallVector &cacheInfo = {}, - const std::map> - &instructions = {}) - : name(name), description(description), - registerFileInfo(registerFileInfo), cacheInfo(cacheInfo), - instructions(instructions) {} - - // Get methods - const std::string &getName() const { return name; } - - const std::string &getDescription() const { return description; } - - const std::map & - getRegisterFileInfo() const { - return registerFileInfo; - } - - const llvm::SmallVector &getCacheInfo() const { - return cacheInfo; - } - - const std::map> & - getInstructions() const { - return instructions; + uArch(StringRef name, StringRef description, + llvm::ArrayRef instructionRegistry) + : name(name), description(description) { + for (const Instruction *instr : instructionRegistry) + this->instructionRegistry[instr->getInstructionKind()] = instr; } - - // Get the name of the supported instruction names for that - // architecture. It returns the names of the instructions added to the uArch. - llvm::SmallVector getSupportedInstructionNames() const { - llvm::SmallVector instructionNames; - for (const auto &inst : instructions) { - instructionNames.push_back(Instruction::toString(inst.first)); - } - return instructionNames; + virtual ~uArch() = default; + StringRef getName() const { return name; } + StringRef getDescription() const { return description; } + virtual int getSubgroupSize() const = 0; + virtual unsigned getPackedFormatBitSize() const = 0; + virtual unsigned getPackedFormatBitSizeGatherScatter() const = 0; + + const Instruction *getInstruction(InstructionKind instKind) const { + auto it = instructionRegistry.find(instKind); + assert(it != instructionRegistry.end() && + "Instruction not found in registry"); + return it->second; } - // Checks if an instruction is supported in this uArch - bool checkSupportedInstruction(InstructionKind instr) const { - return instructions.find(instr) != instructions.end(); + bool isSupportedInstruction(InstructionKind instr) const { + return instructionRegistry.contains(instr); } protected: - std::string name; // Name of the uArch, similar to target triple - std::string description; - std::map registerFileInfo; - llvm::SmallVector cacheInfo; - std::map> - instructions; // set of instructions supported by the uArch + StringRef name; + StringRef description; + llvm::SmallDenseMap + instructionRegistry; }; // A struct to represent shared memory information @@ -251,9 +236,9 @@ struct MMAInstructionInterface { std::pair CShape, std::pair DShape, Type AType, Type BType, Type CType, Type DType) = 0; - virtual llvm::SmallVector getSupportedM(Type type) = 0; - virtual llvm::SmallVector getSupportedK(Type type) = 0; - virtual llvm::SmallVector getSupportedN(Type type) = 0; + virtual llvm::SmallVector getSupportedM(Type type) const = 0; + virtual llvm::SmallVector getSupportedK(Type type) const = 0; + virtual llvm::SmallVector getSupportedN(Type type) const = 0; virtual ~MMAInstructionInterface() = default; };