Skip to content

Commit bd6da1f

Browse files
authored
[mlir][xegpu] Add more tests in XeGPU subgroup distribution. (#162543)
This PR adds some tests for covering some useful corner cases. 1. more tests for `vector.shape_cast` distribution. 2. testing for `MoveFuncBodyToWarpOp` pattern that was not possible before.
1 parent a47cb9b commit bd6da1f

File tree

5 files changed

+186
-5
lines changed

5 files changed

+186
-5
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
6464

6565
/// Appends patterns for XeGPU SIMT distribution into `patterns`.
6666
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
67+
/// Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
68+
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns);
69+
/// Appends patterns for XeGPU workgroup to subgroup distribution into
70+
/// `patterns`.
6771
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns);
6872

6973
/// Collect a set of patterns to unroll xegpu operations to a smaller shapes.

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,7 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
195195
/// }
196196
/// return %0
197197
/// }
198-
struct MoveFuncBodyToWarpExecuteOnLane0
199-
: public OpRewritePattern<gpu::GPUFuncOp> {
198+
struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
200199
using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
201200
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
202201
PatternRewriter &rewriter) const override {
@@ -1447,6 +1446,11 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
14471446
/*pattern benefit=*/highPatternBenefit);
14481447
}
14491448

1449+
void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(
1450+
RewritePatternSet &patterns) {
1451+
patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext());
1452+
}
1453+
14501454
void XeGPUSubgroupDistributePass::runOnOperation() {
14511455
// Step 1: Attach layouts to op operands.
14521456
// TODO: Following assumptions are made:
@@ -1473,7 +1477,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
14731477
// gpu.warp_execute_on_lane_0 operation.
14741478
{
14751479
RewritePatternSet patterns(&getContext());
1476-
patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
1480+
xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(patterns);
14771481

14781482
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
14791483
signalPassFailure();
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// RUN: mlir-opt -test-xegpu-move-func-to-warp-op -split-input-file --allow-unregistered-dialect %s | FileCheck %s
2+
3+
gpu.module @test {
4+
gpu.func @empty() {
5+
gpu.return
6+
}
7+
}
8+
9+
// CHECK-LABEL: gpu.func @empty() {
10+
// CHECK-NEXT: gpu.return
11+
// CHECK-NEXT: }
12+
13+
// -----
14+
gpu.module @test {
15+
gpu.func @gemm(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
16+
%c0 = arith.constant 0 : index
17+
%0 = xegpu.create_nd_tdesc %arg0 : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
18+
%1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
19+
%2 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
20+
%3 = xegpu.load_nd %1[%c0, %c0] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
21+
%4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
22+
%5 = xegpu.create_nd_tdesc %arg2 : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
23+
xegpu.store_nd %4, %5[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
24+
gpu.return
25+
}
26+
}
27+
28+
// CHECK-LABEL: gpu.func @gemm(
29+
// CHECK: %[[ARG0:[a-zA-Z0-9]+]]: memref<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<16x16xf16>,
30+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<8x16xf32>) {
31+
// CHECK: %[[LANEID:.*]] = gpu.lane_id
32+
// CHECK-NEXT: gpu.warp_execute_on_lane_0(%[[LANEID]])[16]
33+
// CHECK-SAME: args(%[[ARG0]], %[[ARG1]], %[[ARG2]] : memref<8x16xf16>, memref<16x16xf16>, memref<8x16xf32>) {
34+
// CHECK: ^bb0(%[[ARG3:[a-zA-Z0-9]+]]: memref<8x16xf16>, %[[ARG4:[a-zA-Z0-9]+]]: memref<16x16xf16>,
35+
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: memref<8x16xf32>):
36+
// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG3]] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
37+
// CHECK-NEXT: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG4]] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
38+
// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[T1]][{{.*}}] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
39+
// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[T2]][{{.*}}] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
40+
// CHECK-NEXT: %[[T5:.*]] = xegpu.dpas %[[T3]], %[[T4]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
41+
// CHECK-NEXT: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG5]] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
42+
// CHECK-NEXT: xegpu.store_nd %[[T5]], %[[T6]][%{{.*}}] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
43+
// CHECK-NEXT: }
44+
// CHECK-NEXT: gpu.return
45+
46+
// -----
47+
gpu.module @test {
48+
gpu.func @already_in_warp_op() {
49+
%laneid = gpu.lane_id
50+
gpu.warp_execute_on_lane_0(%laneid)[16] {
51+
"some_op"() : () -> ()
52+
gpu.yield
53+
}
54+
gpu.return
55+
}
56+
}
57+
58+
// CHECK-LABEL: gpu.func @already_in_warp_op() {
59+
// CHECK: %[[LANEID:.*]] = gpu.lane_id
60+
// CHECK: gpu.warp_execute_on_lane_0(%[[LANEID]])[16] {
61+
// CHECK: "some_op"() : () -> ()
62+
// CHECK: }
63+
// CHECK: gpu.return

mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ gpu.module @xevm_module{
530530
// CHECK-NEXT: }
531531
// CHECK-NEXT: %[[T1:.*]] = vector.transpose %[[W]]#1, [1, 0] : vector<1x2xf32> to vector<2x1xf32>
532532
gpu.module @xevm_module{
533-
gpu.func @vector_transpose(%arg0: memref<2x16xf32>, %laneid: index) {
533+
gpu.func @vector_transpose(%laneid: index) {
534534
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) {
535535
%cst = "some_op"()
536536
{layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
@@ -556,7 +556,7 @@ gpu.module @xevm_module{
556556
// CHECK: }
557557
// CHECK: vector.bitcast %[[W]]#1 : vector<4x2xi8> to vector<4x1xi16>
558558
gpu.module @xevm_module{
559-
gpu.func @vector_bitcast(%arg0: memref<4x16xi16>, %laneid: index) {
559+
gpu.func @vector_bitcast(%laneid: index) {
560560
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) {
561561
%cst = "some_op"()
562562
{layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
@@ -573,3 +573,82 @@ gpu.module @xevm_module{
573573
gpu.return
574574
}
575575
}
576+
577+
// -----
578+
// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing
579+
// CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) {
580+
// CHECK: gpu.yield %{{.*}} : vector<1x16xf32>, vector<16xf32>
581+
// CHECK: }
582+
// CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1xf32> to vector<1x1xf32>
583+
gpu.module @xevm_module {
584+
gpu.func @vector_shapecast_rank_increasing(%laneid: index) {
585+
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
586+
%cst = "some_op"()
587+
{layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
588+
: () -> (vector<16xf32>)
589+
%cast = vector.shape_cast %cst
590+
{
591+
layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
592+
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
593+
}
594+
: vector<16xf32> to vector<1x16xf32>
595+
gpu.yield %cast : vector<1x16xf32>
596+
}
597+
"some_user_op"(%r) : (vector<1x1xf32>) -> ()
598+
gpu.return
599+
}
600+
}
601+
602+
// -----
603+
// CHECK-LABEL: gpu.func @vector_shapecast_rank_reducing(
604+
// CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1xf32>, vector<1x1xf32>) {
605+
// CHECK: gpu.yield %{{.*}} : vector<16xf32>, vector<1x16xf32>
606+
// CHECK: }
607+
// CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1x1xf32> to vector<1xf32>
608+
gpu.module @xevm_module {
609+
gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
610+
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
611+
%cst = "some_op"()
612+
{layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
613+
: () -> (vector<1x16xf32>)
614+
%cast = vector.shape_cast %cst
615+
{
616+
layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
617+
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
618+
}
619+
: vector<1x16xf32> to vector<16xf32>
620+
gpu.yield %cast : vector<16xf32>
621+
}
622+
"some_user_op"(%r) : (vector<1xf32>) -> ()
623+
gpu.return
624+
}
625+
}
626+
627+
// -----
628+
// NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand.
629+
//
630+
// CHECK-LABEL: gpu.func @vector_shapecast_unsupported
631+
// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>) {
632+
// CHECK: %[[T1:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<1x16xf32>
633+
// CHECK: gpu.yield %[[T1]] : vector<1x16xf32>
634+
// CHECK: }
635+
// CHECK: "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> ()
636+
// CHECK: gpu.return
637+
gpu.module @xevm_module {
638+
gpu.func @vector_shapecast_unsupported(%laneid: index) {
639+
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
640+
%cst = "some_op"()
641+
{layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
642+
: () -> (vector<16xf32>)
643+
%cast = vector.shape_cast %cst
644+
{
645+
layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
646+
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
647+
}
648+
: vector<16xf32> to vector<1x16xf32>
649+
gpu.yield %cast : vector<1x16xf32>
650+
}
651+
"some_user_op"(%r) : (vector<1x1xf32>) -> ()
652+
gpu.return
653+
}
654+
}

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,36 @@ struct TestXeGPUSGDistribute
247247
}
248248
};
249249

250+
struct TestXeGPUMoveFuncBodyToWarpOp
251+
: public PassWrapper<TestXeGPUMoveFuncBodyToWarpOp,
252+
OperationPass<gpu::GPUModuleOp>> {
253+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUMoveFuncBodyToWarpOp)
254+
255+
StringRef getArgument() const final {
256+
return "test-xegpu-move-func-to-warp-op";
257+
}
258+
259+
StringRef getDescription() const final {
260+
return "Test the implementation of XeGPU move gpu function body to "
261+
"WarpExecuteOnLane0 op.";
262+
}
263+
264+
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
265+
registry.insert<xegpu::XeGPUDialect>();
266+
registry.insert<gpu::GPUDialect>();
267+
}
268+
269+
TestXeGPUMoveFuncBodyToWarpOp() = default;
270+
TestXeGPUMoveFuncBodyToWarpOp(const TestXeGPUMoveFuncBodyToWarpOp &pass) =
271+
default;
272+
273+
void runOnOperation() override {
274+
RewritePatternSet patterns(&getContext());
275+
xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(patterns);
276+
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
277+
}
278+
};
279+
250280
struct TestXeGPULayoutInterface
251281
: public PassWrapper<TestXeGPULayoutInterface,
252282
OperationPass<gpu::GPUModuleOp>> {
@@ -312,6 +342,7 @@ void registerTestXeGPULowerings() {
312342
PassRegistration<TestXeGPUUnrollingPatterns>();
313343
PassRegistration<TestXeGPULayoutInterface>();
314344
PassRegistration<TestXeGPUSGDistribute>();
345+
PassRegistration<TestXeGPUMoveFuncBodyToWarpOp>();
315346
}
316347
} // namespace test
317348
} // namespace mlir

0 commit comments

Comments
 (0)