Skip to content

Commit ca5d902

Browse files
committed
save work
1 parent e9211c8 commit ca5d902

File tree

2 files changed

+246
-63
lines changed

2 files changed

+246
-63
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp

Lines changed: 202 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,23 @@
99
#include "mlir/Dialect/Arith/IR/Arith.h"
1010
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1111
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
12+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1213
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1314
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1415
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1516
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
1617
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
1718
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
19+
#include "mlir/IR/BuiltinAttributeInterfaces.h"
20+
#include "mlir/IR/BuiltinTypes.h"
1821
#include "mlir/IR/OpDefinition.h"
1922
#include "mlir/IR/Types.h"
2023
#include "mlir/IR/Value.h"
2124
#include "mlir/Support/LLVM.h"
2225
#include "mlir/Transforms/DialectConversion.h"
2326
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27+
#include "llvm/ADT/SmallVector.h"
28+
#include <algorithm>
2429
#include <optional>
2530

2631
namespace mlir {
@@ -35,12 +40,14 @@ namespace xegpu {
3540

3641
using namespace mlir;
3742

38-
struct TransposableBlockRange {
39-
int minWidth, maxWidth, minHeight, maxHeight;
43+
namespace {
44+
45+
struct Allowed2DShapeRange {
46+
int64_t minWidth, maxWidth, minHeight, maxHeight;
4047
};
4148

4249
// TODO: Use uArch to get supported block ranges.
43-
static TransposableBlockRange getBlockRange(int bitWidth) {
50+
static Allowed2DShapeRange getTransposableBlockRange(int bitWidth) {
4451
switch (bitWidth) {
4552
case 32:
4653
return {/**min width**/ 1, /**max width**/ 8, /**min height**/ 1,
@@ -50,10 +57,8 @@ static TransposableBlockRange getBlockRange(int bitWidth) {
5057
}
5158
}
5259

53-
namespace {
54-
5560
static std::optional<SmallVector<int64_t>>
56-
get2DLaneData(xegpu::TensorDescType tdescType) {
61+
getMaybeLaneData(xegpu::TensorDescType tdescType) {
5762
auto layout = tdescType.getLayoutAttr();
5863
if (!layout)
5964
return std::nullopt;
@@ -63,44 +68,131 @@ get2DLaneData(xegpu::TensorDescType tdescType) {
6368
return laneData;
6469
}
6570

71+
static std::optional<SmallVector<int64_t>>
72+
getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
73+
auto layout = tdescType.getLayoutAttr();
74+
if (!layout)
75+
return std::nullopt;
76+
auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
77+
if (laneLayout.size() != 2)
78+
return std::nullopt;
79+
return laneLayout;
80+
}
81+
82+
// A transpose layout is invalid if lane layout is transposed (lane[0] != 1 &&
83+
// lane[1] == 1), but inner lane data is not equal to [1, 1].
84+
static bool hasInvalidTranposeLayout(xegpu::TensorDescType tdescType) {
85+
// If the dtype is greater or equal to 32 bits, layout must be valid.
86+
int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
87+
if (elementTyBitwidth >= 32)
88+
return false;
89+
auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
90+
auto maybeLaneData = getMaybeLaneData(tdescType);
91+
if (!maybeLaneData || !maybeLaneLayout)
92+
return false;
93+
auto laneData = maybeLaneData.value();
94+
auto laneLayout = maybeLaneLayout.value();
95+
if (laneLayout[0] == 1 || laneLayout[1] != 1)
96+
return false;
97+
if (laneData[0] != 1 || laneData[1] == 1)
98+
return false;
99+
return true;
100+
}
101+
66102
static xegpu::TensorDescType
67-
getModifiedTensorDescType(xegpu::TensorDescType tdescType) {
68-
auto optionalLaneData = get2DLaneData(tdescType);
69-
if (!optionalLaneData)
103+
tryConvertToTransposable(xegpu::TensorDescType tdescType) {
104+
if (!hasInvalidTranposeLayout(tdescType))
70105
return tdescType;
71-
auto laneData = optionalLaneData.value();
106+
auto laneData = getMaybeLaneData(tdescType).value();
72107
int64_t innerLaneData = laneData[1];
73-
if (laneData[0] == 1 && innerLaneData != 1) {
74-
int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
75-
assert(elementTyBitwidth < 32 &&
76-
"Expected element type bitwidth < 32 with laneData[1] != 1");
77-
SmallVector<int64_t> newShape(tdescType.getShape());
78-
newShape.back() = newShape.back() / innerLaneData;
79-
Type newElemTy = IntegerType::get(tdescType.getContext(),
80-
elementTyBitwidth * innerLaneData);
81-
xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
82-
tdescType.getContext(),
83-
tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
84-
return xegpu::TensorDescType::get(
85-
newShape, newElemTy, tdescType.getArrayLength(),
86-
tdescType.getBoundaryCheck(), tdescType.getMemorySpace(), newLayout);
87-
}
88-
return tdescType;
108+
int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
109+
// Required shape is total shape of the vector result that this tensor desc
110+
// must eventually load after adjusting for the new bitwidth and array
111+
// length.
112+
SmallVector<int64_t> requiredShape(tdescType.getShape());
113+
requiredShape.back() =
114+
requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
115+
int newBitWidth = elementTyBitwidth * innerLaneData;
116+
Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
117+
// Supported shape is the max transpose shape that can be supported by
118+
// hardware that is less than or equal to required shape.
119+
auto supportedHeight = std::min(
120+
requiredShape[0], getTransposableBlockRange(newBitWidth).maxHeight);
121+
auto supportedWidth = std::min(
122+
requiredShape[1], getTransposableBlockRange(newBitWidth).maxWidth);
123+
SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth};
124+
125+
// Required shape must be multiple of supported shape. Otherwise, we can not
126+
// optimize it.
127+
// TODO: Supported shape can be adjusted to handle non-multiple cases.
128+
if (requiredShape[0] % supportedShape[0] != 0 ||
129+
requiredShape[1] % supportedShape[1] != 0)
130+
return tdescType;
131+
132+
xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
133+
tdescType.getContext(),
134+
tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
135+
// Array length can not be larger than 1 for transpose case.
136+
return xegpu::TensorDescType::get(
137+
supportedShape, newElemTy, /**array length**/ 1,
138+
tdescType.getBoundaryCheck(), tdescType.getMemorySpace(), newLayout);
139+
}
140+
141+
static Value createConstantIndex(ConversionPatternRewriter &rewriter,
142+
Location loc, int64_t value) {
143+
return arith::ConstantIndexOp::create(rewriter, loc, value).getResult();
89144
}
90145

91146
static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
92147
OpFoldResult ofr) {
93148
std::optional<int64_t> mayBeInt = getConstantIntValue(ofr);
94149
if (mayBeInt)
95-
return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt);
150+
return createConstantIndex(rewriter, loc, *mayBeInt);
96151
return llvm::cast<Value>(ofr);
97152
}
98153

99154
static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
100155
Value val, int64_t constant) {
101-
auto constantOp = arith::ConstantIndexOp::create(rewriter, loc, constant);
102-
return arith::DivUIOp::create(rewriter, loc, val, constantOp.getResult())
103-
.getResult();
156+
auto constantOp = createConstantIndex(rewriter, loc, constant);
157+
return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
158+
}
159+
160+
static Value generateLoads(ConversionPatternRewriter &rewriter,
161+
TypedValue<VectorType> data,
162+
SmallVector<int64_t> &shapeRatio,
163+
SmallVector<OpFoldResult> offsets,
164+
SmallVector<int64_t> &supportedShape,
165+
TypedValue<xegpu::TensorDescType> newTensorDesc,
166+
xegpu::LoadNdOp origLoadOp) {
167+
Location loc = data.getLoc();
168+
assert(offsets.size() >= 2 && "Expecting at least 2 offsets for 2D LoadNdOp");
169+
Value offsetX = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
170+
Value offsetY = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
171+
for (int64_t h = 0; h < shapeRatio[0]; ++h) {
172+
for (int64_t w = 0; w < shapeRatio[1]; ++w) {
173+
int64_t localOffsetX = h * supportedShape[0];
174+
int64_t localOffsetY = w * supportedShape[1];
175+
Value loadOffsetX = arith::AddIOp::create(
176+
rewriter, loc, offsetX,
177+
createConstantIndex(rewriter, loc, localOffsetX));
178+
Value loadOffsetY = arith::AddIOp::create(
179+
rewriter, loc, offsetY,
180+
createConstantIndex(rewriter, loc, localOffsetY));
181+
auto loadOp = xegpu::LoadNdOp::create(
182+
rewriter, loc,
183+
VectorType::get(supportedShape, data.getType().getElementType()),
184+
newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
185+
origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
186+
origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
187+
origLoadOp.getL3HintAttr());
188+
// Insert the loaded block into the right position in data.
189+
data = vector::InsertStridedSliceOp::create(
190+
rewriter, loc, loadOp.getResult(), data,
191+
ArrayRef<int64_t>{localOffsetX, localOffsetY},
192+
ArrayRef<int64_t>{1, 1});
193+
}
194+
}
195+
return data;
104196
}
105197

106198
class XeGPUCreateNdDescOpPattern final
@@ -111,7 +203,7 @@ class XeGPUCreateNdDescOpPattern final
111203
matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
112204
ConversionPatternRewriter &rewriter) const override {
113205
auto tdescTy = createNdOp.getType();
114-
auto convertType = getModifiedTensorDescType(tdescTy);
206+
auto convertType = tryConvertToTransposable(tdescTy);
115207
if (convertType == tdescTy)
116208
return failure();
117209
auto strides = createNdOp.getMixedStrides();
@@ -120,7 +212,7 @@ class XeGPUCreateNdDescOpPattern final
120212
if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
121213
return failure();
122214
Value source = createNdOp.getSource();
123-
auto optionalLaneData = get2DLaneData(tdescTy);
215+
auto optionalLaneData = getMaybeLaneData(tdescTy);
124216
assert(optionalLaneData && "Expected 2D lane data");
125217
auto laneData = optionalLaneData.value();
126218
int64_t innerLaneData = laneData[1];
@@ -160,7 +252,6 @@ class XeGPUCreateNdDescOpPattern final
160252
return success();
161253
}
162254
};
163-
} // namespace
164255

165256
class XeGPULoadNdDescOpPattern final
166257
: public OpConversionPattern<xegpu::LoadNdOp> {
@@ -175,9 +266,8 @@ class XeGPULoadNdDescOpPattern final
175266
if (adaptorType == origTensorDescType)
176267
return failure();
177268
// Offsets must be adjusted based on innerLaneData.
178-
auto optionalLaneData = get2DLaneData(loadNdOp.getTensorDescType());
179-
assert(optionalLaneData && "Expected 2D lane data");
180-
int64_t innerLaneData = optionalLaneData.value()[1];
269+
auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
270+
int64_t innerLaneData = laneData[1];
181271
auto offsets = loadNdOp.getMixedOffsets();
182272
if (offsets.empty())
183273
return rewriter.notifyMatchFailure(loadNdOp,
@@ -187,25 +277,82 @@ class XeGPULoadNdDescOpPattern final
187277
rewriter, loadNdOp.getLoc(),
188278
convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
189279
innerLaneData);
190-
VectorType modifiedType =
191-
VectorType::get(adaptorType.getShape(), adaptorType.getElementType());
192-
// Create a new LoadNdOp with modified offsets and type.
193-
auto newLoadNdOp = xegpu::LoadNdOp::create(
194-
rewriter, loadNdOp->getLoc(), modifiedType, adaptor.getTensorDesc(),
195-
modifiedOffsets, loadNdOp.getPackedAttr(), loadNdOp.getTransposeAttr(),
196-
loadNdOp.getL1HintAttr(), loadNdOp.getL2HintAttr(),
197-
loadNdOp.getL3HintAttr());
198-
// Bitcast back to the original type.
280+
// Get the 2D data shape of this loadNdOp in its original type including
281+
// array length.
282+
SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
283+
// origDataShape.back() *= origTensorDescType.getArrayLength();
284+
// Adjust the data shape based on innerLaneData.
285+
origDataShape.back() /= innerLaneData;
286+
// HW supported shape is the new tensor desc shape after conversion.
287+
SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
288+
// Shape ratio is 2D and, it describes how many blocks need to be loaded in
289+
// HW supported shape to cover the original shape.
290+
auto ratio = computeShapeRatio(origDataShape, hwSupportedShape)
291+
.value(); // ratio must be defined if we reach here.
292+
// Create a zero-initialized vector to hold all loaded blocks.
293+
// TypedAttr zeroAttr = rewriter.getZeroAttr(adaptorType.getElementType());
294+
VectorType origVectorType =
295+
VectorType::get(origDataShape, adaptorType.getElementType());
296+
Value data;
297+
// Orig data shape is 3D for the array length case.
298+
if (origTensorDescType.getArrayLength() > 1) {
299+
SmallVector<int64_t> arrayLenDataShape(origDataShape);
300+
arrayLenDataShape.insert(arrayLenDataShape.begin(),
301+
origTensorDescType.getArrayLength());
302+
auto arrayLenVecType =
303+
VectorType::get(arrayLenDataShape, adaptorType.getElementType());
304+
data = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
305+
arrayLenVecType,
306+
rewriter.getZeroAttr(arrayLenVecType));
307+
for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
308+
Value slice = arith::ConstantOp::create(
309+
rewriter, loadNdOp->getLoc(),
310+
VectorType::get(origDataShape, adaptorType.getElementType()),
311+
rewriter.getZeroAttr(origVectorType));
312+
// Increse the Y offset for each array slice.
313+
Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
314+
modifiedOffsets.back());
315+
modifiedOffsets.back() =
316+
arith::AddIOp::create(rewriter, loadNdOp->getLoc(), offsetY,
317+
createConstantIndex(rewriter,
318+
loadNdOp->getLoc(),
319+
i * origDataShape[1]))
320+
.getResult();
321+
slice = generateLoads(
322+
rewriter, cast<TypedValue<VectorType>>(slice), ratio,
323+
modifiedOffsets, hwSupportedShape,
324+
cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
325+
loadNdOp);
326+
// Insert slice to data.
327+
data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
328+
data, ArrayRef<int64_t>{i});
329+
}
330+
// Cast back to the original type and replace all uses.
331+
data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
332+
loadNdOp.getType(), data);
333+
rewriter.replaceOp(loadNdOp, data);
334+
return success();
335+
}
336+
data = arith::ConstantOp::create(
337+
rewriter, loadNdOp->getLoc(),
338+
VectorType::get(origDataShape, adaptorType.getElementType()),
339+
rewriter.getZeroAttr(origVectorType));
340+
data = generateLoads(
341+
rewriter, cast<TypedValue<VectorType>>(data), ratio, modifiedOffsets,
342+
hwSupportedShape,
343+
cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
344+
loadNdOp);
199345
auto castOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
200-
loadNdOp.getType(), newLoadNdOp);
201-
// Cast op must have the same layout as the original LoadNdOp result.
202-
xegpu::setDistributeLayoutAttr(
203-
castOp->getOpResult(0),
204-
xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
205-
rewriter.replaceOp(loadNdOp, castOp.getResult());
346+
loadNdOp.getType(), data);
347+
// // Cast op must have the same layout as the original LoadNdOp result.
348+
// xegpu::setDistributeLayoutAttr(
349+
// castOp->getOpResult(0),
350+
// xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
351+
rewriter.replaceOp(loadNdOp, castOp);
206352
return success();
207353
}
208354
};
355+
} // namespace
209356

