diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index c3ee3968abc16..4c8a214049ea9 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -2081,14 +2081,14 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [ /// A subview result type can be fully inferred from the source type and the /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. - static Type inferResultType(MemRefType sourceMemRefType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); - static Type inferResultType(MemRefType sourceMemRefType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); + static MemRefType inferResultType(MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); + static MemRefType inferResultType(MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); /// A rank-reducing result type can be inferred from the desired result /// shape. Only the layout map is inferred. @@ -2097,16 +2097,16 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [ /// and the desired sizes. In case there are more "ones" among the sizes /// than the difference in source/result rank, it is not clear which dims of /// size one should be dropped. - static Type inferRankReducedResultType(ArrayRef resultShape, - MemRefType sourceMemRefType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); - static Type inferRankReducedResultType(ArrayRef resultShape, - MemRefType sourceMemRefType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); + static MemRefType inferRankReducedResultType( + ArrayRef resultShape, MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); + static MemRefType inferRankReducedResultType( + ArrayRef resultShape, MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e0930abc1887d..11597505e7888 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2702,10 +2702,10 @@ void SubViewOp::getAsmResultNames( /// A subview result type can be fully inferred from the source type and the /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. -Type SubViewOp::inferResultType(MemRefType sourceMemRefType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides) { +MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides) { unsigned rank = sourceMemRefType.getRank(); (void)rank; assert(staticOffsets.size() == rank && "staticOffsets length mismatch"); @@ -2744,10 +2744,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType, sourceMemRefType.getMemorySpace()); } -Type SubViewOp::inferResultType(MemRefType sourceMemRefType, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides) { +MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -2763,13 +2763,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType, staticSizes, staticStrides); } -Type SubViewOp::inferRankReducedResultType(ArrayRef resultShape, - MemRefType sourceRankedTensorType, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides) { - auto inferredType = llvm::cast( - inferResultType(sourceRankedTensorType, offsets, sizes, strides)); +MemRefType SubViewOp::inferRankReducedResultType( + ArrayRef resultShape, MemRefType sourceRankedTensorType, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides) { + MemRefType inferredType = + inferResultType(sourceRankedTensorType, offsets, sizes, strides); assert(inferredType.getRank() >= static_cast(resultShape.size()) && "expected "); if (inferredType.getRank() == static_cast(resultShape.size())) @@ -2795,11 +2794,10 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef resultShape, inferredType.getMemorySpace()); } -Type SubViewOp::inferRankReducedResultType(ArrayRef resultShape, - MemRefType sourceRankedTensorType, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides) { +MemRefType SubViewOp::inferRankReducedResultType( + ArrayRef resultShape, MemRefType sourceRankedTensorType, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -2826,8 +2824,8 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, auto sourceMemRefType = llvm::cast(source.getType()); // Structuring implementation this way avoids duplication between builders. if (!resultType) { - resultType = llvm::cast(SubViewOp::inferResultType( - sourceMemRefType, staticOffsets, staticSizes, staticStrides)); + resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, + staticSizes, staticStrides); } result.addAttributes(attrs); build(b, result, resultType, source, dynamicOffsets, dynamicSizes, @@ -2992,8 +2990,8 @@ LogicalResult SubViewOp::verify() { // Compute the expected result type, assuming that there are no rank // reductions. - auto expectedType = cast(SubViewOp::inferResultType( - baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides())); + MemRefType expectedType = SubViewOp::inferResultType( + baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()); // Verify all properties of a shaped type: rank, element type and dimension // sizes. This takes into account potential rank reductions. @@ -3075,8 +3073,8 @@ static MemRefType getCanonicalSubViewResultType( MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef mixedOffsets, ArrayRef mixedSizes, ArrayRef mixedStrides) { - auto nonRankReducedType = llvm::cast(SubViewOp::inferResultType( - sourceType, mixedOffsets, mixedSizes, mixedStrides)); + MemRefType nonRankReducedType = SubViewOp::inferResultType( + sourceType, mixedOffsets, mixedSizes, mixedStrides); FailureOr unusedDims = computeMemRefRankReductionMask( currentSourceType, currentResultType, mixedSizes); if (failed(unusedDims)) @@ -3110,9 +3108,8 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp( SmallVector offsets(rank, b.getIndexAttr(0)); SmallVector sizes = getMixedSizes(b, loc, memref); SmallVector strides(rank, b.getIndexAttr(1)); - auto targetType = - llvm::cast(SubViewOp::inferRankReducedResultType( - targetShape, memrefType, offsets, sizes, strides)); + MemRefType targetType = SubViewOp::inferRankReducedResultType( + targetShape, memrefType, offsets, sizes, strides); return b.createOrFold(loc, targetType, memref, offsets, sizes, strides); } @@ -3256,11 +3253,11 @@ struct SubViewReturnTypeCanonicalizer { ArrayRef mixedSizes, ArrayRef mixedStrides) { // Infer a memref type without taking into account any rank reductions. - auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets, - mixedSizes, mixedStrides); + MemRefType resTy = SubViewOp::inferResultType( + op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides); if (!resTy) return {}; - MemRefType nonReducedType = cast(resTy); + MemRefType nonReducedType = resTy; // Directly return the non-rank reduced type if there are no dropped dims. llvm::SmallBitVector droppedDims = op.getDroppedDims(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp index e237858d208a0..21361d2e9a2d7 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -70,9 +70,9 @@ propagateSubViewOp(RewriterBase &rewriter, UnrealizedConversionCastOp conversionOp, SubViewOp op) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); - auto newResultType = cast(SubViewOp::inferRankReducedResultType( + MemRefType newResultType = SubViewOp::inferRankReducedResultType( op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), - op.getMixedSizes(), op.getMixedStrides())); + op.getMixedSizes(), op.getMixedStrides()); Value newSubview = rewriter.create( op.getLoc(), newResultType, conversionOp.getOperand(0), op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp index bc0dd034f6385..c475d92e0658e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -60,14 +60,13 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter, // `subview(old_op)` is replaced by a new `subview(val)`. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(subviewUse); - Type newType = memref::SubViewOp::inferRankReducedResultType( + MemRefType newType = memref::SubViewOp::inferRankReducedResultType( subviewUse.getType().getShape(), cast(val.getType()), subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), subviewUse.getStaticStrides()); Value newSubview = rewriter.create( - subviewUse->getLoc(), cast(newType), val, - subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), - subviewUse.getMixedStrides()); + subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(), + subviewUse.getMixedSizes(), subviewUse.getMixedStrides()); // Ouch recursion ... is this really necessary? replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); @@ -211,9 +210,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, for (int64_t i = 0, e = originalShape.size(); i != e; ++i) sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]); // Strides is [1, 1 ... 1 ]. - auto dstMemref = - cast(memref::SubViewOp::inferRankReducedResultType( - originalShape, mbMemRefType, offsets, sizes, strides)); + MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType( + originalShape, mbMemRefType, offsets, sizes, strides); Value subview = rewriter.create(loc, dstMemref, mbAlloc, offsets, sizes, strides); LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index ed3ba321b37ab..81404fa664cd4 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -407,10 +407,10 @@ struct ExtractSliceOpInterface SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); - return cast(memref::SubViewOp::inferRankReducedResultType( + return memref::SubViewOp::inferRankReducedResultType( extractSliceOp.getType().getShape(), llvm::cast(*srcMemrefType), mixedOffsets, mixedSizes, - mixedStrides)); + mixedStrides); } }; @@ -692,10 +692,10 @@ struct InsertSliceOpInterface // Take a subview of the destination buffer. auto dstMemrefType = cast(dstMemref->getType()); - auto subviewMemRefType = - cast(memref::SubViewOp::inferRankReducedResultType( + MemRefType subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( insertSliceOp.getSourceType().getShape(), dstMemrefType, - mixedOffsets, mixedSizes, mixedStrides)); + mixedOffsets, mixedSizes, mixedStrides); Value subView = rewriter.create( loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, mixedStrides); @@ -960,12 +960,12 @@ struct ParallelInsertSliceOpInterface // Take a subview of the destination buffer. auto destBufferType = cast(destBuffer->getType()); - auto subviewMemRefType = - cast(memref::SubViewOp::inferRankReducedResultType( + MemRefType subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( parallelInsertSliceOp.getSourceType().getShape(), destBufferType, parallelInsertSliceOp.getMixedOffsets(), parallelInsertSliceOp.getMixedSizes(), - parallelInsertSliceOp.getMixedStrides())); + parallelInsertSliceOp.getMixedStrides()); Value subview = rewriter.create( parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer, parallelInsertSliceOp.getMixedOffsets(), diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 5871d6dd5b3e6..f13e54901f690 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -265,9 +265,9 @@ static MemRefType dropUnitDims(MemRefType inputType, ArrayRef sizes, ArrayRef strides) { auto targetShape = getReducedShape(sizes); - Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( + MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType( targetShape, inputType, offsets, sizes, strides); - return cast(rankReducedType).canonicalizeStridedLayout(); + return rankReducedType.canonicalizeStridedLayout(); } /// Creates a rank-reducing memref.subview op that drops unit dims from its diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 47fca8e72b573..dc46ed17a374d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1326,10 +1326,9 @@ class DropInnerMostUnitDimsTransferRead rewriter.getIndexAttr(0)); SmallVector strides(srcType.getRank(), rewriter.getIndexAttr(1)); - auto resultMemrefType = - cast(memref::SubViewOp::inferRankReducedResultType( - srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, - strides)); + MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType( + srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, + strides); ArrayAttr inBoundsAttr = rewriter.getArrayAttr( readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)); Value rankedReducedView = rewriter.create( @@ -1417,10 +1416,9 @@ class DropInnerMostUnitDimsTransferWrite rewriter.getIndexAttr(0)); SmallVector strides(srcType.getRank(), rewriter.getIndexAttr(1)); - auto resultMemrefType = - cast(memref::SubViewOp::inferRankReducedResultType( - srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, - strides)); + MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType( + srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, + strides); ArrayAttr inBoundsAttr = rewriter.getArrayAttr( writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));