Skip to content

Commit 3f73fda

Browse files
committed
clean up
1 parent e8b43fb commit 3f73fda

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
//===---- XeGPUInstructionlize.cpp -- XeGPU Instructionlize Pass ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
10+
11+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
12+
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
13+
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
14+
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
15+
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
16+
#include "mlir/Pass/Pass.h"
17+
#include "mlir/Pass/PassManager.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
20+
namespace mlir {
21+
namespace xegpu {
22+
#define GEN_PASS_DEF_XEGPUINSTRUCTIONLIZE
23+
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
24+
} // namespace xegpu
25+
} // namespace mlir
26+
27+
#define DEBUG_TYPE "xegpu-instructionlize"
28+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
29+
30+
using namespace mlir;
31+
32+
namespace {
33+
34+
/// Unroll XeGPU ops to their instruction-level representation.
35+
class XeGPUInstructionlizePass final
36+
: public xegpu::impl::XeGPUInstructionlizeBase<XeGPUInstructionlizePass> {
37+
public:
38+
void runOnOperation() override;
39+
40+
private:
41+
SmallVector<int64_t> getTileShape(TypedValue<ShapedType> value) const;
42+
std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
43+
bool needsUnroll(Operation *op) const;
44+
};
45+
} // namespace
46+
47+
SmallVector<int64_t>
48+
XeGPUInstructionlizePass::getTileShape(TypedValue<ShapedType> value) const {
49+
assert(value && "value must be non-null");
50+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value);
51+
if (layout && layout.isSgLayout()) {
52+
if (auto inst_data = layout.getInstData())
53+
return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
54+
}
55+
return llvm::to_vector(value.getType().getShape());
56+
}
57+
58+
std::optional<SmallVector<int64_t>>
59+
XeGPUInstructionlizePass::getTileShape(Operation *op) const {
60+
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
61+
return getTileShape(cast<TypedValue<ShapedType>>(op->getResult(0)));
62+
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
63+
return getTileShape(cast<TypedValue<ShapedType>>(op->getOperand(0)));
64+
if (isa<xegpu::StoreNdOp>(op))
65+
return getTileShape(cast<TypedValue<ShapedType>>(op->getOperand(1)));
66+
67+
if (isa<xegpu::DpasOp>(op)) {
68+
auto a = cast<TypedValue<ShapedType>>(op->getOperand(0));
69+
auto b = cast<TypedValue<ShapedType>>(op->getOperand(1));
70+
SmallVector<int64_t> aTileShape = getTileShape(a);
71+
SmallVector<int64_t> bTileShape = getTileShape(b);
72+
73+
if (aTileShape.size() != 2 || bTileShape.size() != 2)
74+
return std::nullopt;
75+
76+
// semantic check for A and B
77+
if (aTileShape[1] != bTileShape[0])
78+
return std::nullopt;
79+
80+
// semantic check for C
81+
if (op->getNumOperands() == 3) {
82+
auto c = cast<TypedValue<ShapedType>>(op->getOperand(2));
83+
SmallVector<int64_t> cTileShape = getTileShape(c);
84+
int64_t expectedShape[2] = {aTileShape[0], bTileShape[1]};
85+
if (!llvm::equal(cTileShape, expectedShape))
86+
return std::nullopt;
87+
}
88+
89+
return SmallVector<int64_t>({aTileShape[0], aTileShape[1], bTileShape[1]});
90+
}
91+
return std::nullopt;
92+
}
93+
94+
bool XeGPUInstructionlizePass::needsUnroll(Operation *op) const {
95+
for (Value opr : op->getOperands()) {
96+
if (auto value = dyn_cast<TypedValue<ShapedType>>(opr)) {
97+
auto tileShape = getTileShape(value);
98+
// the tile should have the same rank as the origial type
99+
if (tileShape.size() != static_cast<size_t>(value.getType().getRank()))
100+
return false;
101+
if (!llvm::equal(tileShape, value.getType().getShape()))
102+
return true;
103+
}
104+
}
105+
return false;
106+
}
107+
108+
void XeGPUInstructionlizePass::runOnOperation() {
109+
MLIRContext *ctx = &getContext();
110+
xegpu::UnrollOptions options;
111+
options.setFilterConstraint([&](Operation *op) -> LogicalResult {
112+
return needsUnroll(op) ? success() : failure();
113+
});
114+
115+
options.setNativeShapeFn(
116+
[&](Operation *op) -> std::optional<SmallVector<int64_t>> {
117+
return getTileShape(op);
118+
});
119+
120+
options.setUnrolledTypesFn(
121+
[&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> {
122+
Type elemTy = type.getElementType();
123+
Type newTy;
124+
125+
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
126+
newTy = xegpu::TensorDescType::get(
127+
ctx, tileShape, elemTy, tdescTy.getEncoding(),
128+
tdescTy.getLayoutAttr().dropInstData());
129+
else
130+
newTy = type.clone(tileShape, elemTy);
131+
132+
std::optional<SmallVector<int64_t>> ratio =
133+
computeShapeRatio(type.getShape(), tileShape);
134+
assert(ratio &&
135+
"The shape of the type must be a multiple of tileShape.");
136+
return SmallVector<Type>(computeProduct(*ratio), newTy);
137+
});
138+
139+
RewritePatternSet patterns(ctx);
140+
141+
populateXeGPUUnrollPatterns(patterns, options);
142+
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
143+
}

0 commit comments

Comments
 (0)