Skip to content

Commit a99ee75

Browse files
committed
refactor
1 parent d7eaaa5 commit a99ee75

File tree

3 files changed

+31
-32
lines changed

3 files changed

+31
-32
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ class LayoutAttr;
2424
class TensorDescType;
2525
} // namespace xegpu
2626

27+
namespace xegpu {
28+
/// HW dependent constants.
29+
/// TODO: These constants should be queried from the target information.
30+
namespace targetinfo {
31+
constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
32+
/// If DPAS A or B operands have low precision element types they must be packed
33+
/// according to the following sizes.
34+
constexpr unsigned packedSizeInBitsForDefault =
35+
16; // Minimum packing size per register for DPAS A.
36+
constexpr unsigned packedSizeInBitsForDpasB =
37+
32; // Minimum packing size per register for DPAS B.
38+
} // namespace targetinfo
39+
} // namespace xegpu
40+
2741
namespace xegpu {
2842

2943
/// If tensor descriptor has a layout attribute it is used in SIMT mode.

mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,6 @@ namespace xegpu {
4646
using namespace mlir;
4747
using namespace mlir::dataflow;
4848

49-
/// HW dependent constants.
50-
/// TODO: These constants should be queried from the target information.
51-
constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
52-
/// If DPAS A or B operands have low precision element types they must be packed
53-
/// according to the following sizes.
54-
constexpr unsigned packedSizeInBitsForDefault =
55-
16; // Minimum packing size per register for DPAS A.
56-
constexpr unsigned packedSizeInBitsForDpasB =
57-
32; // Minimum packing size per register for DPAS B.
58-
5949
namespace {
6050

6151
//===----------------------------------------------------------------------===//
@@ -198,8 +188,10 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
198188
static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
199189
assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
200190
if (rank == 1)
201-
return LayoutInfo(LaneLayout({subgroupSize}), LaneData({1}));
202-
return LayoutInfo(LaneLayout({1, subgroupSize}), LaneData({1, 1}));
191+
return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
192+
LaneData({1}));
193+
return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
194+
LaneData({1, 1}));
203195
}
204196

205197
/// Helper to get the default layout for a vector type.
@@ -216,9 +208,9 @@ static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
216208
// Packing factor is determined by the element type bitwidth.
217209
int packingFactor = 1;
218210
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
219-
if (bitwidth < packedSizeInBitsForDefault)
220-
packingFactor = packedSizeInBitsForDefault / bitwidth;
221-
return LayoutInfo(LaneLayout({1, subgroupSize}),
211+
if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
212+
packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
213+
return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
222214
LaneData({1, packingFactor}));
223215
}
224216

@@ -233,13 +225,14 @@ static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
233225
Type elementTy = vectorTy.getElementType();
234226
assert(elementTy.isIntOrFloat() &&
235227
"Expected int or float type in DPAS operands");
236-
LaneLayout layout({1, subgroupSize});
228+
LaneLayout layout({1, xegpu::targetinfo::subgroupSize});
237229
// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
238230
// must have the VNNI format.
239-
if (operandNum == 1 &&
240-
elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
241-
LaneData data(
242-
{packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
231+
if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() <
232+
xegpu::targetinfo::packedSizeInBitsForDpasB) {
233+
LaneData data({xegpu::targetinfo::packedSizeInBitsForDpasB /
234+
elementTy.getIntOrFloatBitWidth(),
235+
1});
243236
return LayoutInfo(layout, data);
244237
}
245238
// Otherwise, return the default layout for the vector type.
@@ -577,7 +570,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
577570
ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
578571
if (tdescShape.size() > 1)
579572
assert(
580-
tdescShape[0] == subgroupSize &&
573+
tdescShape[0] == xegpu::targetinfo::subgroupSize &&
581574
"Expected the first dimension of 2D tensor descriptor to be equal to "
582575
"subgroup size.");
583576

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,6 @@ namespace xegpu {
5858

5959
using namespace mlir;
6060

61-
/// HW dependent constants.
62-
/// TODO: These constants should be queried from the target information.
63-
constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
64-
/// If DPAS A or B operands have low precision element types they must be packed
65-
/// according to the following sizes.
66-
constexpr unsigned packedSizeInBitsForDefault =
67-
16; // Minimum packing size per register for DPAS A.
68-
constexpr unsigned packedSizeInBitsForDpasB =
69-
32; // Minimum packing size per register for DPAS B.
7061
static const char *const resolveSIMTTypeMismatch =
7162
"resolve_simt_type_mismatch"; // Attribute name for identifying
7263
// UnrelizedConversionCastOp added to resolve
@@ -228,8 +219,9 @@ struct MoveFuncBodyToWarpExecuteOnLane0
228219
/** upperBound = **/ mlir::IntegerAttr());
229220
ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
230221
auto warpOp = rewriter.create<gpu::WarpExecuteOnLane0Op>(
231-
laneId.getLoc(), gpuFuncResultType, laneId, subgroupSize,
232-
newGpuFunc.getArguments(), newGpuFunc.getArgumentTypes());
222+
laneId.getLoc(), gpuFuncResultType, laneId,
223+
xegpu::targetinfo::subgroupSize, newGpuFunc.getArguments(),
224+
newGpuFunc.getArgumentTypes());
233225
Block &warpBodyBlock = warpOp.getBodyRegion().front();
234226
// Replace the ReturnOp of the original gpu function with a YieldOp.
235227
auto origRetunOp =

0 commit comments

Comments
 (0)