Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 131 additions & 74 deletions mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
#include <map>
#include <string>

#define DEBUG_TYPE "xegpu-uarch"

using namespace mlir;
using namespace mlir::xegpu::uArch;

Expand All @@ -33,21 +31,80 @@ namespace xegpu {
namespace uArch {

struct Xe2Plus : public uArch {
Xe2Plus(StringRef archName, StringRef archDescription,
llvm::ArrayRef<const Instruction *> instructionRegistry,
const XeCoreInfo &xeCore)
: uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
int getSubgroupSize() const override { return 16; }
unsigned getPackedFormatBitSize() const override { return 16; }
unsigned getPackedFormatBitSizeGatherScatter() const override { return 32; }

protected:
XeCoreInfo xeCore;
Xe2Plus(const std::string &archName, const std::string &archDescription,
const XeCoreInfo &xeCore,
const std::map<RegisterFileType, RegisterFileInfo> &regInfo = {},
const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
const std::map<InstructionKind, std::shared_ptr<Instruction>>
&instrs = {})
: uArch(archName, archDescription, regInfo, cacheInfo, instrs),
xeCore(xeCore) {}
};

// struct to represent DPAS instruction
//===----------------------------------------------------------------------===//
// uArch instructions
//===----------------------------------------------------------------------===//
struct StoreNdInstruction : public Instruction {
StoreNdInstruction()
: Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}

// Source :
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
// Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
// the specified pointer
llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
return sortedLaneVectorLengths;
}
};

struct LoadNdInstruction : public Instruction {
LoadNdInstruction()
: Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}

// Source :
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
// Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
// the specified pointer.
llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
return sortedLaneVectorLengths;
}
};

struct PrefetchNdInstruction : public Instruction {
PrefetchNdInstruction()
: Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}

// Source :
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
llvm::ArrayRef<int> getSortedLaneVectorLengths(int elementBitwidth) const {
const static int sortedNarrowTypesLengths[] = {1, 2, 4, 8, 16};
const static int sortedWideTypesLengths[] = {1, 2, 4, 8};
switch (elementBitwidth) {
case 8:
case 16:
return sortedNarrowTypesLengths;
case 32:
case 64:
return sortedWideTypesLengths;
default:
llvm_unreachable("Unsupported element bitwidth");
}
}
};

struct DPASInstruction : public Instruction, public MMAInstructionInterface {
DPASInstruction()
: Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB,
unsigned packedFormatBitSizeC)
: Instruction(InstructionKind::DPAS, InstructionScope::Subgroup),
packedFormatBitSizeA(packedFormatBitSizeA),
packedFormatBitSizeB(packedFormatBitSizeB),
packedFormatBitSizeC(packedFormatBitSizeC) {}
// Source:
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html

// Override all virtuals from MatrixOpInterface
virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
Expand All @@ -67,82 +124,82 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
std::pair<uint32_t, uint32_t> CShape,
std::pair<uint32_t, uint32_t> DShape, Type AType,
Type BType, Type CType, Type DType) override;
virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) override;
virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) override;
virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
virtual llvm::SmallVector<uint32_t, 8>
getSupportedM(Type type) const override;
virtual llvm::SmallVector<uint32_t, 8>
getSupportedK(Type type) const override;
virtual llvm::SmallVector<uint32_t, 8>
getSupportedN(Type type) const override;

unsigned getPackedFormatBitSizeA() const { return packedFormatBitSizeA; }
unsigned getPackedFormatBitSizeB() const { return packedFormatBitSizeB; }
unsigned getPackedFormatBitSizeC() const { return packedFormatBitSizeC; }

protected:
const unsigned packedFormatBitSizeA;
const unsigned packedFormatBitSizeB;
const unsigned packedFormatBitSizeC;
};

