Skip to content

Commit 43c35be

Browse files
committed
add some tests
1 parent 76f7323 commit 43c35be

File tree

2 files changed

+264
-43
lines changed

2 files changed

+264
-43
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp

Lines changed: 155 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,19 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Arith/IR/Arith.h"
910
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1011
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
12+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1113
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1214
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1315
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
1416
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
1517
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
18+
#include "mlir/IR/OpDefinition.h"
19+
#include "mlir/IR/Types.h"
20+
#include "mlir/IR/Value.h"
21+
#include "mlir/Support/LLVM.h"
1622
#include "mlir/Transforms/DialectConversion.h"
1723
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1824
#include <optional>
@@ -42,6 +48,46 @@ get2DLaneData(xegpu::TensorDescType tdescType) {
4248
return laneData;
4349
}
4450

51+
static xegpu::TensorDescType
52+
getModifiedTensorDescType(xegpu::TensorDescType tdescType) {
53+
auto optionalLaneData = get2DLaneData(tdescType);
54+
if (!optionalLaneData)
55+
return tdescType;
56+
auto laneData = optionalLaneData.value();
57+
int64_t innerLaneData = laneData[1];
58+
if (laneData[0] == 1 && innerLaneData != 1) {
59+
int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
60+
assert(elementTyBitwidth < 32 &&
61+
"Expected element type bitwidth < 32 with laneData[1] != 1");
62+
SmallVector<int64_t> newShape(tdescType.getShape());
63+
newShape.back() = newShape.back() / innerLaneData;
64+
Type newElemTy = IntegerType::get(tdescType.getContext(),
65+
elementTyBitwidth * innerLaneData);
66+
xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
67+
tdescType.getContext(),
68+
tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
69+
return xegpu::TensorDescType::get(
70+
newShape, newElemTy, tdescType.getArrayLength(),
71+
tdescType.getBoundaryCheck(), tdescType.getMemorySpace(), newLayout);
72+
}
73+
return tdescType;
74+
}
75+
76+
static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
77+
OpFoldResult ofr) {
78+
std::optional<int64_t> mayBeInt = getConstantIntValue(ofr);
79+
if (mayBeInt)
80+
return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt);
81+
return llvm::cast<Value>(ofr);
82+
}
83+
84+
static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
85+
Value val, int64_t constant) {
86+
auto constantOp = arith::ConstantIndexOp::create(rewriter, loc, constant);
87+
return arith::DivUIOp::create(rewriter, loc, val, constantOp.getResult())
88+
.getResult();
89+
}
90+
4591
class XeGPUCreateNdDescOpPattern final
4692
: public OpConversionPattern<xegpu::CreateNdDescOp> {
4793
public:
@@ -50,17 +96,106 @@ class XeGPUCreateNdDescOpPattern final
5096
matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
5197
ConversionPatternRewriter &rewriter) const override {
5298
auto tdescTy = createNdOp.getType();
53-
auto convertType = this->getTypeConverter()->convertType(tdescTy);
99+
auto convertType = getModifiedTensorDescType(tdescTy);
54100
if (convertType == tdescTy)
55101
return failure();
102+
auto strides = createNdOp.getMixedStrides();
103+
auto maybeConstInnerStride = getConstantIntValue(strides.back());
104+
// Only row-major memrefs are expected for now.
105+
if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
106+
return failure();
107+
Value source = createNdOp.getSource();
108+
auto optionalLaneData = get2DLaneData(tdescTy);
109+
assert(optionalLaneData && "Expected 2D lane data");
110+
auto laneData = optionalLaneData.value();
111+
int64_t innerLaneData = laneData[1];
112+
auto memrefType = dyn_cast<MemRefType>(source.getType());
113+
// Inner dimension of the shape must be adjusted based on innerLaneData.
114+
SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes());
115+
modifiedShape.back() = divideByConstant(
116+
rewriter, createNdOp.getLoc(),
117+
convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
118+
innerLaneData);
119+
// Similarly, second to last stride must be adjusted.
120+
assert(strides.size() >= 2 &&
121+
"Expected at least 2 strides for CreateNdDescOp");
122+
SmallVector<OpFoldResult> modifiedStrides(strides);
123+
modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
124+
rewriter, createNdOp.getLoc(),
125+
convertToValue(rewriter, createNdOp.getLoc(),
126+
modifiedStrides[modifiedStrides.size() - 2]),
127+
innerLaneData);
128+
129+
// If the source is a static memref, we need to extract the pointer to
130+
// base address.
131+
if (memrefType && memrefType.hasStaticShape()) {
132+
auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
133+
rewriter, createNdOp.getLoc(), source);
134+
source = arith::IndexCastOp::create(
135+
rewriter, createNdOp.getLoc(),
136+
IntegerType::get(rewriter.getContext(), 64),
137+
extractOp.getResult())
138+
.getResult();
139+
}
140+
// Create a new CreateNdDescOp with the modified shape and converted type.
141+
auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
142+
rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
143+
modifiedStrides);
144+
rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
56145
return success();
57146
}
58147
};
59148
} // namespace
60149

