diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td index 3a88dae041dd1..b3848ffb0c661 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td @@ -71,4 +71,5 @@ def XeGPUBlocking: Pass<"xegpu-blocking"> { ]; } + #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h new file mode 100644 index 0000000000000..d647dddb6f8e6 --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h @@ -0,0 +1,223 @@ +//===--- IntelGpuXe2.h ---------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// Xe2 uArch definition. +/// +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_XEGPU_UTILS_INTEL_GPU_XE2_H +#define MLIR_DIALECT_XEGPU_UTILS_INTEL_GPU_XE2_H + +#include "mlir/Dialect/XeGPU/uArch/uArchInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include +#include +#include + +namespace mlir { +namespace xegpu { +namespace uArch { +namespace Xe2Plus { +struct XeCoreInfo { + uint32_t num_threads; + SharedMemory shared_memory; + uint32_t num_vector_units; + uint32_t num_matrix_units; + + // Constructor + XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory, + uint32_t num_vector_units, uint32_t num_matrix_units) + : num_threads(num_threads), shared_memory(shared_memory), + num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) { + } +}; + +struct Xe2Plus : public uArch { + XeCoreInfo xe_core; + Xe2Plus( + const std::string &archName, const std::string &archDescription, + const XeCoreInfo &xeCore, + const std::vector &hierarchy = {}, + const std::map ®Info = {}, + const std::vector &cacheInfo = {}, + const std::map> &instrs = {}) + : uArch(archName, archDescription, hierarchy, regInfo, cacheInfo, instrs), + xe_core(xeCore) {} +}; + +// struct to represent DPAS instruction +struct DPASInstruction : public Instruction, public MMAOpInterface { + DPASInstruction() + : Instruction("dpas", // name + "Dot Product Accumulate") // description + {} + + // Override all virtuals from MatrixOpInterface + virtual std::vector> + getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) override; + virtual std::vector + getSupportedTypes(MLIRContext &context, MMAOpndEnum matrixType) override; + virtual bool + checkSupportedShapesAndTypes(std::pair AShape, + std::pair BShape, + std::pair CShape, + std::pair DShape, + mlir::Type AType, mlir::Type BType, + mlir::Type CType, mlir::Type DType) override; + virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType, + mlir::Type CType, mlir::Type DType) override; + virtual bool validate(std::pair AShape, + std::pair BShape, + std::pair CShape, + std::pair DShape, mlir::Type AType, + mlir::Type BType, mlir::Type CType, + mlir::Type DType) override; + virtual std::vector getSupportedM(mlir::Type type) override; + virtual std::vector getSupportedK(mlir::Type type) override; + virtual std::vector getSupportedN(mlir::Type type) override; +}; + +// struct to represent Load2D/Store2D/Prefetch instruction +struct LoadStorePrefetch2DInstruction : public Instruction { + MemoryType memory_type; + // MemoryAccessType memory_access_type; + // std::vector supported_types; + std::vector supported_types_bitwidth; + std::map alignment; + std::vector> supported_tile_sizes; + uint32_t min_surface_pitch; + + // Validate Array length restriction on a given tile + bool validateArrayLenRestriction(std::vector tile, + uint32_t array_len, mlir::Type dataType) { + + Restriction, uint32_t, mlir::Type> + width_array_len_restriction( + tile, array_len, dataType, + [](std::vector tile, uint32_t array_len, + mlir::Type dataType) { + assert(tile.size() == 2); + return tile[1] * array_len * + (dataType.getIntOrFloatBitWidth() / 8) <= + 64; + }); + return width_array_len_restriction.validate(); + } + + // Validate Surface Pitch restriction on a given tile + bool validateSurfacePitchRestriction(std::vector tile, + uint32_t surfacePitch /*in bytes*/) { + Restriction, uint32_t> surface_pitch_restriction( + tile, surfacePitch, + [](std::vector tile, uint32_t surfacePitch) { + assert(tile.size() == 2); + return surfacePitch >= 64; + }); + return surface_pitch_restriction.validate(); + } +}; + +namespace PVCuArch { +struct PVCuArch : public Xe2Plus { + // Maintaines ownership of the instructions owned by PVUarch + std::vector> owned_instructions; + PVCuArch() + : Xe2Plus("pvc", // archName + "Ponte Vecchio Architecture", // archDescription + XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore + {/* register_file_info */}, // Optional: empty + {/* cache_info */}, // Optional: empty + {/* instructions */}, // Optional: empty + {/* restrictions */} // Optional: empty + ) { + // Initialize uArchHierarchy + this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 16)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 4)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 2)); + // Intialize register file info + // GRF + this->register_file_info.emplace( + "GRF", + RegisterFileInfo(64 * 1024, // size in bits + {"small", "large"}, // GRF modes + {128, 256}, // registers per thread per mode + 0, // number of banks + 0 // bank size + )); + // Initialize cache info + // L1 cache, XeCore level + this->cache_info.push_back( + CacheInfo(512 * 1024, 64, this->uArch_hierarchy[1])); + // L3 cache, XeStack level + this->cache_info.push_back( + CacheInfo(512 * 1024, 64, this->uArch_hierarchy[3])); + + // Add the instructions + auto dpas = std::make_shared(); + instructions.emplace(dpas->getName(), dpas); + // instructions[dpas->name] = dpas.get(); + owned_instructions.push_back(dpas); + } +}; +} // namespace PVCuArch + +namespace BMGuArch { +struct BMGuArch : public Xe2Plus { + // Maintaines ownership of the instructions owned by PVUarch + std::vector> owned_instructions; + BMGuArch() + : Xe2Plus("bmg", // archName + "Battlemage Architecture", // archDescription + XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore + {/* register_file_info */}, // Optional: empty + {/* cache_info */}, // Optional: empty + {/* instructions */}, // Optional: empty + {/* restrictions */} // Optional: empty + ) { + // Initialize uArchHierarchy + this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 4)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 5)); + this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 1)); + // Intialize register file info + // GRF + this->register_file_info["GRF"] = + RegisterFileInfo(64 * 1024, // size in bits + {"small", "large"}, // GRF modes + {128, 256}, // registers per thread per mode + 0, // number of banks + 0 // bank size + ); + // Initialize cache info + // L1 cache, XeCore level + this->cache_info.push_back( + CacheInfo(256 * 1024, 64, this->uArch_hierarchy[1])); + // L3 cache, XeStack level + this->cache_info.push_back( + CacheInfo(18 * 1024 * 1024, 256, this->uArch_hierarchy[3])); + + // Add the instructions + auto dpas = std::make_shared(); + instructions.emplace(dpas->getName(), dpas); + // instructions[dpas->name] = dpas.get(); + owned_instructions.push_back(dpas); + } +}; +} // namespace BMGuArch + +} // namespace Xe2Plus +} // namespace uArch +} // namespace xegpu +} // namespace mlir + +#endif // MLIR_DIALECT_XEGPU_UTILS_INTEL_GPU_XE2_H diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h new file mode 100644 index 0000000000000..13c4fd8638ab6 --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h @@ -0,0 +1,333 @@ +//===--- uArch.h ---------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// Base uArch definition for different architectures. +/// +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_XEGPU_UTILS_UARCH_BASE_H +#define MLIR_DIALECT_XEGPU_UTILS_UARCH_BASE_H + +#include +#include +#include +#include +#include +#include +#include + +#include "mlir/IR/Types.h" + +namespace mlir { +namespace xegpu { +namespace uArch { +// Restriction struct +// This struct is used to represent a restriction on the uArch +// The restriction is represented as a range of necessary parameters (template +// arguments) and a lambda function (validate()) that takes the same number of +// arguments as the number of template arguments The lambda function returns +// true if the arguments satisfy the restriction The lambda function returns +// false if the arguments do not satisfy the restriction + +// For example, a restriction that checks if the number of dimensions in a +// std::vector> is 2 can be represented as: +// std::vector> rt = +// {{1, 32}, {2, 16}}; Restriction>> r1(rt, +// [](std::vector> t) { return t.size() == 2; }); +// r1.validate() will return true if the number of dimensions in the +// std::vector> is 2 r1.validate() will return false if +// the number of dimensions in the std::vector> is not 2 + +// The primary purpose of Restriction struct is to provide a generic way to +// represent restrictions on the uArch and to validate if the uArch satisfies +// the restrictions +template +struct Restriction { + std::tuple data; + std::function func; + + Restriction(Args... args, std::function f) + : data(args...), func(f) {} + + bool validate() { return std::apply(func, data); } + std::any apply() { return std::apply(func, data); } +}; + +// Architecture HW component hierarchy to present thread, core, socket ... +struct uArchHierarchyComponent { + std::string name = ""; // optional name of the hierarchy component + // no. of lower hierarchy component it contains, e.g., for PVC XeCore it + // contains 8 threads, so no_of_component=8 + uint32_t no_of_component; + // Constructor + uArchHierarchyComponent(const std::string &name, uint32_t no_of_component) + : name(name), no_of_component(no_of_component) {} +}; + +// An enum class to represent the type of memory +enum class MemoryType { Shared, Local, Global, Constant, Texture, Other }; + +// An enum class to represent the scope of an instruction +enum class InstructionScopeEnum { WorkItem, Subgroup, Workgroup, Cluster }; + +// A struct to represent basic information about an instruction +// This struct is used to represent the information about an instruction in the +// uArch The information includes: +// - the name of the instruction, +// - the description of the instruction +// - the scope of the instruction, +// +// The information is represented as strings +// For example, the information about an instruction can be represented as: +// Instruction instr = {"dpas", "Dot Product Accumulate Systolic (DPAS) is a +// matrix multiply-add operation", "subgroup"}; + +// The primary purpose of Instruction struct is to provide a generic way to +// represent information about an instruction and to use this information to +// generate the uArch. Specifc instruction in a uArch can inherit from this +// struct and add more fields as needed + +struct Instruction { + // @TODO: Add more fields as needed + Instruction(std::string name, std::string desc) + : name(std::move(name)), description(std::move(desc)) {} + + virtual ~Instruction() = default; + // Get methods + std::string getName() { return name; } + std::string getDescription() { return description; } + InstructionScopeEnum getScope() { return scope; } + +protected: + std::string name; + std::string description; + InstructionScopeEnum scope; +}; + +// A struct to represent register file information +struct RegisterFileInfo { + // Constructor + RegisterFileInfo() = default; + RegisterFileInfo(uint32_t size, const std::vector &mode, + const std::vector &numRegs, uint32_t num_banks, + uint32_t bank_size) + : size(size), mode(mode), num_regs_per_thread_per_mode(numRegs), + num_banks(num_banks), bank_size(bank_size) {} + + // Get methods + uint32_t getSize() const { return size; } + + const std::vector &getModes() const { return mode; } + + const std::vector &getNumRegsPerThreadPerMode() const { + return num_regs_per_thread_per_mode; + } + + uint32_t getNumBanks() const { return num_banks; } + + uint32_t getBankSize() const { return bank_size; } + +protected: + uint32_t size; // size per register in bits + std::vector mode; // e.g., "small", "large" GRF modes + std::vector + num_regs_per_thread_per_mode; // number of registers per thread per mode + uint32_t num_banks; + uint32_t bank_size; +}; + +// A struct to represent cache information + +struct CacheInfo { + // Constructor + CacheInfo(uint32_t size, uint32_t line_size, + const uArchHierarchyComponent &component) + : size(size), line_size(line_size), component(component) {} + + virtual ~CacheInfo() = default; + + // Get methods + uint32_t getSize() const { return size; } + uint32_t getLineSize() const { return line_size; } + const uArchHierarchyComponent &getComponent() const { return component; } + +protected: + uint32_t size; + uint32_t line_size; + // At which component level the cache is shared + uArchHierarchyComponent component; + + // @TODO: Add more fields as needed (e.g., associativity, num_banks, + // bank_size, num_ports, port_width, bank_conflicts) +}; + +// A struct to represent the uArch +// This struct is used to represent the microarchitecture of a target device +// The uArch includes: +// - the name of the uArch, +// - the description of the uArch, +// - uArch hierarchy +// - Rgister File information +// - Cache information +// - the set of instructions supported by the uArch, +struct uArch { + // Constructor + uArch() = default; + uArch(const std::string &name, const std::string &description, + const std::vector &uArch_hierarchy = {}, + const std::map ®ister_file_info = {}, + const std::vector &cache_info = {}, + const std::map> + &instructions = {}, + const std::vector *> &restrictions = {}) + : name(name), description(description), uArch_hierarchy(uArch_hierarchy), + register_file_info(register_file_info), cache_info(cache_info), + instructions(instructions) {} + + // Get methods + const std::string &getName() const { return name; } + + const std::string &getDescription() const { return description; } + + const std::vector &getHierarchy() const { + return uArch_hierarchy; + } + + const std::map &getRegisterFileInfo() const { + return register_file_info; + } + + const std::vector &getCacheInfo() const { return cache_info; } + + const std::map> & + getInstructions() const { + return instructions; + } + + /// @brief Get the name of the supported instruction names for that + /// architecture. It returns the names of the instructions added to the uArch. + std::vector getSupportedInstructionNames() const { + std::vector instructionNames; + for (const auto &inst : instructions) { + instructionNames.push_back(inst.first); + } + return instructionNames; + } + + /// @brief Checks if an instruction is supported in this uArch + /// @param instructionName + /// @return true if supported, false otherwise + bool checkSupportedInstruction(const std::string &instructionName) const { + return instructions.find(instructionName) != instructions.end(); + } + +protected: + std::string name; // Similar to target triple + std::string description; + std::vector uArch_hierarchy; + std::map register_file_info; + std::vector cache_info; + std::map> instructions; +}; + +// A struct to represent shared memory information +struct SharedMemory { + // Constructor + SharedMemory(uint32_t size, uint32_t alignment) + : size(size), alignment(alignment) {} + + // Getters + uint32_t getSize() const { return size; } + uint32_t getAlignment() const { return alignment; } + +protected: + uint32_t size; // in bytes + uint32_t alignment; // in bytes + // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth) +}; + +// For future use case in Xe4+ + +// struct EUInfo { +// uint32_t num_eu_threads; +// SharedMemory shared_memory; +// }; + +// uint32_t num_simd_units; +// uint32_t num_spus; +// uint32_t num_smt; +// uint32_t num_hardware_threads; +// uint32_t num_threads_per_spu; +// uint32_t num_threads_per_simd_unit; +// uint32_t num_threads_per_hardware_thread; +// uint32_t num_threads_per_smt; +// SharedMemory shared_memory; +// }; + +// A struct to represent a GPU uArch +// This struct is used to represent the GPU microarchitecture of a target device +// struct GPUuArch : public uArch { +// uint32_t num_compute_units; +// uint32_t num_vector_units; +// uint32_t num_scalar_units; +// uint32_t num_tensor_units; +// uint32_t num_matrix_units; +// SharedMemory shared_memory; +// }; + +struct uArchMap { +public: + // Singleton instance + static uArchMap &instance() { + static uArchMap instance; + return instance; + } + + // Insert or update a key-value pair + void insert(const std::string &key, std::shared_ptr value) { + std::unique_lock lock(mutex_); + // map_[key] = std::move(value); // safe to overwrite + map_.emplace(key, std::move(value)); // safe to overwrite + } + + // Get a value by key (concurrent safe read) + std::shared_ptr get(const std::string &key) const { + std::shared_lock lock(mutex_); + auto it = map_.find(key); + if (it != map_.end()) + return it->second; + return nullptr; + } + + // Check if a key exists + bool contains(const std::string &key) const { + std::shared_lock lock(mutex_); + return map_.find(key) != map_.end(); + } + + // Remove a key + bool erase(const std::string &key) { + std::unique_lock lock(mutex_); + return map_.erase(key) > 0; + } + +private: + uArchMap() = default; + uArchMap(const uArchMap &) = delete; + uArchMap &operator=(const uArchMap &) = delete; + + mutable std::shared_mutex mutex_; + std::map> map_; +}; + +} // namespace uArch +} // namespace xegpu +} // namespace mlir + +#endif // MLIR_DIALECT_XEGPU_UTILS_UARCH_H diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h new file mode 100644 index 0000000000000..77207d0714d0c --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h @@ -0,0 +1,109 @@ +//===--- uArchInterfaces.h ---*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// Defines the utility interfaces that are implemented by individual +/// instructions. +/// +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_XEGPU_UTILS_UARCH_INTERFACES_H +#define MLIR_DIALECT_XEGPU_UTILS_UARCH_INTERFACES_H + +#include "mlir/Dialect/XeGPU/uArch/uArchBase.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include +#include +#include + +namespace mlir { +namespace xegpu { +namespace uArch { + +// Create a BlockIOOp Interface +struct BlockIO2DOpInterface { + // Get the supported shapes for the specific data type. + // Can provide load/store/prefetch ops supported shapes for a specific + // uarch + virtual std::vector> getSupportedShapes( + mlir::Type type, bool isTrnasform = false /*VNNI transform bit*/, + bool isTranspose = false /*transpose bit*/, + uint32_t transpose_bitwidth = 32 /*transpose bitwidth */) = 0; + + // Get supported types + virtual std::vector getSupportedTypes(MLIRContext &context) = 0; + // Checks if a shape is supported + virtual bool checkSupportedShapesAndTypes(std::vector shape, + mlir::Type type) = 0; + // Checks if a type is type is supported + virtual bool checkSupportedTypes(mlir::Type type) = 0; + + // Validate the BlockIO2D ops restrictions + // @param blockSize, size of the load/store/prefetch block + // @param surfaceSize, size of the load/store/prefetch surface + // @param dataType, data type of the data + // @param alignment, alignment + // @param surface_pitch, suface pitch + // @param array_len, array length + virtual bool validate(std::vector blockSize, + std::vector surfaceSize, mlir::Type dataType, + uint32_t alignment, uint32_t surface_pitch, + uint32_t array_len = 1) = 0; + virtual ~BlockIO2DOpInterface() = default; +}; + +enum class MMAOpndEnum { MatrixA, MatrixB, MatrixC, MatrixD }; +struct MMAOpInterface { + // Get supported Matrix type + // @param dataType, data type of the matrix + // @param matrixType, Matrix type (Matrix A, B, C, or D) + virtual std::vector> + getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) = 0; + + // @TODO: This method takes an context object as a parameter, this is to + // create the type objects from the same context. Since type objects are + // uniqued in a specific context, to do things like "aType == bType" (where + // aType and bType are both same type) kind of checks, the both types should + // be from the same context. + // + // One alternative to this is to create enum to represent each types, but this + // adds an extra burden to user to convert these enums to specific types. In + // fact the utility that would convert enumToType() and vice versa would still + // have to use the context object. + // + // Untill we have a better solution, we stick to passing context object to + // this method. + virtual std::vector getSupportedTypes(MLIRContext &context, + MMAOpndEnum matrixType) = 0; + virtual bool + checkSupportedShapesAndTypes(std::pair AShape, + std::pair BShape, + std::pair CShape, + std::pair DShape, + mlir::Type AType, mlir::Type BType, + mlir::Type CType, mlir::Type DType) = 0; + virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType, + mlir::Type CType, mlir::Type DType) = 0; + virtual bool validate(std::pair AShape, + std::pair BShape, + std::pair CShape, + std::pair DShape, mlir::Type AType, + mlir::Type BType, mlir::Type CType, + mlir::Type DType) = 0; + virtual std::vector getSupportedM(mlir::Type type) = 0; + virtual std::vector getSupportedK(mlir::Type type) = 0; + virtual std::vector getSupportedN(mlir::Type type) = 0; + + virtual ~MMAOpInterface() = default; +}; + +} // namespace uArch +} // namespace xegpu +} // namespace mlir +#endif // MLIR_DIALECT_XEGPU_UTILS_UARCH_INTERFACES_H diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt index 31167e6af908b..9079df050ab2b 100644 --- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(uArch) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt index 242a97ccfdf6d..5393b9b7b1c6f 100644 --- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRXeGPUDialect MLIRArithUtils MLIRDialectUtils MLIRIR + MLIRXeGPUuArch MLIRViewLikeInterface MLIRVectorDialect ) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 642c393cbc2c8..4dd3291a3f541 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" @@ -31,6 +32,14 @@ void XeGPUDialect::initialize() { #define GET_ATTRDEF_LIST #include >(); + + // Populate the uArchMap with the supported target devices + auto pvcuArch = + std::make_shared(); + mlir::xegpu::uArch::uArchMap::instance().insert("pvc", pvcuArch); + auto bmguArch = + std::make_shared(); + mlir::xegpu::uArch::uArchMap::instance().insert("bmg", bmguArch); } // Checks if the given shape can be evenly distributed based on the layout diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index ef7cd1424e7a4..39cfa9e708881 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -7,9 +7,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" @@ -575,6 +579,62 @@ LogicalResult DpasOp::verify() { if (getAcc() && getAcc().getType() != getResultType()) return emitOpError("Expecting the acc type to be the same as result."); + // @uArch: Check if the types are supported for DPAS. + Operation *op = getOperation(); + + // Use XeVM target + auto gpuModuleOp = op->getParentOfType(); + xevm::XeVMTargetAttr xevmAttr = nullptr; + if (gpuModuleOp) { + auto targetAttrs = gpuModuleOp.getTargets(); + if (targetAttrs) { + // Look up XeVM attribute + for (auto &attr : *targetAttrs) { + xevmAttr = dyn_cast(attr); + if (!xevmAttr) { + LLVM_DEBUG(llvm::dbgs() << "No target device found, skipping " + "target-specific verification\n"); + } + } + } + } + + // It target device info is not attched, skip the target-specific checks + // Potential usage of uArch in verification. + if (xevmAttr) { + auto targetDeviceNameStr = xevmAttr.getChip().str(); + auto targetDeviceArch = + mlir::xegpu::uArch::uArchMap::instance().get(targetDeviceNameStr); + if (targetDeviceArch) { + // @TODO: We should keep the name of the Instructions in one place, since + // we use the name of the instruction to find the instruction, it should + // be standardized and kept for users to access. + + // One could use the find mechanism of std::map to find if an instruction + // is supported or not + // + // auto it = targetDeviceArch->instructions.find("dpas"); if (it != + // targetDeviceArch->instructions.end()) + // + // Alternatively, one could use uARch provided method to do so + if (targetDeviceArch->checkSupportedInstruction("dpas")) { + auto supportedInstructions = targetDeviceArch->getInstructions(); + std::shared_ptr instr = + supportedInstructions["dpas"]; + auto matrixOp = + std::dynamic_pointer_cast( + instr); + if (matrixOp) { + if (!matrixOp->checkSupportedTypes(getLhsType().getElementType(), + getRhsType().getElementType(), + getResultType().getElementType(), + getResultType().getElementType())) + return emitOpError("Unsupported DPAS types."); + } + } + } + } + // SIMT code: the size of the B operand has to be a multiple of 32 bits. // It skips the semantic check since lack of architecture information. // Users need to ensure the correctness. diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt index 9c178d1d85642..63acd30646764 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt @@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms MLIRTransforms MLIRGPUDialect MLIRXeGPUUtils + MLIRXeGPUuArch MLIRGPUUtils MLIRVectorTransforms ) diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt index 98e84a4420722..8fa908087c0ae 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt @@ -8,4 +8,5 @@ add_mlir_dialect_library(MLIRXeGPUUtils MLIRIR MLIRSCFTransforms MLIRXeGPUDialect - ) +) + diff --git a/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt new file mode 100644 index 0000000000000..c7f691cb6dda7 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_dialect_library(MLIRXeGPUuArch + IntelGpuXe2.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU/uArch + + LINK_LIBS PUBLIC + MLIRIR + MLIRDialectUtils +) + diff --git a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp new file mode 100644 index 0000000000000..76ee0589c4d3c --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp @@ -0,0 +1,216 @@ +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" +#include "mlir/IR/BuiltinTypes.h" +#include "llvm/Support/YAMLTraits.h" +#include +#include +#include +#include + +using namespace mlir::xegpu::uArch; +using namespace mlir::xegpu::uArch::Xe2Plus; + +namespace mlir { +namespace xegpu { +namespace uArch { +namespace Xe2Plus { + +std::vector> +DPASInstruction::getSupportedShapes(mlir::Type dataType, + MMAOpndEnum matrixType) { + auto combineVectors = [](const std::vector &a, + const std::vector &b) + -> std::vector> { + std::vector> result; + for (unsigned x : a) { + for (unsigned y : b) { + result.emplace_back(x, y); + } + } + return result; + }; + + auto M = getSupportedM(dataType); + auto K = getSupportedK(dataType); + auto N = getSupportedN(dataType); + std::vector> resultMatrix; + + switch (matrixType) { + case MMAOpndEnum::MatrixA: + resultMatrix = combineVectors(M, K); + break; + case MMAOpndEnum::MatrixB: + resultMatrix = combineVectors(K, N); + break; + case MMAOpndEnum::MatrixC: + resultMatrix = combineVectors(M, N); + break; + case MMAOpndEnum::MatrixD: + resultMatrix = combineVectors(M, N); + break; + } + return resultMatrix; +} + +std::vector +DPASInstruction::getSupportedTypes(MLIRContext &context, + MMAOpndEnum matrixType) { + mlir::Type bf16Type = mlir::BFloat16Type::get(&context); + mlir::Type f16Type = mlir::Float16Type::get(&context); + mlir::Type tf32Type = mlir::FloatTF32Type::get(&context); + mlir::Type f32Type = mlir::Float32Type::get(&context); + + switch (matrixType) { + case MMAOpndEnum::MatrixA: + return {bf16Type, f16Type, tf32Type}; + break; + case MMAOpndEnum::MatrixB: + return {bf16Type, f16Type, tf32Type}; + break; + case MMAOpndEnum::MatrixC: + return {bf16Type, f16Type, f32Type}; + break; + case MMAOpndEnum::MatrixD: + return {bf16Type, f16Type, f32Type}; + break; + } +} + +bool DPASInstruction::checkSupportedTypes(mlir::Type AType, mlir::Type BType, + mlir::Type CType, mlir::Type DType) { + if (AType.isF16() || BType.isF16()) { + if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) || + (!DType.isF32() && !DType.isF16())) { + llvm::errs() + << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " + << "Supported types are:\n" + << " Dst | Acc | A | B \n" + << " f, hf | f, hf | hf | hf \n" + << "AType: " << AType << " BType: " << BType << " CType: " << CType + << " DType: " << DType; + return false; + } + } else if (AType.isBF16() || BType.isBF16()) { + if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) || + (!DType.isF32() && !DType.isBF16())) { + llvm::errs() + << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " + << "Supported types are:\n" + << " Dst | Acc | A | B \n" + << " f, bf | f, bf | bf | bf \n" + << "AType: " << AType << " BType: " << BType << " CType: " << CType + << " DType: " << DType; + return false; + } + } else if (AType.isTF32() || BType.isTF32()) { + if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) || + (!DType.isF32())) { + llvm::errs() + << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " + << "Supported types are:\n" + << " Dst | Acc | A | B \n" + << " f | f | tf32 | tf32 \n" + << "AType: " << AType << " BType: " << BType << " CType: " << CType + << " DType: " << DType; + return false; + } + } else if (!(AType.isInteger(2) || AType.isInteger(4) || + AType.isInteger(8)) && + !(BType.isInteger(2) || BType.isInteger(4) || + BType.isInteger(8))) { + llvm::errs() + << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " + << "Supported types are:\n" + << " Dst | Acc | A | B " + " \n" + << " ud, d | ud,d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 " + << "AType: " << AType << " BType: " << BType << " CType: " << CType + << " DType: " << DType; + return false; + } + + return true; +} + +bool DPASInstruction::checkSupportedShapesAndTypes( + std::pair AShape, std::pair BShape, + std::pair CShape, std::pair DShape, + mlir::Type AType, mlir::Type BType, mlir::Type CType, mlir::Type DType) { + auto supportedAShapes = getSupportedShapes(AType, MMAOpndEnum::MatrixA); + auto supportedBShapes = getSupportedShapes(BType, MMAOpndEnum::MatrixB); + auto supportedCShapes = getSupportedShapes(CType, MMAOpndEnum::MatrixC); + auto supportedDShapes = getSupportedShapes(DType, MMAOpndEnum::MatrixD); + return llvm::is_contained(supportedAShapes, AShape) && + llvm::is_contained(supportedBShapes, BShape) && + llvm::is_contained(supportedCShapes, CShape) && + llvm::is_contained(supportedDShapes, DShape) && + checkSupportedTypes(AType, BType, CType, DType); +} + +bool DPASInstruction::validate(std::pair AShape, + std::pair BShape, + std::pair CShape, + std::pair DShape, + mlir::Type AType, mlir::Type BType, + mlir::Type CType, mlir::Type DType) { + return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType, + BType, CType, DType); +} + +std::vector DPASInstruction::getSupportedM(mlir::Type type) { + return {1, 2, 3, 4, 5, 6, 7, 8}; +} + +std::vector DPASInstruction::getSupportedK(mlir::Type type) { + // assert if data type is not int or float type + assert(type.isIntOrFloat() && "Matrix type must be int or float"); + auto bitWidth = type.getIntOrFloatBitWidth(); + uint32_t kSize = 0; + switch (bitWidth) { + case 2: + kSize = 64; + break; + case 4: + kSize = 64; + break; + case 8: + kSize = 32; + break; + case 16: + kSize = 16; + break; + case 32: + kSize = 8; + break; + default: + llvm_unreachable("Invalid int or float"); + } + return {kSize}; +} + +std::vector DPASInstruction::getSupportedN(mlir::Type type) { + return {16}; +} + +} // namespace Xe2Plus +} // namespace uArch +} // namespace xegpu +} // namespace mlir + +// namespace mlir { +// namespace xe_gpu { +// namespace namespace mlir { +// namespace xegpu { +// namespace PVCuArchYAML { { +// struct XeCoreInfo { +// uint32_t num_threads; +// SharedMemory shared_memory; +// uint32_t num_vector_units; +// uint32_t num_matrix_units; +// }; + +// struct Xe2Plus { +// XeCoreInfo xe_core; +// }; +// } +// } +// } diff --git a/mlir/test/Dialect/XeGPU/attach-target-device.mlir b/mlir/test/Dialect/XeGPU/attach-target-device.mlir new file mode 100644 index 0000000000000..bdd34d7537f14 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/attach-target-device.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt --xevm-attach-target='module=xevm.* O=3 chip=pvc' %s -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: module @valid_dpas +module @valid_dpas attributes {gpu.container_module} { + // CHECK: gpu.module @valid_dpas [#xevm.target] { + gpu.module @valid_dpas { + // CHECK: gpu.func @valid_dpas + gpu.func @valid_dpas(%a: memref<24x32xf16>, %b: memref<32x24xf16>) { + // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG0:.*]]{{\[}}0, 0] + // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16 + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.layout> + + // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] + // CHECK-SAME: -> vector<24x32xf16> + %load_a = xegpu.load_nd %tdesc_a : !xegpu.tensor_desc<24x32xf16, #xegpu.layout> -> vector<24x32xf16> + + // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG1:.*]]{{\[}}0, 0] + // CHECK-SAME: memref<32x24xf16> -> !xegpu.tensor_desc<32x24xf16 + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf16> -> !xegpu.tensor_desc<32x24xf16, #xegpu.layout> + + // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] + // CHECK-SAME: -> vector<32x24xf16> + %load_b = xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<32x24xf16, #xegpu.layout> -> vector<32x24xf16> + + // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] + // CHECK-SAME: layout_result_0 = #xegpu.layout + // CHECK-SAME: : vector<24x32xf16>, vector<32x24xf16> -> vector<24x24xf16> + %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout} : vector<24x32xf16>, vector<32x24xf16> -> vector<24x24xf16> + + // CHECK: gpu.return + gpu.return + } + } +} diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 83a98ab0622b7..79b3da688092e 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -643,3 +643,22 @@ func.func @tensor_desc_invalid_sg_data(%src: ui64, %offsets: vector<16xindex>) { #xegpu.layout> return } + + +// ----- +module @invalid_dpas attributes {gpu.container_module} { + gpu.module @invalid_dpas [#xevm.target] { + + gpu.func @invalid_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout> + %load_b = xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<32x24xf32, #xegpu.layout> -> vector<32x24xf32> + // expected-error@+1 {{Unsupported DPAS types.}} + %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + gpu.return + } + } +} +