Skip to content

Commit 83d3a2e

Browse files
authored
[uArch][XeGPU] Add XeGPU uArch definition. (#153706)
The uArch infrastructure provides: - A set data structures to represent, uArch and it's necessary components (e.g., instructions, register-files, caches). - A set of utility interfaces that are common to a family of ops (e.g., mma ops, 2DBlockIO ops). The implementation of these interfaces are provided by the specific instructions. Each family of ops provides these 5 common APIs. However, some family of ops may have more utility APIs. The common 5 APIs are: - getSupportedShapes - getSupportedTypes - checkSupportedShapesAndTypes - checkSupportedTypes - validate Add support for PVC and BMG architectures. Add support for DPAS instruction.
1 parent 311d113 commit 83d3a2e

File tree

3 files changed

+563
-0
lines changed

3 files changed

+563
-0
lines changed
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
//===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// \file
10+
// Xe2 uArch definition. Xe2 is the second generation of Intel Xe GPUs.
11+
// This file defines the uArch details for Xe2 and its derived architectures.
12+
// This includes Ponte Vecchio (PVC) and Battlemage (BMG) architectures.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
16+
#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
17+
18+
#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
19+
#include "mlir/IR/BuiltinTypes.h"
20+
#include "mlir/IR/TypeUtilities.h"
21+
#include "llvm/ADT/SmallVector.h"
22+
#include "llvm/Support/DebugLog.h"
23+
#include <map>
24+
#include <string>
25+
26+
#define DEBUG_TYPE "xegpu-uarch"
27+
28+
using namespace mlir;
29+
using namespace mlir::xegpu::uArch;
30+
31+
namespace mlir {
32+
namespace xegpu {
33+
namespace uArch {
34+
35+
struct Xe2Plus : public uArch {
36+
XeCoreInfo xeCore;
37+
Xe2Plus(const std::string &archName, const std::string &archDescription,
38+
const XeCoreInfo &xeCore,
39+
const std::map<RegisterFileType, RegisterFileInfo> &regInfo = {},
40+
const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
41+
const std::map<InstructionKind, std::shared_ptr<Instruction>>
42+
&instrs = {})
43+
: uArch(archName, archDescription, regInfo, cacheInfo, instrs),
44+
xeCore(xeCore) {}
45+
};
46+
47+
// struct to represent DPAS instruction
48+
struct DPASInstruction : public Instruction, public MMAInstructionInterface {
49+
DPASInstruction()
50+
: Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
51+
52+
// Override all virtuals from MatrixOpInterface
53+
virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
54+
getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
55+
virtual llvm::SmallVector<Type, 8>
56+
getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
57+
virtual bool
58+
checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
59+
std::pair<uint32_t, uint32_t> BShape,
60+
std::pair<uint32_t, uint32_t> CShape,
61+
std::pair<uint32_t, uint32_t> DShape, Type AType,
62+
Type BType, Type CType, Type DType) override;
63+
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
64+
Type DType) override;
65+
virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
66+
std::pair<uint32_t, uint32_t> BShape,
67+
std::pair<uint32_t, uint32_t> CShape,
68+
std::pair<uint32_t, uint32_t> DShape, Type AType,
69+
Type BType, Type CType, Type DType) override;
70+
virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) override;
71+
virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) override;
72+
virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
73+
};
74+
75+
struct PVCuArch : public Xe2Plus {
76+
// Maintaines ownership of the instructions owned by PVUarch
77+
llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
78+
PVCuArch()
79+
: Xe2Plus("pvc", // archName
80+
"Ponte Vecchio Architecture", // archDescription
81+
XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
82+
{/* registerFileInfo */}, // Optional: empty
83+
{/* cacheInfo */}, // Optional: empty
84+
{/* instructions */} // Optional: empty
85+
) {
86+
// Intialize register file info
87+
// GRF
88+
this->registerFileInfo.emplace(
89+
RegisterFileType::GRF,
90+
RegisterFileInfo(
91+
64 * 1024, // size in bits
92+
{RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
93+
{128, 256} // registers per thread per mode
94+
));
95+
// Initialize cache info
96+
// L1 cache, XeCore level
97+
this->cacheInfo.push_back(
98+
CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1));
99+
// L2 cache, XeStack level
100+
this->cacheInfo.push_back(
101+
CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
102+
103+
// Add the instructions-
104+
auto dpas = std::make_shared<DPASInstruction>();
105+
instructions.emplace(dpas->getInstructionKind(), dpas);
106+
owned_instructions.push_back(dpas);
107+
}
108+
};
109+
110+
struct BMGuArch : public Xe2Plus {
111+
// Maintaines ownership of the instructions owned by PVUarch
112+
llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
113+
BMGuArch()
114+
: Xe2Plus("bmg", // archName
115+
"Battlemage Architecture", // archDescription
116+
XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
117+
{/* registerFileInfo */}, // Optional: empty
118+
{/* cacheInfo */}, // Optional: empty
119+
{/* instructions */} // Optional: empty
120+
) {
121+
// Intialize register file info
122+
// GRF
123+
this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo(
124+
64 * 1024, // size in bits
125+
{RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
126+
{128, 256} // registers per thread per mode
127+
);
128+
// Initialize cache info
129+
// L1 cache, XeCore level
130+
this->cacheInfo.push_back(
131+
CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1));
132+
// L2 cache, XeStack level
133+
this->cacheInfo.push_back(
134+
CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2));
135+
136+
// Add the instructions
137+
auto dpas = std::make_shared<DPASInstruction>();
138+
instructions.emplace(dpas->getInstructionKind(), dpas);
139+
owned_instructions.push_back(dpas);
140+
}
141+
};
142+
} // namespace uArch
143+
} // namespace xegpu
144+
} // namespace mlir
145+
146+
inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
147+
DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
148+
auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
149+
const llvm::SmallVector<uint32_t, 8> &b)
150+
-> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
151+
llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> result;
152+
for (unsigned x : a) {
153+
for (unsigned y : b) {
154+
result.emplace_back(x, y);
155+
}
156+
}
157+
return result;
158+
};
159+
160+
auto M = getSupportedM(dataType);
161+
auto K = getSupportedK(dataType);
162+
auto N = getSupportedN(dataType);
163+
llvm::SmallVector<std::pair<unsigned, unsigned>, 16> resultMatrix;
164+
165+
switch (matrixType) {
166+
case MMAOpndKind::MatrixA:
167+
resultMatrix = combineVectors(M, K);
168+
break;
169+
case MMAOpndKind::MatrixB:
170+
resultMatrix = combineVectors(K, N);
171+
break;
172+
case MMAOpndKind::MatrixC:
173+
resultMatrix = combineVectors(M, N);
174+
break;
175+
case MMAOpndKind::MatrixD:
176+
resultMatrix = combineVectors(M, N);
177+
break;
178+
}
179+
return resultMatrix;
180+
}
181+
182+
inline llvm::SmallVector<Type, 8>
183+
DPASInstruction::getSupportedTypes(MLIRContext &context,
184+
MMAOpndKind matrixType) {
185+
Type bf16Type = BFloat16Type::get(&context);
186+
Type f16Type = Float16Type::get(&context);
187+
Type tf32Type = FloatTF32Type::get(&context);
188+
Type f32Type = Float32Type::get(&context);
189+
190+
switch (matrixType) {
191+
case MMAOpndKind::MatrixA:
192+
return {bf16Type, f16Type, tf32Type};
193+
case MMAOpndKind::MatrixB:
194+
return {bf16Type, f16Type, tf32Type};
195+
case MMAOpndKind::MatrixC:
196+
return {bf16Type, f16Type, f32Type};
197+
case MMAOpndKind::MatrixD:
198+
return {bf16Type, f16Type, f32Type};
199+
}
200+
return {};
201+
}
202+
203+
inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
204+
Type CType, Type DType) {
205+
if (AType.isF16() || BType.isF16()) {
206+
if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
207+
(!DType.isF32() && !DType.isF16())) {
208+
LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
209+
return false;
210+
}
211+
} else if (AType.isBF16() || BType.isBF16()) {
212+
if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
213+
(!DType.isF32() && !DType.isBF16())) {
214+
LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
215+
return false;
216+
}
217+
} else if (AType.isTF32() || BType.isTF32()) {
218+
if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
219+
(!DType.isF32())) {
220+
LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
221+
return false;
222+
}
223+
} else if (!(AType.isInteger(2) || AType.isInteger(4) ||
224+
AType.isInteger(8)) &&
225+
!(BType.isInteger(2) || BType.isInteger(4) ||
226+
BType.isInteger(8))) {
227+
LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
228+
return false;
229+
}
230+
231+
return true;
232+
}
233+
234+
inline bool DPASInstruction::checkSupportedShapesAndTypes(
235+
std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
236+
std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
237+
Type AType, Type BType, Type CType, Type DType) {
238+
auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
239+
auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
240+
auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
241+
auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
242+
return llvm::is_contained(supportedAShapes, AShape) &&
243+
llvm::is_contained(supportedBShapes, BShape) &&
244+
llvm::is_contained(supportedCShapes, CShape) &&
245+
llvm::is_contained(supportedDShapes, DShape) &&
246+
checkSupportedTypes(AType, BType, CType, DType);
247+
}
248+
249+
inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
250+
std::pair<uint32_t, uint32_t> BShape,
251+
std::pair<uint32_t, uint32_t> CShape,
252+
std::pair<uint32_t, uint32_t> DShape,
253+
Type AType, Type BType, Type CType,
254+
Type DType) {
255+
return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
256+
BType, CType, DType);
257+
}
258+
259+
inline llvm::SmallVector<uint32_t, 8>
260+
DPASInstruction::getSupportedM(Type type) {
261+
return {1, 2, 3, 4, 5, 6, 7, 8};
262+
}
263+
264+
inline llvm::SmallVector<uint32_t, 8>
265+
DPASInstruction::getSupportedK(Type type) {
266+
// assert if data type is not int or float type
267+
assert(type.isIntOrFloat() && "Matrix type must be int or float");
268+
auto bitWidth = type.getIntOrFloatBitWidth();
269+
uint32_t kSize = 0;
270+
switch (bitWidth) {
271+
case 2:
272+
kSize = 64;
273+
break;
274+
case 4:
275+
kSize = 64;
276+
break;
277+
case 8:
278+
kSize = 32;
279+
break;
280+
case 16:
281+
kSize = 16;
282+
break;
283+
case 32:
284+
kSize = 8;
285+
break;
286+
default:
287+
llvm_unreachable("Invalid int or float");
288+
}
289+
return {kSize};
290+
}
291+
292+
inline llvm::SmallVector<uint32_t, 8>
293+
DPASInstruction::getSupportedN(Type type) {
294+
return {16};
295+
}
296+
297+
#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H

0 commit comments

Comments
 (0)