Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);

/// Appends patterns for XeGPU SIMT distribution into `patterns`.
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
/// Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns);
/// Appends patterns for XeGPU workgroup to subgroup distribution into
/// `patterns`.
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns);

/// Collect a set of patterns to unroll xegpu operations to a smaller shapes.
Expand Down
10 changes: 7 additions & 3 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
/// }
/// return %0
/// }
struct MoveFuncBodyToWarpExecuteOnLane0
: public OpRewritePattern<gpu::GPUFuncOp> {
struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -1447,6 +1446,11 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
/*pattern benefit=*/highPatternBenefit);
}

void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(
RewritePatternSet &patterns) {
patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext());
}

void XeGPUSubgroupDistributePass::runOnOperation() {
// Step 1: Attach layouts to op operands.
// TODO: Following assumptions are made:
Expand All @@ -1473,7 +1477,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// gpu.warp_execute_on_lane_0 operation.
{
RewritePatternSet patterns(&getContext());
patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(patterns);

if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
Expand Down
63 changes: 63 additions & 0 deletions mlir/test/Dialect/XeGPU/move-gpu-func-to-warp-op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// RUN: mlir-opt -test-xegpu-move-func-to-warp-op -split-input-file --allow-unregistered-dialect %s | FileCheck %s

gpu.module @test {
gpu.func @empty() {
gpu.return
}
}

// CHECK-LABEL: gpu.func @empty() {
// CHECK-NEXT: gpu.return
// CHECK-NEXT: }

// -----
gpu.module @test {
gpu.func @gemm(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0 : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
%2 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%3 = xegpu.load_nd %1[%c0, %c0] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
%5 = xegpu.create_nd_tdesc %arg2 : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
xegpu.store_nd %4, %5[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}

// CHECK-LABEL: gpu.func @gemm(
// CHECK: %[[ARG0:[a-zA-Z0-9]+]]: memref<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<16x16xf16>,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<8x16xf32>) {
// CHECK: %[[LANEID:.*]] = gpu.lane_id
// CHECK-NEXT: gpu.warp_execute_on_lane_0(%[[LANEID]])[16]
// CHECK-SAME: args(%[[ARG0]], %[[ARG1]], %[[ARG2]] : memref<8x16xf16>, memref<16x16xf16>, memref<8x16xf32>) {
// CHECK: ^bb0(%[[ARG3:[a-zA-Z0-9]+]]: memref<8x16xf16>, %[[ARG4:[a-zA-Z0-9]+]]: memref<16x16xf16>,
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: memref<8x16xf32>):
// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG3]] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK-NEXT: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG4]] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[T1]][{{.*}}] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[T2]][{{.*}}] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
// CHECK-NEXT: %[[T5:.*]] = xegpu.dpas %[[T3]], %[[T4]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
// CHECK-NEXT: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG5]] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-NEXT: xegpu.store_nd %[[T5]], %[[T6]][%{{.*}}] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
// CHECK-NEXT: }
// CHECK-NEXT: gpu.return

// -----
gpu.module @test {
gpu.func @already_in_warp_op() {
%laneid = gpu.lane_id
gpu.warp_execute_on_lane_0(%laneid)[16] {
"some_op"() : () -> ()
gpu.yield
}
gpu.return
}
}

// CHECK-LABEL: gpu.func @already_in_warp_op() {
// CHECK: %[[LANEID:.*]] = gpu.lane_id
// CHECK: gpu.warp_execute_on_lane_0(%[[LANEID]])[16] {
// CHECK: "some_op"() : () -> ()
// CHECK: }
// CHECK: gpu.return
83 changes: 81 additions & 2 deletions mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ gpu.module @xevm_module{
// CHECK-NEXT: }
// CHECK-NEXT: %[[T1:.*]] = vector.transpose %[[W]]#1, [1, 0] : vector<1x2xf32> to vector<2x1xf32>
gpu.module @xevm_module{
gpu.func @vector_transpose(%arg0: memref<2x16xf32>, %laneid: index) {
gpu.func @vector_transpose(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
Expand All @@ -556,7 +556,7 @@ gpu.module @xevm_module{
// CHECK: }
// CHECK: vector.bitcast %[[W]]#1 : vector<4x2xi8> to vector<4x1xi16>
gpu.module @xevm_module{
gpu.func @vector_bitcast(%arg0: memref<4x16xi16>, %laneid: index) {
gpu.func @vector_bitcast(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
Expand All @@ -573,3 +573,82 @@ gpu.module @xevm_module{
gpu.return
}
}

// -----
// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing
// CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) {
// CHECK: gpu.yield %{{.*}} : vector<1x16xf32>, vector<16xf32>
// CHECK: }
// CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1xf32> to vector<1x1xf32>
gpu.module @xevm_module {
gpu.func @vector_shapecast_rank_increasing(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
: () -> (vector<16xf32>)
%cast = vector.shape_cast %cst
{
layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
: vector<16xf32> to vector<1x16xf32>
gpu.yield %cast : vector<1x16xf32>
}
"some_user_op"(%r) : (vector<1x1xf32>) -> ()
gpu.return
}
}

// -----
// CHECK-LABEL: gpu.func @vector_shapecast_rank_reducing(
// CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1xf32>, vector<1x1xf32>) {
// CHECK: gpu.yield %{{.*}} : vector<16xf32>, vector<1x16xf32>
// CHECK: }
// CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1x1xf32> to vector<1xf32>
gpu.module @xevm_module {
gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
: () -> (vector<1x16xf32>)
%cast = vector.shape_cast %cst
{
layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
}
: vector<1x16xf32> to vector<16xf32>
gpu.yield %cast : vector<16xf32>
}
"some_user_op"(%r) : (vector<1xf32>) -> ()
gpu.return
}
}

// -----
// NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand.
//
// CHECK-LABEL: gpu.func @vector_shapecast_unsupported
// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>) {
// CHECK: %[[T1:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<1x16xf32>
// CHECK: gpu.yield %[[T1]] : vector<1x16xf32>
// CHECK: }
// CHECK: "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> ()
// CHECK: gpu.return
gpu.module @xevm_module {
gpu.func @vector_shapecast_unsupported(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
: () -> (vector<16xf32>)
%cast = vector.shape_cast %cst
{
layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
: vector<16xf32> to vector<1x16xf32>
gpu.yield %cast : vector<1x16xf32>
}
"some_user_op"(%r) : (vector<1x1xf32>) -> ()
gpu.return
}
}
32 changes: 32 additions & 0 deletions mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -247,6 +248,36 @@ struct TestXeGPUSGDistribute
}
};

struct TestXeGPUMoveFuncBodyToWarpOp
: public PassWrapper<TestXeGPUMoveFuncBodyToWarpOp,
OperationPass<gpu::GPUModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUMoveFuncBodyToWarpOp)

StringRef getArgument() const final {
return "test-xegpu-move-func-to-warp-op";
}

StringRef getDescription() const final {
return "Test the implementation of XeGPU move gpu function body to "
"WarpExecuteOnLane0 op.";
}

void getDependentDialects(::mlir::DialectRegistry &registry) const override {
registry.insert<xegpu::XeGPUDialect>();
registry.insert<gpu::GPUDialect>();
}

TestXeGPUMoveFuncBodyToWarpOp() = default;
TestXeGPUMoveFuncBodyToWarpOp(const TestXeGPUMoveFuncBodyToWarpOp &pass) =
default;

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

struct TestXeGPULayoutInterface
: public PassWrapper<TestXeGPULayoutInterface,
OperationPass<gpu::GPUModuleOp>> {
Expand Down Expand Up @@ -312,6 +343,7 @@ void registerTestXeGPULowerings() {
PassRegistration<TestXeGPUUnrollingPatterns>();
PassRegistration<TestXeGPULayoutInterface>();
PassRegistration<TestXeGPUSGDistribute>();
PassRegistration<TestXeGPUMoveFuncBodyToWarpOp>();
}
} // namespace test
} // namespace mlir
Loading