Skip to content

Commit f8dda5c

Browse files
committed
fix comments and add unit tests
1 parent ee3802d commit f8dda5c

File tree

3 files changed

+56
-4
lines changed

3 files changed

+56
-4
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,21 @@ void XeGPUDialect::initialize() {
3131
>();
3232
}
3333

34+
// Checks if the given shape can be evenly distributed based on the layout
35+
// and data factors provided by the LayoutAttr.
3436
bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
3537
xegpu::LayoutAttr attr) {
3638
assert(attr && "Layout attribute is missing.");
3739

38-
auto getSubShapeOrNull =
40+
// Checks whether the given shape can be evenly distributed using the specified
41+
// layout and data attributes. If successful, it returns the work size for each
42+
// compute unit; otherwise, it returns `std::nullopt`. The work size per compute
43+
// unit is calculated as follows:
44+
// - If `data` is null: newShape[i] = shape[i] / layout[i]
45+
// - If `data` is not null: newShape[i] = data[i]
46+
// When round-robin distribution (`use_rr`) is enabled, `shape[i]` can be smaller
47+
// than `layout[i] * data[i]`, allowing multiple compute units to share the data.
48+
auto tryDistribute =
3949
[&](llvm::ArrayRef<int64_t> shape, DenseI32ArrayAttr layout,
4050
DenseI32ArrayAttr data,
4151
bool use_rr = true) -> std::optional<SmallVector<int64_t>> {
@@ -68,20 +78,20 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
6878

6979
// check the sgLayout and sgData
7080
auto maybeSgShape =
71-
getSubShapeOrNull(shape, attr.getSgLayout(), attr.getSgData());
81+
tryDistribute(shape, attr.getSgLayout(), attr.getSgData());
7282
if (!maybeSgShape)
7383
return false;
7484
auto sgShape = maybeSgShape.value();
7585

7686
// check InstData, it neither have layout nor need round-robin
7787
auto maybeInstShape =
78-
getSubShapeOrNull(sgShape, nullptr, attr.getInstData(), false);
88+
tryDistribute(sgShape, nullptr, attr.getInstData(), false);
7989
if (!maybeInstShape)
8090
return false;
8191
auto instShape = maybeInstShape.value();
8292

8393
// check LaneLayout and LaneData
84-
auto maybeLaneShape = getSubShapeOrNull(instShape, attr.getLaneLayout(),
94+
auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayout(),
8595
attr.getLaneData(), false);
8696
return maybeLaneShape.has_value();
8797
}

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,27 @@ func.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32, 3>) {
2929
return
3030
}
3131

32+
// -----
33+
func.func @test_create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
34+
// expected-error@+1 {{cannot distribute [128, 128] using #xegpu.layout<sg_layout = [4, 2], sg_data = [24, 48]>}}
35+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [24, 48]>>
36+
return
37+
}
38+
39+
// -----
40+
func.func @test_create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
41+
// expected-error@+1 {{cannot distribute [128, 128] using #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [24, 48]>}}
42+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [24, 48]>>
43+
return
44+
}
45+
46+
// -----
47+
func.func @test_create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
48+
// expected-error@+1 {{cannot distribute [128, 128] using #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [64, 32]>}}
49+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [64, 32]>>
50+
return
51+
}
52+
3253
// -----
3354
func.func @test_prefetch_nd_vc_1(%src: memref<24x32xf16>) {
3455
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>

mlir/test/Dialect/XeGPU/ops.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,27 @@ gpu.func @test_create_nd_tdesc_simt_6(%src: memref<24x32xf32>) {
9595
gpu.return
9696
}
9797

98+
// CHECK: gpu.func @test_create_nd_tdesc_subgroup_1(%[[arg0:.*]]: memref<128x128xf32>) {
99+
gpu.func @test_create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
100+
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64]>>
101+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64]>>
102+
gpu.return
103+
}
104+
105+
// CHECK: gpu.func @test_create_nd_tdesc_subgroup_2(%[[arg0:.*]]: memref<128x128xf32>) {
106+
gpu.func @test_create_nd_tdesc_subgroup_2(%src: memref<128x128xf32>) {
107+
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [8, 16]>>
108+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [8, 16]>>
109+
gpu.return
110+
}
111+
112+
// CHECK: gpu.func @test_create_nd_tdesc_subgroup_3(%[[arg0:.*]]: memref<128x128xf32>) {
113+
gpu.func @test_create_nd_tdesc_subgroup_3(%src: memref<128x128xf32>) {
114+
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
115+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
116+
gpu.return
117+
}
118+
98119
// CHECK: gpu.func @test_prefetch_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
99120
gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
100121
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>

0 commit comments

Comments
 (0)