struct PVCuArch : public Xe2Plus {
// Maintaines ownership of the instructions owned by PVUarch
llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
//===----------------------------------------------------------------------===//
// uArch instances
//===----------------------------------------------------------------------===//

struct PVCuArch final : public Xe2Plus {
inline static const DPASInstruction dpasInst{16, 32, 32};
inline static const StoreNdInstruction loadNdInst;
inline static const StoreNdInstruction storeNdInst;
inline static const PrefetchNdInstruction prefetchNdInst;
inline static const Instruction *const instructionRegistryArr[] = {
&dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst};
PVCuArch()
: Xe2Plus("pvc", // archName
"Ponte Vecchio Architecture", // archDescription
XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
{/* registerFileInfo */}, // Optional: empty
{/* cacheInfo */}, // Optional: empty
{/* instructions */} // Optional: empty
) {
// Intialize register file info
// GRF
this->registerFileInfo.emplace(
RegisterFileType::GRF,
RegisterFileInfo(
64 * 1024, // size in bits
{RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
{128, 256} // registers per thread per mode
));
// Initialize cache info
// L1 cache, XeCore level
this->cacheInfo.push_back(
CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1));
// L2 cache, XeStack level
this->cacheInfo.push_back(
CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));

// Add the instructions-
auto dpas = std::make_shared<DPASInstruction>();
instructions.emplace(dpas->getInstructionKind(), dpas);
owned_instructions.push_back(dpas);
instructionRegistryArr,
XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore
) {}
static const uArch *getInstance() {
static const PVCuArch instance;
return reinterpret_cast<const uArch *>(&instance);
}
};

struct BMGuArch : public Xe2Plus {
// Maintaines ownership of the instructions owned by PVUarch
llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
inline static const DPASInstruction dpasInst{16, 32, 32};
inline static const StoreNdInstruction loadNdInst;
inline static const StoreNdInstruction storeNdInst;
inline static const PrefetchNdInstruction prefetchNdInst;
inline static const Instruction *const instructionRegistryArr[] = {
&dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst};
BMGuArch()
: Xe2Plus("bmg", // archName
"Battlemage Architecture", // archDescription
XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
{/* registerFileInfo */}, // Optional: empty
{/* cacheInfo */}, // Optional: empty
{/* instructions */} // Optional: empty
) {
// Intialize register file info
// GRF
this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo(
64 * 1024, // size in bits
{RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
{128, 256} // registers per thread per mode
);
// Initialize cache info
// L1 cache, XeCore level
this->cacheInfo.push_back(
CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1));
// L2 cache, XeStack level
this->cacheInfo.push_back(
CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2));

// Add the instructions
auto dpas = std::make_shared<DPASInstruction>();
instructions.emplace(dpas->getInstructionKind(), dpas);
owned_instructions.push_back(dpas);
instructionRegistryArr,
XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
) {}
static const uArch *getInstance() {
static const BMGuArch instance;
return reinterpret_cast<const uArch *>(&instance);
}
};

inline const uArch *getUArch(llvm::StringRef archName) {
if (archName.equals_insensitive("pvc"))
return PVCuArch::getInstance();
else if (archName.equals_insensitive("bmg"))
return BMGuArch::getInstance();

return nullptr;
}

} // namespace uArch
} // namespace xegpu
} // namespace mlir

//===----------------------------------------------------------------------===//
// Instruction implementations
//===----------------------------------------------------------------------===//

inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
Expand Down Expand Up @@ -257,12 +314,12 @@ inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
}

inline llvm::SmallVector<uint32_t, 8>
DPASInstruction::getSupportedM(Type type) {
DPASInstruction::getSupportedM(Type type) const {
return {1, 2, 3, 4, 5, 6, 7, 8};
}

inline llvm::SmallVector<uint32_t, 8>
DPASInstruction::getSupportedK(Type type) {
DPASInstruction::getSupportedK(Type type) const {
// assert if data type is not int or float type
assert(type.isIntOrFloat() && "Matrix type must be int or float");
auto bitWidth = type.getIntOrFloatBitWidth();
Expand Down Expand Up @@ -290,7 +347,7 @@ DPASInstruction::getSupportedK(Type type) {
}

inline llvm::SmallVector<uint32_t, 8>
DPASInstruction::getSupportedN(Type type) {
DPASInstruction::getSupportedN(Type type) const {
return {16};
}

Expand Down
Loading
Loading