Skip to content

Commit 1b16552

Browse files
committed
address comments
1 parent 59de450 commit 1b16552

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,30 +39,32 @@ void XeGPUDialect::initialize() {
3939

4040
/// Generates instructions to compute offsets for a subgroup identified by
4141
/// its multidimensional indices (sgId), using the specified subgroup layout
42-
/// (sgLayout), subgroup data dimensions (sgShape), and the overall data
43-
/// dimensions (shape).
42+
/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
43+
/// dimensions (sizePerWg).
4444
static SmallVector<SmallVector<Value>>
45-
genOffsetsComputations(OpBuilder &builder, Location loc,
46-
SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
47-
ArrayRef<int64_t> sgShape, ArrayRef<int64_t> shape) {
45+
genOffsetsComputingInsts(OpBuilder &builder, Location loc,
46+
SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
47+
ArrayRef<int64_t> sizePerSg,
48+
ArrayRef<int64_t> sizePerWg) {
4849

4950
SmallVector<SmallVector<Value>> offsets;
5051

51-
// nd local offset, localOffset[i] = sgId[i] * sgShape[i]
52+
// nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
5253
SmallVector<Value> localOffsets = llvm::map_to_vector(
53-
llvm::zip(sgId, sgShape), [&](const auto &t) -> Value {
54+
llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
5455
return builder.createOrFold<index::MulOp>(
5556
loc, std::get<0>(t),
5657
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
5758
});
5859

59-
// distUnit[i] is the minimum value between shape[i] and
60-
// sgLayout[i] * sgShape[i]
60+
// distUnit[i] is the minimum value between sizePerWg[i] and
61+
// sgLayout[i] * sizePerSg[i]
6162
SmallVector<int64_t> distUnit = llvm::map_to_vector(
62-
llvm::zip_equal(shape, computeElementwiseMul(sgLayout, sgShape)),
63+
llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
6364
[](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
6465

65-
for (SmallVector<int64_t> unitOffs : StaticTileOffsetRange(shape, distUnit)) {
66+
for (SmallVector<int64_t> unitOffs :
67+
StaticTileOffsetRange(sizePerWg, distUnit)) {
6668
SmallVector<Value> base =
6769
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
6870
return builder.create<arith::ConstantIndexOp>(loc, d);
@@ -75,7 +77,7 @@ genOffsetsComputations(OpBuilder &builder, Location loc,
7577
});
7678

7779
SmallVector<Value> mods = llvm::map_to_vector(
78-
llvm::zip_equal(adds, shape), [&](const auto &t) -> Value {
80+
llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
7981
return builder.createOrFold<index::RemUOp>(
8082
loc, std::get<0>(t),
8183
builder.create<arith::ConstantIndexOp>(loc, std::get<1>(t)));
@@ -300,8 +302,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
300302
SmallVector<int64_t> sgShape;
301303
if (auto maybeSgShape = getSgDataAsInt())
302304
sgShape = maybeSgShape.value();
303-
else if (auto ratio = computeShapeRatio(shape, sgLayout))
304-
sgShape = ratio.value();
305+
else if (auto derivedShape = computeShapeRatio(shape, sgLayout))
306+
sgShape = derivedShape.value();
305307
else
306308
return failure();
307309

@@ -311,7 +313,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
311313
return failure();
312314
SmallVector<Value> sgIds = *maybeIds;
313315

314-
return genOffsetsComputations(builder, loc, sgIds, sgLayout, sgShape, shape);
316+
return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
317+
shape);
315318
}
316319

317320
//===----------------------------------------------------------------------===//
@@ -401,7 +404,8 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
401404
SmallVector<Value> sgIds =
402405
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
403406

404-
return genOffsetsComputations(builder, loc, sgIds, sgLayout, sgShape, shape);
407+
return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
408+
shape);
405409
}
406410

407411
//===----------------------------------------------------------------------===//

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
22

3+
#map = affine_map<()[s0] -> (s0 floordiv 4)>
4+
#map1 = affine_map<()[s0] -> (s0 mod 4)>
5+
36
gpu.module @test_round_robin_assignment {
47
// CHECK-LABEL: create_nd_tdesc
58
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -12,6 +15,30 @@ gpu.module @test_round_robin_assignment {
1215
gpu.return
1316
}
1417

18+
// CHECK-LABEL: create_nd_tdesc_with_shared_data
19+
// CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
20+
gpu.func @create_nd_tdesc_with_shared_data(%src: memref<256x128xf32>) {
21+
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
22+
//CHECK: [[IdY:%.+]] = affine.apply #map()[[[sgId]]]
23+
//CHECK: [[IdX:%.+]] = affine.apply #map1()[[[sgId]]]
24+
//CHECK: [[C16:%.+]] = arith.constant 16 : index
25+
//CHECK: [[LY:%.+]] = index.mul [[IdY]], [[C16]]
26+
//CHECK: [[C64:%.+]] = arith.constant 64 : index
27+
//CHECK: [[LX:%.+]] = index.mul [[IdX]], [[C64]]
28+
//CHECK: [[C0:%.+]] = arith.constant 0 : index
29+
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
30+
//CHECK: [[ADDY:%.+]] = arith.addi [[LY]], [[C0]] : index
31+
//CHECK: [[ADDX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
32+
//CHECK: [[C128:%.+]] = arith.constant 128 : index
33+
//CHECK: [[offY:%.+]] = index.remu [[ADDY]], [[C128]]
34+
//CHECK: [[C128_2:%.+]] = arith.constant 128 : index
35+
//CHECK: [[offX:%.+]] = index.remu [[ADDX]], [[C128_2]]
36+
//CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
37+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
38+
-> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
39+
gpu.return
40+
}
41+
1542
// CHECK-LABEL: load_nd_tdesc
1643
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
1744
gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {

0 commit comments

Comments
 (0)