Skip to content

Commit f4c0c40

Browse files
authored
[mlir][xegpu] XeGPU alias ops folder pass (#88886)
Adds a pass that folds aliasing ops into XeGPU ops.
1 parent 78dca4a commit f4c0c40

File tree

11 files changed

+217
-0
lines changed

11 files changed

+217
-0
lines changed

mlir/docs/Passes.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,7 @@ This document describes the available MLIR passes and their contracts.
119119
## TOSA Dialect Passes
120120

121121
[include "TosaPasses.md"]
122+
123+
## XeGPU Dialect Passes
124+
125+
[include "XeGPUPasses.md"]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name XeGPU)
3+
add_public_tablegen_target(MLIRXeGPUPassIncGen)
4+
add_dependencies(mlir-headers MLIRXeGPUPassIncGen)
5+
6+
add_mlir_doc(Passes XeGPUPasses ./ -gen-pass-doc)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//===- Passes.h - XeGPU Patterns and Passes ---------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_H
10+
#define MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_H
11+
12+
#include "mlir/Pass/Pass.h"
13+
14+
namespace mlir {
15+
16+
namespace xegpu {
17+
18+
//===----------------------------------------------------------------------===//
19+
// Passes
20+
//===----------------------------------------------------------------------===//
21+
22+
#define GEN_PASS_DECL
23+
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
24+
25+
//===----------------------------------------------------------------------===//
26+
// Registration
27+
//===----------------------------------------------------------------------===//
28+
29+
#define GEN_PASS_REGISTRATION
30+
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
31+
32+
} // namespace xegpu
33+
} // namespace mlir
34+
35+
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_H
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===-- Passes.td - XeGPU transformation definition file ---*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
10+
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
11+
#define MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
12+
13+
include "mlir/Pass/PassBase.td"
14+
15+
def XeGPUFoldAliasOps : Pass<"xegpu-fold-alias-ops"> {
16+
let summary = "Fold alias ops into XeGPU ops";
17+
let description = [{
18+
The pass folds aliasing ops into XeGPU ops that they operate on the original
19+
source references.
20+
}];
21+
let dependentDialects = [
22+
"memref::MemRefDialect", "xegpu::XeGPUDialect"
23+
];
24+
}
25+
26+
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- Transforms.h - XeGPU Dialect transformations -------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
10+
#define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
11+
12+
namespace mlir {
13+
class RewritePatternSet;
14+
15+
namespace xegpu {
16+
17+
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
18+
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
19+
20+
} // namespace xegpu
21+
} // namespace mlir
22+
23+
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H

mlir/include/mlir/InitAllPasses.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
4646
#include "mlir/Dialect/Transform/Transforms/Passes.h"
4747
#include "mlir/Dialect/Vector/Transforms/Passes.h"
48+
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
4849
#include "mlir/Transforms/Passes.h"
4950

5051
#include <cstdlib>
@@ -92,6 +93,7 @@ inline void registerAllPasses() {
9293
arm_sme::registerArmSMEPasses();
9394
arm_sve::registerArmSVEPasses();
9495
emitc::registerEmitCPasses();
96+
xegpu::registerXeGPUPasses();
9597

9698
// Dialect pipelines
9799
bufferization::registerBufferizationPipelines();
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_dialect_library(MLIRXeGPUTransforms
2+
XeGPUFoldAliasOps.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
6+
7+
DEPENDS
8+
MLIRXeGPUPassIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRAffineUtils
12+
MLIRIR
13+
MLIRMemRefDialect
14+
MLIRXeGPUDialect
15+
MLIRPass
16+
MLIRTransforms
17+
)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===- XeGPUFoldAliasOps.cpp - XeGPU alias ops folders ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
10+
11+
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
12+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
13+
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
14+
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
15+
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17+
#include "llvm/Support/Debug.h"
18+
19+
namespace mlir {
20+
namespace xegpu {
21+
#define GEN_PASS_DEF_XEGPUFOLDALIASOPS
22+
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
23+
} // namespace xegpu
24+
} // namespace mlir
25+
26+
#define DEBUG_TYPE "xegpu-fold-alias-ops"
27+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
28+
29+
using namespace mlir;
30+
31+
namespace {
32+
/// Merges subview operation with xegpu.create_nd_tdesc operation.
33+
class XegpuCreateNdDescOpSubViewOpFolder final
34+
: public OpRewritePattern<xegpu::CreateNdDescOp> {
35+
public:
36+
using OpRewritePattern<xegpu::CreateNdDescOp>::OpRewritePattern;
37+
38+
LogicalResult matchAndRewrite(xegpu::CreateNdDescOp descOp,
39+
PatternRewriter &rewriter) const override;
40+
};
41+
} // namespace
42+
43+
LogicalResult XegpuCreateNdDescOpSubViewOpFolder::matchAndRewrite(
44+
xegpu::CreateNdDescOp descOp, PatternRewriter &rewriter) const {
45+
auto subViewOp = descOp.getSource().getDefiningOp<memref::SubViewOp>();
46+
47+
if (!subViewOp)
48+
return rewriter.notifyMatchFailure(descOp, "not a subview producer");
49+
if (!subViewOp.hasUnitStride())
50+
return rewriter.notifyMatchFailure(descOp, "requires unit strides");
51+
52+
SmallVector<Value> resolvedOffsets;
53+
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
54+
rewriter, descOp.getLoc(), subViewOp.getMixedOffsets(),
55+
subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
56+
descOp.getMixedOffsets(), resolvedOffsets);
57+
58+
rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
59+
descOp, descOp.getTensorDesc().getType(), subViewOp.getSource(),
60+
getAsOpFoldResult(resolvedOffsets));
61+
62+
return success();
63+
}
64+
65+
void xegpu::populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns) {
66+
patterns.add<XegpuCreateNdDescOpSubViewOpFolder>(patterns.getContext());
67+
}
68+
69+
namespace {
70+
71+
struct XeGPUFoldAliasOpsPass final
72+
: public xegpu::impl::XeGPUFoldAliasOpsBase<XeGPUFoldAliasOpsPass> {
73+
void runOnOperation() override;
74+
};
75+
76+
} // namespace
77+
78+
void XeGPUFoldAliasOpsPass::runOnOperation() {
79+
RewritePatternSet patterns(&getContext());
80+
xegpu::populateXeGPUFoldAliasOpsPatterns(patterns);
81+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
82+
}

0 commit comments

Comments
 (0)