Skip to content

Commit f33b7f7

Browse files
committed
Address review comments.
Simplify the design: - Remove uArchHierarchyComponent LLVMize names. Replace String usage with enum whenever possible.
1 parent 6dace4b commit f33b7f7

File tree

4 files changed

+83
-116
lines changed

4 files changed

+83
-116
lines changed

mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h

Lines changed: 27 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,10 @@ struct Xe2Plus : public uArch {
4545
Xe2Plus(
4646
const std::string &archName, const std::string &archDescription,
4747
const XeCoreInfo &xeCore,
48-
const std::vector<uArchHierarchyComponent> &hierarchy = {},
49-
const std::map<std::string, RegisterFileInfo> &regInfo = {},
48+
const std::map<RegisterFileType, RegisterFileInfo> &regInfo = {},
5049
const std::vector<CacheInfo> &cacheInfo = {},
5150
const std::map<std::string, std::shared_ptr<Instruction>> &instrs = {})
52-
: uArch(archName, archDescription, hierarchy, regInfo, cacheInfo, instrs),
51+
: uArch(archName, archDescription, regInfo, cacheInfo, instrs),
5352
xe_core(xeCore) {}
5453
};
5554

@@ -62,9 +61,9 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
6261

6362
// Override all virtuals from MatrixOpInterface
6463
virtual std::vector<std::pair<uint32_t, uint32_t>>
65-
getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) override;
64+
getSupportedShapes(mlir::Type dataType, MMAOpndKind matrixType) override;
6665
virtual std::vector<mlir::Type>
67-
getSupportedTypes(MLIRContext &context, MMAOpndEnum matrixType) override;
66+
getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
6867
virtual bool
6968
checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
7069
std::pair<uint32_t, uint32_t> BShape,
@@ -97,29 +96,22 @@ struct PVCuArch : public Xe2Plus {
9796
{/* cache_info */}, // Optional: empty
9897
{/* instructions */} // Optional: empty
9998
) {
100-
// Initialize uArchHierarchy
101-
this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
102-
this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
103-
this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 16));
104-
this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 4));
105-
this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 2));
10699
// Intialize register file info
107100
// GRF
108-
this->register_file_info.emplace(
109-
"GRF",
110-
RegisterFileInfo(64 * 1024, // size in bits
111-
{"small", "large"}, // GRF modes
112-
{128, 256}, // registers per thread per mode
113-
0, // number of banks
114-
0 // bank size
115-
));
101+
this->registerFileInfo.emplace(
102+
RegisterFileType::GRF,
103+
RegisterFileInfo(
104+
64 * 1024, // size in bits
105+
{RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
106+
{128, 256} // registers per thread per mode
107+
));
116108
// Initialize cache info
117109
// L1 cache, XeCore level
118-
this->cache_info.push_back(
119-
CacheInfo(512 * 1024, 64, this->uArch_hierarchy[1]));
120-
// L3 cache, XeStack level
121-
this->cache_info.push_back(
122-
CacheInfo(512 * 1024, 64, this->uArch_hierarchy[3]));
110+
this->cacheInfo.push_back(
111+
CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1));
112+
// L2 cache, XeStack level
113+
this->cacheInfo.push_back(
114+
CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
123115

124116
// Add the instructions
125117
auto dpas = std::make_shared<DPASInstruction>();
@@ -140,31 +132,22 @@ struct BMGuArch : public Xe2Plus {
140132
XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
141133
{/* register_file_info */}, // Optional: empty
142134
{/* cache_info */}, // Optional: empty
143-
{/* instructions */}, // Optional: empty
144-
{/* restrictions */} // Optional: empty
135+
{/* instructions */} // Optional: empty)
145136
) {
146-
// Initialize uArchHierarchy
147-
this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
148-
this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
149-
this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 4));
150-
this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 5));
151-
this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 1));
152137
// Intialize register file info
153138
// GRF
154-
this->register_file_info["GRF"] =
155-
RegisterFileInfo(64 * 1024, // size in bits
156-
{"small", "large"}, // GRF modes
157-
{128, 256}, // registers per thread per mode
158-
0, // number of banks
159-
0 // bank size
160-
);
139+
this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo(
140+
64 * 1024, // size in bits
141+
{RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
142+
{128, 256} // registers per thread per mode
143+
);
161144
// Initialize cache info
162145
// L1 cache, XeCore level
163-
this->cache_info.push_back(
164-
CacheInfo(256 * 1024, 64, this->uArch_hierarchy[1]));
165-
// L3 cache, XeStack level
166-
this->cache_info.push_back(
167-
CacheInfo(18 * 1024 * 1024, 256, this->uArch_hierarchy[3]));
146+
this->cacheInfo.push_back(
147+
CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1));
148+
// L2 cache, XeStack level
149+
this->cacheInfo.push_back(
150+
CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2));
168151

