Skip to content

Commit 9d0341d

Browse files
committed
add pass
1 parent 9c42b3e commit 9d0341d

File tree

4 files changed

+82
-1
lines changed

4 files changed

+82
-1
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,14 @@ def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> {
8080
"scf::SCFDialect", "ub::UBDialect", "vector::VectorDialect"];
8181
}
8282

83+
def XeGPUOptimizeTranspose : Pass<"xegpu-optimize-transpose"> {
84+
let summary = "Optimize XeGPU loadNd operations feeding into vector.transpose";
85+
let description = [{
86+
This pass rewrites XeGPU loadNd operations that feed into vector.transpose
87+
into more optimal forms to improve performance.
88+
}];
89+
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
90+
"vector::VectorDialect"];
91+
}
92+
8393
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD

mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ struct UnrollOptions {
6161

6262
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
6363
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
64-
64+
/// Appends patterns for optimizing transpose operations into `patterns`.
65+
void populateXeGPUOptimizeTransposePatterns(RewritePatternSet &patterns);
6566
/// Appends patterns for XeGPU SIMT distribution into `patterns`.
6667
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
6768
/// Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.

mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
66
XeGPUWgToSgDistribute.cpp
77
XeGPUPropagateLayout.cpp
88
XeGPUVectorLinearize.cpp
9+
XeGPUOptimizeTranspose.cpp
910

1011
ADDITIONAL_HEADER_DIRS
1112
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//===- XeGPUOptimizeTranspose.cpp - XeGPU optimize transpose ----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
10+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
11+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
12+
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
13+
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
14+
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
15+
#include "mlir/Transforms/DialectConversion.h"
16+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17+
18+
namespace mlir {
19+
namespace xegpu {
20+
#define GEN_PASS_DEF_XEGPUOPTIMIZETRANSPOSE
21+
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
22+
} // namespace xegpu
23+
} // namespace mlir
24+
25+
#define DEBUG_TYPE "xegpu-optimize-transpose"
26+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27+
28+
using namespace mlir;
29+
30+
namespace {
31+
32+
class XeGPULoadNdPattern final : public OpConversionPattern<xegpu::LoadNdOp> {
33+
public:
34+
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
35+
LogicalResult
36+
matchAndRewrite(xegpu::LoadNdOp loadOp, OpAdaptor adaptor,
37+
ConversionPatternRewriter &rewriter) const override {
38+
return success();
39+
}
40+
};
41+
} // namespace
42+
43+
void xegpu::populateXeGPUOptimizeTransposePatterns(
44+
RewritePatternSet &patterns) {
45+
patterns.add<XeGPULoadNdPattern>(patterns.getContext());
46+
}
47+
48+
namespace {
49+
50+
struct XeGPUOptimizeTransposePass final
51+
: public xegpu::impl::XeGPUOptimizeTransposeBase<
52+
XeGPUOptimizeTransposePass> {
53+
void runOnOperation() override {
54+
MLIRContext &context = getContext();
55+
TypeConverter converter;
56+
RewritePatternSet patterns(&context);
57+
ConversionTarget target(context);
58+
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
59+
target);
60+
xegpu::populateXeGPUOptimizeTransposePatterns(patterns);
61+
if (failed(applyPartialConversion(getOperation(), target,
62+
std::move(patterns)))) {
63+
DBGS() << "Optimize transpose pass failed.\n";
64+
return signalPassFailure();
65+
}
66+
}
67+
};
68+
69+
} // namespace

0 commit comments

Comments
 (0)