Skip to content

Commit 17fd7c8

Browse files
committed
add tests
1 parent 44e6ac4 commit 17fd7c8

File tree

2 files changed

+154
-44
lines changed

2 files changed

+154
-44
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,20 @@ getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
7979
return laneLayout;
8080
}
8181

82+
static bool canBeOptimized(ArrayRef<int64_t> laneLayout,
83+
ArrayRef<int64_t> laneData) {
84+
if (laneLayout.size() != 2 || laneData.size() != 2)
85+
return false;
86+
if (laneLayout[0] == 1 || laneLayout[1] != 1)
87+
return false;
88+
if (laneData[0] != 1 || laneData[1] == 1)
89+
return false;
90+
return true;
91+
}
92+
8293
// A transpose layout is invalid if lane layout is transposed (lane[0] != 1 &&
8394
// lane[1] == 1), but inner lane data is not equal to [1, 1].
84-
static bool hasInvalidTranposeLayout(xegpu::TensorDescType tdescType) {
95+
static bool canBeOptimized(xegpu::TensorDescType tdescType) {
8596
// If the dtype is greater or equal to 32 bits, layout must be valid.
8697
int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
8798
if (elementTyBitwidth >= 32)
@@ -90,18 +101,12 @@ static bool hasInvalidTranposeLayout(xegpu::TensorDescType tdescType) {
90101
auto maybeLaneData = getMaybeLaneData(tdescType);
91102
if (!maybeLaneData || !maybeLaneLayout)
92103
return false;
93-
auto laneData = maybeLaneData.value();
94-
auto laneLayout = maybeLaneLayout.value();
95-
if (laneLayout[0] == 1 || laneLayout[1] != 1)
96-
return false;
97-
if (laneData[0] != 1 || laneData[1] == 1)
98-
return false;
99-
return true;
104+
return canBeOptimized(*maybeLaneLayout, *maybeLaneData);
100105
}
101106

102107
static xegpu::TensorDescType
103108
tryConvertToTransposable(xegpu::TensorDescType tdescType) {
104-
if (!hasInvalidTranposeLayout(tdescType))
109+
if (!canBeOptimized(tdescType))
105110
return tdescType;
106111
auto laneData = getMaybeLaneData(tdescType).value();
107112
int64_t innerLaneData = laneData[1];
@@ -185,11 +190,17 @@ static Value generateLoads(ConversionPatternRewriter &rewriter,
185190
origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
186191
origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
187192
origLoadOp.getL3HintAttr());
193+
// Set the layout for the loadOp.
194+
auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
195+
xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr);
188196
// Insert the loaded block into the right position in data.
189-
data = vector::InsertStridedSliceOp::create(
197+
auto insertOp = vector::InsertStridedSliceOp::create(
190198
rewriter, loc, loadOp.getResult(), data,
191199
ArrayRef<int64_t>{localOffsetX, localOffsetY},
192200
ArrayRef<int64_t>{1, 1});
201+
// InsertOp must have the same layout as newTensorDesc.
202+
xegpu::setDistributeLayoutAttr(insertOp->getOpResult(0), layoutAttr);
203+
data = insertOp.getResult();
193204
}
194205
}
195206
return data;
@@ -288,28 +299,20 @@ class XeGPULoadNdDescOpPattern final
288299
// Shape ratio is 2D and, it describes how many blocks need to be loaded in
289300
// HW supported shape to cover the original shape.
290301
auto ratio = computeShapeRatio(origDataShape, hwSupportedShape)
291-
.value(); // ratio must be defined if we reach here.
302+
.value(); // `ratio` must be defined if we reach here.
292303
// Create a zero-initialized vector to hold all loaded blocks.
293304
// TypedAttr zeroAttr = rewriter.getZeroAttr(adaptorType.getElementType());
294305
VectorType origVectorType =
295306
VectorType::get(origDataShape, adaptorType.getElementType());
296307
Value data;
297308
// Orig data shape is 3D for the array length case.
298309
if (origTensorDescType.getArrayLength() > 1) {
299-
// SmallVector<int64_t> arrayLenDataShape(origDataShape);
300-
// arrayLenDataShape.insert(arrayLenDataShape.begin(),
301-
// origTensorDescType.getArrayLength());
302-
// auto arrayLenVecType =
303-
// VectorType::get(arrayLenDataShape, adaptorType.getElementType());
304-
// auto = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
305-
// arrayLenVecType,
306-
// rewriter.getZeroAttr(arrayLenVecType));
307310
SmallVector<Value> arraySlices;
308311
for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
309312
Value slice = arith::ConstantOp::create(
310313
rewriter, loadNdOp->getLoc(), origVectorType,
311314
rewriter.getZeroAttr(origVectorType));
312-
// Increse the Y offset for each array slice.
315+
// Increase the Y offset for each array slice.
313316
Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
314317
modifiedOffsets.back());
315318
modifiedOffsets.back() =
@@ -323,19 +326,16 @@ class XeGPULoadNdDescOpPattern final
323326
modifiedOffsets, hwSupportedShape,
324327
cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
325328
loadNdOp);
326-
// // Insert slice to data.
327-
// data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
328-
// data, ArrayRef<int64_t>{i});
329-
// Bitcast back to original load shape without array length.
329+
// BitCast back to original load shape without array length.
330330
auto bitcastType = VectorType::get(origTensorDescType.getShape(),
331331
origTensorDescType.getElementType());
332-
slice = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
333-
bitcastType, slice);
334-
arraySlices.push_back(slice);
332+
auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
333+
bitcastType, slice);
334+
// BitCastOp must have the same layout as the original loadNdOp.
335+
xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
336+
origTensorDescType.getLayoutAttr());
337+
arraySlices.push_back(bitCastOp.getResult());
335338
}
336-
// // Cast back to the original type and replace all uses.
337-
// data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
338-
// loadNdOp.getType(), data);
339339
rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
340340
return success();
341341
}
@@ -348,13 +348,12 @@ class XeGPULoadNdDescOpPattern final
348348
hwSupportedShape,
349349
cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
350350
loadNdOp);
351-
auto castOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
352-
loadNdOp.getType(), data);
353-
// // Cast op must have the same layout as the original LoadNdOp result.
354-
// xegpu::setDistributeLayoutAttr(
355-
// castOp->getOpResult(0),
356-
// xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
357-
rewriter.replaceOp(loadNdOp, castOp);
351+
auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
352+
loadNdOp.getType(), data);
353+
// BitCastOp must have the same layout as the original loadNdOp.
354+
xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
355+
origTensorDescType.getLayoutAttr());
356+
rewriter.replaceOp(loadNdOp, bitCastOp);
358357
return success();
359358
}
360359
};
@@ -402,15 +401,20 @@ struct XeGPUOptimizeTransposePass final
402401
// converted.
403402
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
404403
[&](xegpu::CreateNdDescOp createNdOp) {
405-
return !hasInvalidTranposeLayout(createNdOp.getType());
404+
return !canBeOptimized(createNdOp.getType());
406405
});
407406
target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
408407
[&](xegpu::LoadNdOp loadNdOp) {
409-
return !hasInvalidTranposeLayout(loadNdOp.getTensorDescType());
408+
return !canBeOptimized(loadNdOp.getTensorDescType());
410409
});
411410
target.addDynamicallyLegalOp<vector::ExtractOp>(
412411
[&](vector::ExtractOp extractOp) {
413-
return extractOp.getSourceVectorType().getRank() != 3;
412+
auto layout = xegpu::getDistributeLayoutAttr(extractOp.getResult());
413+
if (!layout)
414+
return true;
415+
auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
416+
auto laneData = layout.getEffectiveLaneDataAsInt();
417+
return !canBeOptimized(laneLayout, laneData);
414418
});
415419
converter.addConversion([](Type type) { return type; });
416420

0 commit comments

Comments
 (0)