Skip to content

Commit efb288b

Browse files
committed
WIP
1 parent e0fa7c7 commit efb288b

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,6 +1790,8 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
17901790
Value source, ArrayRef<OpFoldResult> innerTileSizes,
17911791
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
17921792
}];
1793+
1794+
let hasCanonicalizer = 1;
17931795
}
17941796

17951797
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3355,6 +3355,63 @@ Speculation::Speculatability PackOp::getSpeculatability() {
33553355
return Speculation::Speculatable;
33563356
}
33573357

3358+
namespace {
3359+
3360+
// unpack(pack(x)) -> x
3361+
struct PackOfUnPack : public OpRewritePattern<PackOp> {
3362+
using OpRewritePattern<PackOp>::OpRewritePattern;
3363+
LogicalResult matchAndRewrite(PackOp packOp,
3364+
PatternRewriter &rewriter) const override {
3365+
UnPackOp unPackOp = packOp.getSource().getDefiningOp<UnPackOp>();
3366+
if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType())
3367+
return failure();
3368+
if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
3369+
return failure();
3370+
if (packOp.getOuterDimsPerm() != unPackOp.getOuterDimsPerm())
3371+
return failure();
3372+
rewriter.replaceOp(packOp, unPackOp.getSource());
3373+
return success();
3374+
}
3375+
};
3376+
3377+
// packing one dimension can be expressed as an expand shape op.
3378+
struct PackToExpandShape : public OpRewritePattern<PackOp> {
3379+
using OpRewritePattern<PackOp>::OpRewritePattern;
3380+
3381+
Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
3382+
Type newOperandType, ArrayAttr reassociation) const {
3383+
Type operandType = operand.getType();
3384+
if (operandType == newOperandType)
3385+
return operand;
3386+
return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
3387+
reassociation);
3388+
}
3389+
3390+
LogicalResult matchAndRewrite(PackOp packOp,
3391+
PatternRewriter &rewriter) const override {
3392+
ShapedType sourceType = packOp.getSourceType();
3393+
ShapedType destType = packOp.getDestType();
3394+
if (sourceType.getRank() != 1)
3395+
return failure();
3396+
auto reassociation =
3397+
getReassociationIndicesForReshape(sourceType, destType);
3398+
if (!reassociation)
3399+
return failure();
3400+
Value expanded = insertExpand(
3401+
rewriter, packOp.getLoc(), packOp.getSource(), destType,
3402+
getReassociationIndicesAttribute(rewriter, *reassociation));
3403+
rewriter.replaceOp(packOp, expanded);
3404+
return success();
3405+
}
3406+
};
3407+
3408+
} // end namespace
3409+
3410+
void PackOp::getCanonicalizationPatterns(RewritePatternSet &results,
3411+
MLIRContext *ctx) {
3412+
results.add<PackOfUnPack, PackToExpandShape>(ctx);
3413+
}
3414+
33583415
//===----------------------------------------------------------------------===//
33593416
// UnPackOp
33603417
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)