210357
void xegpu::populateXeGPUOptimizeTransposePatterns(
211358
RewritePatternSet &patterns) {
@@ -224,23 +371,15 @@ struct XeGPUOptimizeTransposePass final
224371
RewritePatternSet patterns(&context);
225372
ConversionTarget target(context);
226373

227-
auto checkValidInnerLaneData =
228-
[](std::optional<SmallVector<int64_t>> optionalLaneData) -> bool {
229-
if (!optionalLaneData)
230-
return true;
231-
auto laneData = optionalLaneData.value();
232-
return laneData[0] != 1 || laneData[1] == 1;
233-
};
234-
374+
// CreateNdDescOp and LoadNdOp with invalid transpose layout must be
375+
// converted.
235376
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
236377
[&](xegpu::CreateNdDescOp createNdOp) {
237-
auto optionalLaneData = get2DLaneData(createNdOp.getType());
238-
return checkValidInnerLaneData(optionalLaneData);
378+
return !hasInvalidTranposeLayout(createNdOp.getType());
239379
});
240380
target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
241381
[&](xegpu::LoadNdOp loadNdOp) {
242-
auto optionalLaneData = get2DLaneData(loadNdOp.getTensorDescType());
243-
return checkValidInnerLaneData(optionalLaneData);
382+
return !hasInvalidTranposeLayout(loadNdOp.getTensorDescType());
244383
});
245384
converter.addConversion([](Type type) { return type; });
246385

0 commit comments

Comments
 (0)