diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h new file mode 100644 index 0000000000000..0519f7b2e277d --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h @@ -0,0 +1,297 @@ +//===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// Xe2 uArch definition. Xe2 is the second generation of Intel Xe GPUs. +// This file defines the uArch details for Xe2 and its derived architectures. +// This includes Ponte Vecchio (PVC) and Battlemage (BMG) architectures. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H +#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H + +#include "mlir/Dialect/XeGPU/uArch/uArchBase.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/DebugLog.h" +#include +#include + +#define DEBUG_TYPE "xegpu-uarch" + +using namespace mlir; +using namespace mlir::xegpu::uArch; + +namespace mlir { +namespace xegpu { +namespace uArch { + +struct Xe2Plus : public uArch { + XeCoreInfo xeCore; + Xe2Plus(const std::string &archName, const std::string &archDescription, + const XeCoreInfo &xeCore, + const std::map ®Info = {}, + const llvm::SmallVector &cacheInfo = {}, + const std::map> + &instrs = {}) + : uArch(archName, archDescription, regInfo, cacheInfo, instrs), + xeCore(xeCore) {} +}; + +// struct to represent DPAS instruction +struct DPASInstruction : public Instruction, public MMAInstructionInterface { + DPASInstruction() + : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {} + + // Override all virtuals from MatrixOpInterface + virtual llvm::SmallVector, 16> + getSupportedShapes(Type dataType, MMAOpndKind matrixType) override; + virtual llvm::SmallVector + getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override; + virtual bool + checkSupportedShapesAndTypes(std::pair AShape, + std::pair BShape, + std::pair CShape, + std::pair DShape, Type AType, + Type BType, Type CType, Type DType) override; + virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, + Type DType) override; + virtual bool validate(std::pair AShape, + std::pair BShape, + std::pair CShape, + std::pair DShape, Type AType, + Type BType, Type CType, Type DType) override; + virtual llvm::SmallVector getSupportedM(Type type) override; + virtual llvm::SmallVector getSupportedK(Type type) override; + virtual llvm::SmallVector getSupportedN(Type type) override; +}; + +struct PVCuArch : public Xe2Plus { + // Maintaines ownership of the instructions owned by PVUarch + llvm::SmallVector, 8> owned_instructions; + PVCuArch() + : Xe2Plus("pvc", // archName + "Ponte Vecchio Architecture", // archDescription + XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore + {/* registerFileInfo */}, // Optional: empty + {/* cacheInfo */}, // Optional: empty + {/* instructions */} // Optional: empty + ) { + // Intialize register file info + // GRF + this->registerFileInfo.emplace( + RegisterFileType::GRF, + RegisterFileInfo( + 64 * 1024, // size in bits + {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes + {128, 256} // registers per thread per mode + )); + // Initialize cache info + // L1 cache, XeCore level + this->cacheInfo.push_back( + CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1)); + // L2 cache, XeStack level + this->cacheInfo.push_back( + CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2)); + + // Add the instructions- + auto dpas = std::make_shared(); + instructions.emplace(dpas->getInstructionKind(), dpas); + owned_instructions.push_back(dpas); + } +}; + +struct BMGuArch : public Xe2Plus { + // Maintaines ownership of the instructions owned by PVUarch + llvm::SmallVector, 8> owned_instructions; + BMGuArch() + : Xe2Plus("bmg", // archName + "Battlemage Architecture", // archDescription + XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore + {/* registerFileInfo */}, // Optional: empty + {/* cacheInfo */}, // Optional: empty + {/* instructions */} // Optional: empty + ) { + // Intialize register file info + // GRF + this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo( + 64 * 1024, // size in bits + {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes + {128, 256} // registers per thread per mode + ); + // Initialize cache info + // L1 cache, XeCore level + this->cacheInfo.push_back( + CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1)); + // L2 cache, XeStack level + this->cacheInfo.push_back( + CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2)); + + // Add the instructions + auto dpas = std::make_shared(); + instructions.emplace(dpas->getInstructionKind(), dpas); + owned_instructions.push_back(dpas); + } +}; +} // namespace uArch +} // namespace xegpu +} // namespace mlir + +inline llvm::SmallVector, 16> +DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) { + auto combineVectors = [](const llvm::SmallVector &a, + const llvm::SmallVector &b) + -> llvm::SmallVector, 16> { + llvm::SmallVector, 16> result; + for (unsigned x : a) { + for (unsigned y : b) { + result.emplace_back(x, y); + } + } + return result; + }; + + auto M = getSupportedM(dataType); + auto K = getSupportedK(dataType); + auto N = getSupportedN(dataType); + llvm::SmallVector, 16> resultMatrix; + + switch (matrixType) { + case MMAOpndKind::MatrixA: + resultMatrix = combineVectors(M, K); + break; + case MMAOpndKind::MatrixB: + resultMatrix = combineVectors(K, N); + break; + case MMAOpndKind::MatrixC: + resultMatrix = combineVectors(M, N); + break; + case MMAOpndKind::MatrixD: + resultMatrix = combineVectors(M, N); + break; + } + return resultMatrix; +} + +inline llvm::SmallVector +DPASInstruction::getSupportedTypes(MLIRContext &context, + MMAOpndKind matrixType) { + Type bf16Type = BFloat16Type::get(&context); + Type f16Type = Float16Type::get(&context); + Type tf32Type = FloatTF32Type::get(&context); + Type f32Type = Float32Type::get(&context); + + switch (matrixType) { + case MMAOpndKind::MatrixA: + return {bf16Type, f16Type, tf32Type}; + case MMAOpndKind::MatrixB: + return {bf16Type, f16Type, tf32Type}; + case MMAOpndKind::MatrixC: + return {bf16Type, f16Type, f32Type}; + case MMAOpndKind::MatrixD: + return {bf16Type, f16Type, f32Type}; + } + return {}; +} + +inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType, + Type CType, Type DType) { + if (AType.isF16() || BType.isF16()) { + if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) || + (!DType.isF32() && !DType.isF16())) { + LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices."; + return false; + } + } else if (AType.isBF16() || BType.isBF16()) { + if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) || + (!DType.isF32() && !DType.isBF16())) { + LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices."; + return false; + } + } else if (AType.isTF32() || BType.isTF32()) { + if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) || + (!DType.isF32())) { + LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices."; + return false; + } + } else if (!(AType.isInteger(2) || AType.isInteger(4) || + AType.isInteger(8)) && + !(BType.isInteger(2) || BType.isInteger(4) || + BType.isInteger(8))) { + LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices."; + return false; + } + + return true; +} + +inline bool DPASInstruction::checkSupportedShapesAndTypes( + std::pair AShape, std::pair BShape, + std::pair CShape, std::pair DShape, + Type AType, Type BType, Type CType, Type DType) { + auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA); + auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB); + auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC); + auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD); + return llvm::is_contained(supportedAShapes, AShape) && + llvm::is_contained(supportedBShapes, BShape) && + llvm::is_contained(supportedCShapes, CShape) && + llvm::is_contained(supportedDShapes, DShape) && + checkSupportedTypes(AType, BType, CType, DType); +} + +inline bool DPASInstruction::validate(std::pair AShape, + std::pair BShape, + std::pair CShape, + std::pair DShape, + Type AType, Type BType, Type CType, + Type DType) { + return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType, + BType, CType, DType); +} + +inline llvm::SmallVector +DPASInstruction::getSupportedM(Type type) { + return {1, 2, 3, 4, 5, 6, 7, 8}; +} + +inline llvm::SmallVector +DPASInstruction::getSupportedK(Type type) { + // assert if data type is not int or float type + assert(type.isIntOrFloat() && "Matrix type must be int or float"); + auto bitWidth = type.getIntOrFloatBitWidth(); + uint32_t kSize = 0; + switch (bitWidth) { + case 2: + kSize = 64; + break; + case 4: + kSize = 64; + break; + case 8: + kSize = 32; + break; + case 16: + kSize = 16; + break; + case 32: + kSize = 8; + break; + default: + llvm_unreachable("Invalid int or float"); + } + return {kSize}; +} + +inline llvm::SmallVector +DPASInstruction::getSupportedN(Type type) { + return {16}; +} + +#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h new file mode 100644 index 0000000000000..48d2302994592 --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h @@ -0,0 +1,265 @@ +//===- uArch.h --------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// Base uArch definition for different architectures. +// +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H +#define MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H + +#include +#include +#include +#include +#include +#include +#include + +#include "mlir/IR/Types.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace xegpu { +namespace uArch { + +// An enum class to represent the scope of an instruction +enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster }; +enum class InstructionKind { + DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix + // multiply-add operation + // @TODO: Add more instructions as needed +}; + +llvm::StringRef toString(InstructionKind name) { + switch (name) { + case InstructionKind::DPAS: + return "dpas"; + } + llvm_unreachable("Unknown InstructionKind"); +} + +std::optional parseInstructionKind(llvm::StringRef str) { + if (str.equals_insensitive("dpas")) + return InstructionKind::DPAS; + return std::nullopt; +} + +// A struct to represent basic information about an instruction. +// The primary purpose of the Instruction struct is to provide a generic way to +// represent information about an instruction and to use this information to +// generate the uArch. Specifc instruction in a uArch can inherit from this +// struct and add more fields as needed. +struct Instruction { + Instruction(InstructionKind kind, InstructionScope scope) + : instKind(kind), scope(scope) {} + + virtual ~Instruction() = default; + // Get methods + InstructionKind getInstructionKind() { return instKind; } + InstructionScope getScope() { return scope; } + +protected: + InstructionKind instKind; // Specific InstructionKind (e.g., DPAS) + InstructionScope scope; // scope of the instruction (e.g., lane, subgroup, + // workgroup, cluster) + // @TODO: Add more fields as needed +}; + +enum class RegisterFileMode : uint8_t { Small, Large }; +enum class RegisterFileType : uint8_t { GRF, ARF }; + +// A struct to represent register file information +struct RegisterFileInfo { + // Constructor + RegisterFileInfo() = default; + RegisterFileInfo(uint32_t size, + const llvm::SmallVector &mode, + const llvm::SmallVector &numRegs) + : size(size), mode(mode), numRegsPerThreadPerMode(numRegs) {} + + // Get methods + uint32_t getSize() const { return size; } + + const llvm::SmallVector &getModes() const { + return mode; + } + + const llvm::SmallVector &getNumRegsPerThreadPerMode() const { + return numRegsPerThreadPerMode; + } + +protected: + uint32_t size; // size per register in bits + llvm::SmallVector + mode; // e.g., "small", "large" GRF modes + llvm::SmallVector + numRegsPerThreadPerMode; // number of registers per thread per mode +}; + +enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 }; + +// A struct to represent cache information +struct CacheInfo { + // Constructor + CacheInfo() = default; + CacheInfo(uint32_t size, uint32_t line_size, + CacheHierarchyLevel hierarchy_level) + : size(size), line_size(line_size), hierarchy_level(hierarchy_level) {} + + virtual ~CacheInfo() = default; + + // Get methods + uint32_t getSize() const { return size; } + uint32_t getLineSize() const { return line_size; } + CacheHierarchyLevel getHierarchyLevel() const { return hierarchy_level; } + +protected: + uint32_t size; + uint32_t line_size; + CacheHierarchyLevel hierarchy_level; + // @TODO: Add more fields as needed (e.g., associativity, num_banks, + // bank_size, num_ports, port_width, bank_conflicts, hierarchy_level, + // latency, throughput, bandwidth) +}; + +// A struct to represent the uArch +// This struct is used to represent the microarchitecture of a target device. +struct uArch { + // Constructor + uArch( + const std::string &name, const std::string &description, + const std::map ®isterFileInfo = {}, + const llvm::SmallVector &cacheInfo = {}, + const std::map> + &instructions = {}) + : name(name), description(description), + registerFileInfo(registerFileInfo), cacheInfo(cacheInfo), + instructions(instructions) {} + + // Get methods + const std::string &getName() const { return name; } + + const std::string &getDescription() const { return description; } + + const std::map & + getRegisterFileInfo() const { + return registerFileInfo; + } + + const llvm::SmallVector &getCacheInfo() const { + return cacheInfo; + } + + const std::map> & + getInstructions() const { + return instructions; + } + + // Get the name of the supported instruction names for that + // architecture. It returns the names of the instructions added to the uArch. + llvm::SmallVector getSupportedInstructionNames() const { + llvm::SmallVector instructionNames; + for (const auto &inst : instructions) { + instructionNames.push_back(toString(inst.first)); + } + return instructionNames; + } + + // Checks if an instruction is supported in this uArch + bool checkSupportedInstruction(InstructionKind instr) const { + return instructions.find(instr) != instructions.end(); + } + +protected: + std::string name; // Name of the uArch, similar to target triple + std::string description; + std::map registerFileInfo; + llvm::SmallVector cacheInfo; + std::map> + instructions; // set of instructions supported by the uArch +}; + +// A struct to represent shared memory information +struct SharedMemory { + // Constructor + SharedMemory(uint32_t size, uint32_t alignment) + : size(size), alignment(alignment) {} + + // Get methods + uint32_t getSize() const { return size; } + uint32_t getAlignment() const { return alignment; } + +protected: + uint32_t size; // in bytes + uint32_t alignment; // in bytes + // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth) +}; + +struct XeCoreInfo { + uint32_t num_threads; + SharedMemory shared_memory; + uint32_t num_vector_units; + uint32_t num_matrix_units; + + XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory, + uint32_t num_vector_units, uint32_t num_matrix_units) + : num_threads(num_threads), shared_memory(shared_memory), + num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) { + } +}; + +//===----------------------------------------------------------------------===// +// Interfaces +//===----------------------------------------------------------------------===// +enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD }; +struct MMAInstructionInterface { + // Get supported Matrix shapes + virtual llvm::SmallVector, 16> + getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0; + // @TODO: This method takes an context object as a parameter, this is to + // create the Type objects from the same context. Since type objects are + // uniqued in a specific context, to do things like "aType == bType" (where + // aType and bType are both same type) kind of checks, the both types should + // be from the same context. + // + // One alternative to this is to create enum to represent each types, but this + // adds an extra burden to user to convert these enums to specific types. In + // fact the utility that would convert enumToType() and vice versa would still + // have to use the context object. + // + // Untill we have a better solution, we stick to passing context object to + // this method. + virtual llvm::SmallVector + getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) = 0; + virtual bool + checkSupportedShapesAndTypes(std::pair AShape, + std::pair BShape, + std::pair CShape, + std::pair DShape, Type AType, + Type BType, Type CType, Type DType) = 0; + virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, + Type DType) = 0; + virtual bool validate(std::pair AShape, + std::pair BShape, + std::pair CShape, + std::pair DShape, Type AType, + Type BType, Type CType, Type DType) = 0; + virtual llvm::SmallVector getSupportedM(Type type) = 0; + virtual llvm::SmallVector getSupportedK(Type type) = 0; + virtual llvm::SmallVector getSupportedN(Type type) = 0; + + virtual ~MMAInstructionInterface() = default; +}; + +} // namespace uArch +} // namespace xegpu +} // namespace mlir + +#endif // MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index d997296a22c20..659a98abfeda6 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h"