Skip to content

Commit 790e974

Browse files
committed
bail out transforms using PackOp, UnPackOp
1 parent 276069d commit 790e974

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
9191
linalg::PackOp packOp, AffineMap operandMap,
9292
ArrayRef<unsigned> blocksStartDimPos,
9393
bool transposeOuterBlocks, bool transposeInnerBlocks) {
94+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
95+
if (!packOp.hasPureTensorSemantics()) {
96+
return failure();
97+
}
98+
9499
assert(operandMap.getNumDims() >= 4 &&
95100
"expected at least 4D prepacked matmul");
96101
assert(blocksStartDimPos.size() >= 2 &&

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
6363
OpTy packOrUnPackOp) {
6464
static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
6565
"applies to only pack or unpack operations");
66+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
67+
if (isa<linalg::LinalgOp>(packOrUnPackOp)) {
68+
if (!packOrUnPackOp.hasPureTensorSemantics()) {
69+
return failure();
70+
}
71+
}
6672
LLVM_DEBUG(
6773
{ llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
6874

@@ -373,6 +379,11 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
373379
static FailureOr<GenericOp>
374380
bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
375381
const ControlPropagationFn &controlFn) {
382+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
383+
if (!packOp.hasPureTensorSemantics()) {
384+
return failure();
385+
}
386+
376387
auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
377388
if (!genericOp)
378389
return failure();
@@ -461,6 +472,11 @@ struct BubbleUpPackOpThroughGenericOpPattern
461472

462473
LogicalResult matchAndRewrite(linalg::PackOp packOp,
463474
PatternRewriter &rewriter) const override {
475+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
476+
if (!packOp.hasPureTensorSemantics()) {
477+
return failure();
478+
}
479+
464480
auto genericOp =
465481
bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
466482
if (failed(genericOp))
@@ -483,6 +499,11 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
483499

484500
LogicalResult matchAndRewrite(linalg::PackOp packOp,
485501
PatternRewriter &rewriter) const override {
502+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
503+
if (!packOp.hasPureTensorSemantics()) {
504+
return failure();
505+
}
506+
486507
auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
487508
if (!padOp)
488509
return failure();
@@ -651,6 +672,11 @@ static LogicalResult
651672
bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
652673
linalg::PackOp packOp,
653674
PatternRewriter &rewriter) {
675+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
676+
if (!packOp.hasPureTensorSemantics()) {
677+
return failure();
678+
}
679+
654680
SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
655681
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
656682
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
@@ -757,6 +783,11 @@ static LogicalResult
757783
bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
758784
linalg::PackOp packOp,
759785
PatternRewriter &rewriter) {
786+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
787+
if (!packOp.hasPureTensorSemantics()) {
788+
return failure();
789+
}
790+
760791
// Outer dimensions permutation is not supported currently.
761792
// TODO: Handle outer_dims_perm variants.
762793
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
@@ -840,6 +871,11 @@ class BubbleUpPackOpThroughReshapeOp final
840871

841872
LogicalResult matchAndRewrite(linalg::PackOp packOp,
842873
PatternRewriter &rewriter) const override {
874+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
875+
if (!packOp.hasPureTensorSemantics()) {
876+
return failure();
877+
}
878+
843879
Operation *srcOp = packOp.getSource().getDefiningOp();
844880
// Currently only support when the pack op is the only user.
845881
if (!srcOp || !(srcOp->getNumResults() == 1) ||
@@ -893,6 +929,11 @@ class BubbleUpPackOpThroughReshapeOp final
893929
static LogicalResult pushDownUnPackOpThroughExpandShape(
894930
linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
895931
PatternRewriter &rewriter, ControlPropagationFn controlFn) {
932+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
933+
if (!unPackOp.hasPureTensorSemantics()) {
934+
return failure();
935+
}
936+
896937
// User controlled propagation function.
897938
if (!controlFn(&expandOp.getSrcMutable()))
898939
return failure();
@@ -970,6 +1011,11 @@ class PushDownUnPackOpThroughReshapeOp final
9701011

9711012
LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
9721013
PatternRewriter &rewriter) const override {
1014+
// TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
1015+
if (!unPackOp.hasPureTensorSemantics()) {
1016+
return failure();
1017+
}
1018+
9731019
Value result = unPackOp.getResult();
9741020
// Currently only support unpack op with the single user.
9751021
if (!result.hasOneUse()) {
@@ -1146,11 +1192,17 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
11461192

11471193
LogicalResult matchAndRewrite(tensor::PadOp padOp,
11481194
PatternRewriter &rewriter) const override {
1195+
11491196
linalg::UnPackOp unpackOp =
11501197
padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1198+
11511199
if (!unpackOp)
11521200
return failure();
11531201

1202+
// TODO(issues/129004): Support MemRef PadOp. Temporarily return failure.
1203+
if (!unpackOp.hasPureTensorSemantics())
1204+
return failure();
1205+
11541206
if (!controlFn(&padOp.getSourceMutable()))
11551207
return failure();
11561208

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,6 +1588,11 @@ static LogicalResult
15881588
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
15891589
ArrayRef<int64_t> inputVectorSizes,
15901590
SmallVectorImpl<Value> &newResults) {
1591+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
1592+
if (!packOp.hasPureTensorSemantics()) {
1593+
return failure();
1594+
}
1595+
15911596
// TODO: Introduce a parent class that will handle the insertion point update.
15921597
OpBuilder::InsertionGuard g(rewriter);
15931598
rewriter.setInsertionPoint(packOp);
@@ -1664,6 +1669,10 @@ static LogicalResult
16641669
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
16651670
ArrayRef<int64_t> inputVectorSizes,
16661671
SmallVectorImpl<Value> &newResults) {
1672+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
1673+
if (!unpackOp.hasPureTensorSemantics()) {
1674+
return failure();
1675+
}
16671676

16681677
// TODO: Introduce a parent class that will handle the insertion point update.
16691678
OpBuilder::InsertionGuard g(rewriter);
@@ -1891,6 +1900,10 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
18911900
static LogicalResult
18921901
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
18931902
ArrayRef<int64_t> inputVectorSizes) {
1903+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
1904+
if (!unpackOp.hasPureTensorSemantics()) {
1905+
return failure();
1906+
}
18941907

18951908
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
18961909
return !getConstantIntValue(res).has_value();
@@ -2136,6 +2149,11 @@ static LogicalResult vectorizeLinalgOpPrecondition(
21362149
static LogicalResult
21372150
vectorizePackOpPrecondition(linalg::PackOp packOp,
21382151
ArrayRef<int64_t> inputVectorSizes) {
2152+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
2153+
if (!packOp.hasPureTensorSemantics()) {
2154+
return failure();
2155+
}
2156+
21392157
auto padValue = packOp.getPaddingValue();
21402158
Attribute cstAttr;
21412159
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
@@ -2358,6 +2376,13 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
23582376
}
23592377

23602378
bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2379+
// TODO(issues/129004): Support MemRef PackOp. Temporarily return false.
2380+
// Actually do we need this?
2381+
if (isa<linalg::PackOp, linalg::UnPackOp>(op)) {
2382+
if (!cast<LinalgOp>(op).hasPureTensorSemantics()) {
2383+
return false;
2384+
}
2385+
}
23612386
return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
23622387
tensor::InsertSliceOp>(op);
23632388
}

0 commit comments

Comments
 (0)