Skip to content

Commit 35ca92b

Browse files
committed
working version
1 parent ca5d902 commit 35ca92b

File tree

1 file changed

+50
-19
lines changed

1 file changed

+50
-19
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -296,18 +296,18 @@ class XeGPULoadNdDescOpPattern final
296296
Value data;
297297
// Orig data shape is 3D for the array length case.
298298
if (origTensorDescType.getArrayLength() > 1) {
299-
SmallVector<int64_t> arrayLenDataShape(origDataShape);
300-
arrayLenDataShape.insert(arrayLenDataShape.begin(),
301-
origTensorDescType.getArrayLength());
302-
auto arrayLenVecType =
303-
VectorType::get(arrayLenDataShape, adaptorType.getElementType());
304-
data = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
305-
arrayLenVecType,
306-
rewriter.getZeroAttr(arrayLenVecType));
299+
// SmallVector<int64_t> arrayLenDataShape(origDataShape);
300+
// arrayLenDataShape.insert(arrayLenDataShape.begin(),
301+
// origTensorDescType.getArrayLength());
302+
// auto arrayLenVecType =
303+
// VectorType::get(arrayLenDataShape, adaptorType.getElementType());
304+
// auto = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
305+
// arrayLenVecType,
306+
// rewriter.getZeroAttr(arrayLenVecType));
307+
SmallVector<Value> arraySlices;
307308
for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
308309
Value slice = arith::ConstantOp::create(
309-
rewriter, loadNdOp->getLoc(),
310-
VectorType::get(origDataShape, adaptorType.getElementType()),
310+
rewriter, loadNdOp->getLoc(), origVectorType,
311311
rewriter.getZeroAttr(origVectorType));
312312
// Increse the Y offset for each array slice.
313313
Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
@@ -323,14 +323,20 @@ class XeGPULoadNdDescOpPattern final
323323
modifiedOffsets, hwSupportedShape,
324324
cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
325325
loadNdOp);
326-
// Insert slice to data.
327-
data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
328-
data, ArrayRef<int64_t>{i});
326+
// // Insert slice to data.
327+
// data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
328+
// data, ArrayRef<int64_t>{i});
329+
// Bitcast back to original load shape without array length.
330+
auto bitcastType = VectorType::get(origTensorDescType.getShape(),
331+
origTensorDescType.getElementType());
332+
slice = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
333+
bitcastType, slice);
334+
arraySlices.push_back(slice);
329335
}
330-
// Cast back to the original type and replace all uses.
331-
data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
332-
loadNdOp.getType(), data);
333-
rewriter.replaceOp(loadNdOp, data);
336+
// // Cast back to the original type and replace all uses.
337+
// data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
338+
// loadNdOp.getType(), data);
339+
rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
334340
return success();
335341
}
336342
data = arith::ConstantOp::create(
@@ -352,12 +358,33 @@ class XeGPULoadNdDescOpPattern final
352358
return success();
353359
}
354360
};
361+
362+
class VectorExtractOpPattern final
363+
: public OpConversionPattern<vector::ExtractOp> {
364+
public:
365+
using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
366+
LogicalResult
367+
matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
368+
ConversionPatternRewriter &rewriter) const override {
369+
if (adaptor.getSource().size() == 1)
370+
return failure();
371+
auto mixedPos = extractOp.getMixedPosition();
372+
if (mixedPos.size() != 1)
373+
return failure();
374+
auto mayBeInt = getConstantIntValue(mixedPos[0]);
375+
if (!mayBeInt)
376+
return failure();
377+
rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
378+
return success();
379+
}
380+
};
381+
355382
} // namespace
356383

357384
void xegpu::populateXeGPUOptimizeTransposePatterns(
358385
RewritePatternSet &patterns) {
359-
patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern>(
360-
patterns.getContext());
386+
patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
387+
VectorExtractOpPattern>(patterns.getContext());
361388
}
362389

363390
namespace {
@@ -381,6 +408,10 @@ struct XeGPUOptimizeTransposePass final
381408
[&](xegpu::LoadNdOp loadNdOp) {
382409
return !hasInvalidTranposeLayout(loadNdOp.getTensorDescType());
383410
});
411+
target.addDynamicallyLegalOp<vector::ExtractOp>(
412+
[&](vector::ExtractOp extractOp) {
413+
return extractOp.getSourceVectorType().getRank() != 3;
414+
});
384415
converter.addConversion([](Type type) { return type; });
385416

386417
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,

0 commit comments

Comments
 (0)