@@ -4487,17 +4487,13 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
44874487 // Verify result shape is greater than the minimum expected
44884488 // by the pack operation, and that the output shape
44894489 // represents full tiles.
4490- if (hasTensorSemantics) {
4491- RankedTensorType expectedPackedType = PackOp::inferPackedTensorType (
4492- cast<RankedTensorType>(unpackedType), packOrUnPack.getStaticTiles (),
4493- innerDimsPos, outerDimPerm);
4494- if (!areAllInBound (expectedPackedType.getShape (), packedType.getShape ())) {
4495- return op->emitError (
4496- " the shape of output is not large enough to hold the "
4497- " packed data. Expected at least " )
4498- << expectedPackedType << " , got " << packedType;
4499- }
4500- } else {
4490+ auto expectedPackedShape = PackOp::inferPackedShape (
4491+ unpackedType.getShape (), packOrUnPack.getStaticTiles (),
4492+ packOrUnPack.getInnerDimsPos (), packOrUnPack.getOuterDimsPerm ());
4493+ if (!areAllInBound (expectedPackedShape, packedType.getShape ())) {
4494+ return op->emitError (" the shape of output is not large enough to hold the "
4495+ " packed data. Expected at least " )
4496+ << expectedPackedShape << " , got " << packedType.getShape ();
45014497 }
45024498 if (!llvm::all_of (
45034499 llvm::zip (packedType.getShape ().take_back (mixedTiles.size ()),
@@ -4784,6 +4780,14 @@ MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
47844780 return MemRefType::get (resultShape, sourceType.getElementType ());
47854781}
47864782
4783+ SmallVector<int64_t > PackOp::inferPackedShape (ArrayRef<int64_t > inputShape,
4784+ ArrayRef<int64_t > innerTileSizes,
4785+ ArrayRef<int64_t > innerDimsPos,
4786+ ArrayRef<int64_t > outerDimsPerm) {
4787+ return getPackOpResultTypeShape (inputShape, innerTileSizes, innerDimsPos,
4788+ outerDimsPerm);
4789+ }
4790+
47874791Value PackOp::createDestinationTensor (OpBuilder &b, Location loc, Value source,
47884792 ArrayRef<OpFoldResult> innerTileSizes,
47894793 ArrayRef<int64_t > innerDimsPos,
@@ -5030,7 +5034,6 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
50305034 // Insert a cast if needed
50315035 if (needUpdateDestType) {
50325036 rewriter.setInsertionPointAfter (packOp);
5033- // / 1
50345037 if (hasTensorSemantics) {
50355038 auto castOp =
50365039 rewriter.create <tensor::CastOp>(loc, originalResultType, packOp);
@@ -5040,16 +5043,6 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
50405043 rewriter.create <memref::CastOp>(loc, originalResultType, packOp);
50415044 rewriter.replaceAllUsesExcept (packOp, castOp, castOp);
50425045 }
5043- // / 2
5044- Operation *castOp;
5045- if (hasTensorSemantics) {
5046- castOp =
5047- rewriter.create <tensor::CastOp>(loc, originalResultType, packOp);
5048- } else {
5049- castOp =
5050- rewriter.create <memref::CastOp>(loc, originalResultType, packOp);
5051- }
5052- rewriter.replaceAllUsesExcept (packOp, castOp->getResult (0 ), castOp);
50535046 }
50545047 return success ();
50555048 }
0 commit comments