150+
class XeGPULoadNdDescOpPattern final
151+
: public OpConversionPattern<xegpu::LoadNdOp> {
152+
public:
153+
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
154+
LogicalResult
155+
matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
156+
ConversionPatternRewriter &rewriter) const override {
157+
auto origTensorDescType = loadNdOp.getTensorDescType();
158+
auto adaptorType =
159+
cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
160+
if (adaptorType == origTensorDescType)
161+
return failure();
162+
// Offsets must be adjusted based on innerLaneData.
163+
auto optionalLaneData = get2DLaneData(loadNdOp.getTensorDescType());
164+
assert(optionalLaneData && "Expected 2D lane data");
165+
int64_t innerLaneData = optionalLaneData.value()[1];
166+
auto offsets = loadNdOp.getMixedOffsets();
167+
if (offsets.empty())
168+
return rewriter.notifyMatchFailure(loadNdOp,
169+
"Expecting offsets in LoadNd");
170+
SmallVector<OpFoldResult> modifiedOffsets(offsets);
171+
modifiedOffsets.back() = divideByConstant(
172+
rewriter, loadNdOp.getLoc(),
173+
convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
174+
innerLaneData);
175+
VectorType modifiedType =
176+
VectorType::get(adaptorType.getShape(), adaptorType.getElementType());
177+
// Create a new LoadNdOp with modified offsets and type.
178+
auto newLoadNdOp = xegpu::LoadNdOp::create(
179+
rewriter, loadNdOp->getLoc(), modifiedType, adaptor.getTensorDesc(),
180+
modifiedOffsets, loadNdOp.getPackedAttr(), loadNdOp.getTransposeAttr(),
181+
loadNdOp.getL1HintAttr(), loadNdOp.getL2HintAttr(),
182+
loadNdOp.getL3HintAttr());
183+
// Bitcast back to the original type.
184+
auto castOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
185+
loadNdOp.getType(), newLoadNdOp);
186+
// Cast op must have the same layout as the original LoadNdOp result.
187+
xegpu::setDistributeLayoutAttr(
188+
castOp->getOpResult(0),
189+
xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
190+
rewriter.replaceOp(loadNdOp, castOp.getResult());
191+
return success();
192+
}
193+
};
194+
61195
void xegpu::populateXeGPUOptimizeTransposePatterns(
62196
RewritePatternSet &patterns) {
63-
patterns.add<XeGPUCreateNdDescOpPattern>(patterns.getContext());
197+
patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern>(
198+
patterns.getContext());
64199
}
65200

66201
namespace {
@@ -74,41 +209,28 @@ struct XeGPUOptimizeTransposePass final
74209
RewritePatternSet patterns(&context);
75210
ConversionTarget target(context);
76211

212+
auto checkValidInnerLaneData =
213+
[](std::optional<SmallVector<int64_t>> optionalLaneData) -> bool {
214+
if (!optionalLaneData)
215+
return true;
216+
auto laneData = optionalLaneData.value();
217+
return laneData[0] != 1 || laneData[1] == 1;
218+
};
219+
77220
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
78-
[](xegpu::CreateNdDescOp createNdOp) {
221+
[&](xegpu::CreateNdDescOp createNdOp) {
79222
auto optionalLaneData = get2DLaneData(createNdOp.getType());
80-
if (!optionalLaneData)
81-
return true;
82-
auto laneData = optionalLaneData.value();
83-
return laneData[0] != 1 || laneData[1] == 1;
223+
return checkValidInnerLaneData(optionalLaneData);
84224
});
225+
target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
226+
[&](xegpu::LoadNdOp loadNdOp) {
227+
auto optionalLaneData = get2DLaneData(loadNdOp.getTensorDescType());
228+
return checkValidInnerLaneData(optionalLaneData);
229+
});
230+
converter.addConversion([](Type type) { return type; });
85231

86-
converter.addConversion([](xegpu::TensorDescType tdescType) {
87-
auto optionalLaneData = get2DLaneData(tdescType);
88-
if (!optionalLaneData)
89-
return tdescType;
90-
auto laneData = optionalLaneData.value();
91-
int64_t innerLaneData = laneData[1];
92-
if (laneData[0] == 1 && innerLaneData != 1) {
93-
int elementTyBitwidth =
94-
tdescType.getElementType().getIntOrFloatBitWidth();
95-
assert(elementTyBitwidth < 32 &&
96-
"Expected element type bitwidth < 32 with laneData[1] != 1");
97-
SmallVector<int64_t> newShape(tdescType.getShape());
98-
newShape.back() = newShape.back() / innerLaneData;
99-
Type newElemTy = IntegerType::get(tdescType.getContext(),
100-
elementTyBitwidth * innerLaneData);
101-
xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
102-
tdescType.getContext(),
103-
tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
104-
return xegpu::TensorDescType::get(
105-
newShape, newElemTy, tdescType.getArrayLength(),
106-
tdescType.getBoundaryCheck(), tdescType.getMemorySpace(),
107-
newLayout);
108-
}
109-
return tdescType;
110-
});
111-
232+
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
233+
vector::VectorDialect>();
112234
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
113235
target);
114236
xegpu::populateXeGPUOptimizeTransposePatterns(patterns);

0 commit comments

Comments
 (0)