@@ -3355,6 +3355,63 @@ Speculation::Speculatability PackOp::getSpeculatability() {
33553355 return Speculation::Speculatable;
33563356}
33573357
3358+ namespace {
3359+
3360+ // unpack(pack(x)) -> x
3361+ struct PackOfUnPack : public OpRewritePattern <PackOp> {
3362+ using OpRewritePattern<PackOp>::OpRewritePattern;
3363+ LogicalResult matchAndRewrite (PackOp packOp,
3364+ PatternRewriter &rewriter) const override {
3365+ UnPackOp unPackOp = packOp.getSource ().getDefiningOp <UnPackOp>();
3366+ if (!unPackOp || unPackOp.getSourceType () != packOp.getDestType ())
3367+ return failure ();
3368+ if (packOp.getInnerDimsPos () != unPackOp.getInnerDimsPos ())
3369+ return failure ();
3370+ if (packOp.getOuterDimsPerm () != unPackOp.getOuterDimsPerm ())
3371+ return failure ();
3372+ rewriter.replaceOp (packOp, unPackOp.getSource ());
3373+ return success ();
3374+ }
3375+ };
3376+
3377+ // packing one dimension can be expressed as an expand shape op.
3378+ struct PackToExpandShape : public OpRewritePattern <PackOp> {
3379+ using OpRewritePattern<PackOp>::OpRewritePattern;
3380+
3381+ Value insertExpand (RewriterBase &rewriter, Location loc, Value operand,
3382+ Type newOperandType, ArrayAttr reassociation) const {
3383+ Type operandType = operand.getType ();
3384+ if (operandType == newOperandType)
3385+ return operand;
3386+ return rewriter.create <tensor::ExpandShapeOp>(loc, newOperandType, operand,
3387+ reassociation);
3388+ }
3389+
3390+ LogicalResult matchAndRewrite (PackOp packOp,
3391+ PatternRewriter &rewriter) const override {
3392+ ShapedType sourceType = packOp.getSourceType ();
3393+ ShapedType destType = packOp.getDestType ();
3394+ if (sourceType.getRank () != 1 )
3395+ return failure ();
3396+ auto reassociation =
3397+ getReassociationIndicesForReshape (sourceType, destType);
3398+ if (!reassociation)
3399+ return failure ();
3400+ Value expanded = insertExpand (
3401+ rewriter, packOp.getLoc (), packOp.getSource (), destType,
3402+ getReassociationIndicesAttribute (rewriter, *reassociation));
3403+ rewriter.replaceOp (packOp, expanded);
3404+ return success ();
3405+ }
3406+ };
3407+
3408+ } // end namespace
3409+
3410+ void PackOp::getCanonicalizationPatterns (RewritePatternSet &results,
3411+ MLIRContext *ctx) {
3412+ results.add <PackOfUnPack, PackToExpandShape>(ctx);
3413+ }
3414+
33583415// ===----------------------------------------------------------------------===//
33593416// UnPackOp
33603417// ===----------------------------------------------------------------------===//
0 commit comments