169152
// Add the instructions
170153
auto dpas = std::make_shared<DPASInstruction>();

mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h

Lines changed: 39 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,15 @@
2727
namespace mlir {
2828
namespace xegpu {
2929
namespace uArch {
30-
// Architecture HW component hierarchy to present thread, core, socket ...
31-
struct uArchHierarchyComponent {
32-
std::string name = ""; // optional name of the hierarchy component
33-
// no. of lower hierarchy component it contains, e.g., for PVC XeCore it
34-
// contains 8 threads, so no_of_component=8
35-
uint32_t no_of_component;
36-
// Constructor
37-
uArchHierarchyComponent(const std::string &name, uint32_t no_of_component)
38-
: name(name), no_of_component(no_of_component) {}
39-
};
4030

4131
// An enum class to represent the scope of an instruction
42-
enum class InstructionScopeEnum { WorkItem, Subgroup, Workgroup, Cluster };
32+
enum class InstructionScope { WorkItem, Subgroup, Workgroup, Cluster };
33+
34+
enum class InstructionName {
35+
DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix multiply-add
36+
// operation
37+
// Add more instructions as needed
38+
};
4339

4440
// A struct to represent basic information about an instruction
4541
// This struct is used to represent the information about an instruction in the
@@ -67,69 +63,62 @@ struct Instruction {
6763
// Get methods
6864
std::string getName() { return name; }
6965
std::string getDescription() { return description; }
70-
InstructionScopeEnum getScope() { return scope; }
66+
InstructionScope getScope() { return scope; }
7167

7268
protected:
7369
std::string name;
7470
std::string description;
75-
InstructionScopeEnum scope;
71+
InstructionScope scope;
7672
};
7773

74+
enum class RegisterFileMode : uint8_t { Small, Large };
75+
enum class RegisterFileType : uint8_t { GRF, ARF };
76+
7877
// A struct to represent register file information
7978
struct RegisterFileInfo {
8079
// Constructor
8180
RegisterFileInfo() = default;
82-
RegisterFileInfo(uint32_t size, const std::vector<std::string> &mode,
83-
const std::vector<uint32_t> &numRegs, uint32_t num_banks,
84-
uint32_t bank_size)
85-
: size(size), mode(mode), num_regs_per_thread_per_mode(numRegs),
86-
num_banks(num_banks), bank_size(bank_size) {}
81+
RegisterFileInfo(uint32_t size, const std::vector<RegisterFileMode> &mode,
82+
const std::vector<uint32_t> &numRegs)
83+
: size(size), mode(mode), numRegsPerThreadPerMode(numRegs) {}
8784

88-
// Get methods
8985
uint32_t getSize() const { return size; }
90-
91-
const std::vector<std::string> &getModes() const { return mode; }
92-
86+
const std::vector<RegisterFileMode> &getModes() const { return mode; }
9387
const std::vector<uint32_t> &getNumRegsPerThreadPerMode() const {
94-
return num_regs_per_thread_per_mode;
88+
return numRegsPerThreadPerMode;
9589
}
9690

97-
uint32_t getNumBanks() const { return num_banks; }
98-
99-
uint32_t getBankSize() const { return bank_size; }
100-
10191
protected:
102-
uint32_t size; // size per register in bits
103-
std::vector<std::string> mode; // e.g., "small", "large" GRF modes
92+
uint32_t size; // size per register in bits
93+
std::vector<RegisterFileMode> mode; // e.g., "small", "large" GRF modes
10494
std::vector<uint32_t>
105-
num_regs_per_thread_per_mode; // number of registers per thread per mode
106-
uint32_t num_banks;
107-
uint32_t bank_size;
95+
numRegsPerThreadPerMode; // number of registers per thread per mode
96+
// TODO: Add more fields as needed (e.g., num_banks, bank_size, num_ports,
97+
// port_width, bank_conflicts)
10898
};
10999

100+
enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
110101
// A struct to represent cache information
111-
112102
struct CacheInfo {
113103
// Constructor
114104
CacheInfo(uint32_t size, uint32_t line_size,
115-
const uArchHierarchyComponent &component)
116-
: size(size), line_size(line_size), component(component) {}
105+
CacheHierarchyLevel hierarchy_level)
106+
: size(size), line_size(line_size), hierarchy_level(hierarchy_level) {}
117107

118108
virtual ~CacheInfo() = default;
119109

120110
// Get methods
121111
uint32_t getSize() const { return size; }
122112
uint32_t getLineSize() const { return line_size; }
123-
const uArchHierarchyComponent &getComponent() const { return component; }
113+
CacheHierarchyLevel getHierarchyLevel() const { return hierarchy_level; }
124114

125115
protected:
126116
uint32_t size;
127117
uint32_t line_size;
128-
// At which component level the cache is shared
129-
uArchHierarchyComponent component;
130-
118+
CacheHierarchyLevel hierarchy_level;
131119
// @TODO: Add more fields as needed (e.g., associativity, num_banks,
132-
// bank_size, num_ports, port_width, bank_conflicts)
120+
// bank_size, num_ports, port_width, bank_conflicts, hierarchy_level,
121+
// latency, throughput, bandwidth)
133122
};
134123

135124
// A struct to represent the uArch
@@ -145,29 +134,26 @@ struct uArch {
145134
// Constructor
146135
uArch() = default;
147136
uArch(const std::string &name, const std::string &description,
148-
const std::vector<uArchHierarchyComponent> &uArch_hierarchy = {},
149-
const std::map<std::string, RegisterFileInfo> &register_file_info = {},
137+
const std::map<RegisterFileType, RegisterFileInfo> &register_file_info =
138+
{},
150139
const std::vector<CacheInfo> &cache_info = {},
151140
const std::map<std::string, std::shared_ptr<Instruction>>
152141
&instructions = {})
153-
: name(name), description(description), uArch_hierarchy(uArch_hierarchy),
154-
register_file_info(register_file_info), cache_info(cache_info),
142+
: name(name), description(description),
143+
registerFileInfo(register_file_info), cacheInfo(cache_info),
155144
instructions(instructions) {}
156145

157146
// Get methods
158147
const std::string &getName() const { return name; }
159148

160149
const std::string &getDescription() const { return description; }
161150

162-
const std::vector<uArchHierarchyComponent> &getHierarchy() const {
163-
return uArch_hierarchy;
164-
}
165-
166-
const std::map<std::string, RegisterFileInfo> &getRegisterFileInfo() const {
167-
return register_file_info;
151+
const std::map<RegisterFileType, RegisterFileInfo> &
152+
getRegisterFileInfo() const {
153+
return registerFileInfo;
168154
}
169155

170-
const std::vector<CacheInfo> &getCacheInfo() const { return cache_info; }
156+
const std::vector<CacheInfo> &getCacheInfo() const { return cacheInfo; }
171157

172158
const std::map<std::string, std::shared_ptr<Instruction>> &
173159
getInstructions() const {
@@ -192,9 +178,8 @@ struct uArch {
192178
protected:
193179
std::string name; // Similar to target triple
194180
std::string description;
195-
std::vector<uArchHierarchyComponent> uArch_hierarchy;
196-
std::map<std::string, RegisterFileInfo> register_file_info;
197-
std::vector<CacheInfo> cache_info;
181+
std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
182+
std::vector<CacheInfo> cacheInfo;
198183
std::map<std::string, std::shared_ptr<Instruction>> instructions;
199184
};
200185

mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,11 @@ namespace mlir {
2626
namespace xegpu {
2727
namespace uArch {
2828

29-
enum class MMAOpndEnum { MatrixA, MatrixB, MatrixC, MatrixD };
29+
enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD };
3030
struct MMAInstructionInterface {
3131
// Get supported Matrix shapes
3232
virtual std::vector<std::pair<uint32_t, uint32_t>>
33-
getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) = 0;
34-
33+
getSupportedShapes(mlir::Type dataType, MMAOpndKind matrixType) = 0;
3534
// @TODO: This method takes an context object as a parameter, this is to
3635
// create the mlir::Type objects from the same context. Since type objects are
3736
// uniqued in a specific context, to do things like "aType == bType" (where
@@ -46,7 +45,7 @@ struct MMAInstructionInterface {
4645
// Untill we have a better solution, we stick to passing context object to
4746
// this method.
4847
virtual std::vector<mlir::Type> getSupportedTypes(MLIRContext &context,
49-
MMAOpndEnum matrixType) = 0;
48+
MMAOpndKind matrixType) = 0;
5049
virtual bool
5150
checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
5251
std::pair<uint32_t, uint32_t> BShape,

mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace Xe2Plus {
1616

1717
std::vector<std::pair<uint32_t, uint32_t>>
1818
DPASInstruction::getSupportedShapes(mlir::Type dataType,
19-
MMAOpndEnum matrixType) {
19+
MMAOpndKind matrixType) {
2020
auto combineVectors = [](const std::vector<uint32_t> &a,
2121
const std::vector<uint32_t> &b)
2222
-> std::vector<std::pair<uint32_t, uint32_t>> {
@@ -35,16 +35,16 @@ DPASInstruction::getSupportedShapes(mlir::Type dataType,
3535
std::vector<std::pair<unsigned, unsigned>> resultMatrix;
3636

3737
switch (matrixType) {
38-
case MMAOpndEnum::MatrixA:
38+
case MMAOpndKind::MatrixA:
3939
resultMatrix = combineVectors(M, K);
4040
break;
41-
case MMAOpndEnum::MatrixB:
41+
case MMAOpndKind::MatrixB:
4242
resultMatrix = combineVectors(K, N);
4343
break;
44-
case MMAOpndEnum::MatrixC:
44+
case MMAOpndKind::MatrixC:
4545
resultMatrix = combineVectors(M, N);
4646
break;
47-
case MMAOpndEnum::MatrixD:
47+
case MMAOpndKind::MatrixD:
4848
resultMatrix = combineVectors(M, N);
4949
break;
5050
}
@@ -53,23 +53,23 @@ DPASInstruction::getSupportedShapes(mlir::Type dataType,
5353

5454
std::vector<mlir::Type>
5555
DPASInstruction::getSupportedTypes(MLIRContext &context,
56-
MMAOpndEnum matrixType) {
56+
MMAOpndKind matrixType) {
5757
mlir::Type bf16Type = mlir::BFloat16Type::get(&context);
5858
mlir::Type f16Type = mlir::Float16Type::get(&context);
5959
mlir::Type tf32Type = mlir::FloatTF32Type::get(&context);
6060
mlir::Type f32Type = mlir::Float32Type::get(&context);
6161

6262
switch (matrixType) {
63-
case MMAOpndEnum::MatrixA:
63+
case MMAOpndKind::MatrixA:
6464
return {bf16Type, f16Type, tf32Type};
6565
break;
66-
case MMAOpndEnum::MatrixB:
66+
case MMAOpndKind::MatrixB:
6767
return {bf16Type, f16Type, tf32Type};
6868
break;
69-
case MMAOpndEnum::MatrixC:
69+
case MMAOpndKind::MatrixC:
7070
return {bf16Type, f16Type, f32Type};
7171
break;
72-
case MMAOpndEnum::MatrixD:
72+
case MMAOpndKind::MatrixD:
7373
return {bf16Type, f16Type, f32Type};
7474
break;
7575
}
@@ -110,10 +110,10 @@ bool DPASInstruction::checkSupportedShapesAndTypes(
110110
std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
111111
std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
112112
mlir::Type AType, mlir::Type BType, mlir::Type CType, mlir::Type DType) {
113-
auto supportedAShapes = getSupportedShapes(AType, MMAOpndEnum::MatrixA);
114-
auto supportedBShapes = getSupportedShapes(BType, MMAOpndEnum::MatrixB);
115-
auto supportedCShapes = getSupportedShapes(CType, MMAOpndEnum::MatrixC);
116-
auto supportedDShapes = getSupportedShapes(DType, MMAOpndEnum::MatrixD);
113+
auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
114+
auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
115+
auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
116+
auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
117117
return llvm::is_contained(supportedAShapes, AShape) &&
118118
llvm::is_contained(supportedBShapes, BShape) &&
119119
llvm::is_contained(supportedCShapes, CShape) &&

0 commit comments

Comments
 (0)