@@ -2697,10 +2697,10 @@ void SubViewOp::getAsmResultNames(
26972697// / A subview result type can be fully inferred from the source type and the
26982698// / static representation of offsets, sizes and strides. Special sentinels
26992699// / encode the dynamic case.
2700- Type SubViewOp::inferResultType (MemRefType sourceMemRefType,
2701- ArrayRef<int64_t > staticOffsets,
2702- ArrayRef<int64_t > staticSizes,
2703- ArrayRef<int64_t > staticStrides) {
2700+ MemRefType SubViewOp::inferResultType (MemRefType sourceMemRefType,
2701+ ArrayRef<int64_t > staticOffsets,
2702+ ArrayRef<int64_t > staticSizes,
2703+ ArrayRef<int64_t > staticStrides) {
27042704 unsigned rank = sourceMemRefType.getRank ();
27052705 (void )rank;
27062706 assert (staticOffsets.size () == rank && " staticOffsets length mismatch" );
@@ -2739,10 +2739,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
27392739 sourceMemRefType.getMemorySpace ());
27402740}
27412741
2742- Type SubViewOp::inferResultType (MemRefType sourceMemRefType,
2743- ArrayRef<OpFoldResult> offsets,
2744- ArrayRef<OpFoldResult> sizes,
2745- ArrayRef<OpFoldResult> strides) {
2742+ MemRefType SubViewOp::inferResultType (MemRefType sourceMemRefType,
2743+ ArrayRef<OpFoldResult> offsets,
2744+ ArrayRef<OpFoldResult> sizes,
2745+ ArrayRef<OpFoldResult> strides) {
27462746 SmallVector<int64_t > staticOffsets, staticSizes, staticStrides;
27472747 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
27482748 dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
@@ -2758,13 +2758,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
27582758 staticSizes, staticStrides);
27592759}
27602760
2761- Type SubViewOp::inferRankReducedResultType (ArrayRef<int64_t > resultShape,
2762- MemRefType sourceRankedTensorType,
2763- ArrayRef<int64_t > offsets,
2764- ArrayRef<int64_t > sizes,
2765- ArrayRef<int64_t > strides) {
2766- auto inferredType = llvm::cast<MemRefType>(
2767- inferResultType (sourceRankedTensorType, offsets, sizes, strides));
2761+ MemRefType SubViewOp::inferRankReducedResultType (
2762+ ArrayRef<int64_t > resultShape, MemRefType sourceRankedTensorType,
2763+ ArrayRef<int64_t > offsets, ArrayRef<int64_t > sizes,
2764+ ArrayRef<int64_t > strides) {
2765+ MemRefType inferredType =
2766+ inferResultType (sourceRankedTensorType, offsets, sizes, strides);
27682767 assert (inferredType.getRank () >= static_cast <int64_t >(resultShape.size ()) &&
27692768 " expected " );
27702769 if (inferredType.getRank () == static_cast <int64_t >(resultShape.size ()))
@@ -2790,11 +2789,10 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
27902789 inferredType.getMemorySpace ());
27912790}
27922791
2793- Type SubViewOp::inferRankReducedResultType (ArrayRef<int64_t > resultShape,
2794- MemRefType sourceRankedTensorType,
2795- ArrayRef<OpFoldResult> offsets,
2796- ArrayRef<OpFoldResult> sizes,
2797- ArrayRef<OpFoldResult> strides) {
2792+ MemRefType SubViewOp::inferRankReducedResultType (
2793+ ArrayRef<int64_t > resultShape, MemRefType sourceRankedTensorType,
2794+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2795+ ArrayRef<OpFoldResult> strides) {
27982796 SmallVector<int64_t > staticOffsets, staticSizes, staticStrides;
27992797 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
28002798 dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
@@ -2821,8 +2819,8 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
28212819 auto sourceMemRefType = llvm::cast<MemRefType>(source.getType ());
28222820 // Structuring implementation this way avoids duplication between builders.
28232821 if (!resultType) {
2824- resultType = llvm::cast<MemRefType>( SubViewOp::inferResultType (
2825- sourceMemRefType, staticOffsets, staticSizes, staticStrides) );
2822+ resultType = SubViewOp::inferResultType (sourceMemRefType, staticOffsets,
2823+ staticSizes, staticStrides);
28262824 }
28272825 result.addAttributes (attrs);
28282826 build (b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2987,8 +2985,8 @@ LogicalResult SubViewOp::verify() {
29872985
29882986 // Compute the expected result type, assuming that there are no rank
29892987 // reductions.
2990- auto expectedType = cast<MemRefType>( SubViewOp::inferResultType (
2991- baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ())) ;
2988+ MemRefType expectedType = SubViewOp::inferResultType (
2989+ baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ());
29922990
29932991 // Verify all properties of a shaped type: rank, element type and dimension
29942992 // sizes. This takes into account potential rank reductions.
@@ -3070,8 +3068,8 @@ static MemRefType getCanonicalSubViewResultType(
30703068 MemRefType currentResultType, MemRefType currentSourceType,
30713069 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
30723070 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3073- auto nonRankReducedType = llvm::cast<MemRefType>( SubViewOp::inferResultType (
3074- sourceType, mixedOffsets, mixedSizes, mixedStrides)) ;
3071+ MemRefType nonRankReducedType = SubViewOp::inferResultType (
3072+ sourceType, mixedOffsets, mixedSizes, mixedStrides);
30753073 FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask (
30763074 currentSourceType, currentResultType, mixedSizes);
30773075 if (failed (unusedDims))
@@ -3105,9 +3103,8 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp(
31053103 SmallVector<OpFoldResult> offsets (rank, b.getIndexAttr (0 ));
31063104 SmallVector<OpFoldResult> sizes = getMixedSizes (b, loc, memref);
31073105 SmallVector<OpFoldResult> strides (rank, b.getIndexAttr (1 ));
3108- auto targetType =
3109- llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType (
3110- targetShape, memrefType, offsets, sizes, strides));
3106+ MemRefType targetType = SubViewOp::inferRankReducedResultType (
3107+ targetShape, memrefType, offsets, sizes, strides);
31113108 return b.createOrFold <memref::SubViewOp>(loc, targetType, memref, offsets,
31123109 sizes, strides);
31133110}
@@ -3251,11 +3248,11 @@ struct SubViewReturnTypeCanonicalizer {
32513248 ArrayRef<OpFoldResult> mixedSizes,
32523249 ArrayRef<OpFoldResult> mixedStrides) {
32533250 // Infer a memref type without taking into account any rank reductions.
3254- auto resTy = SubViewOp::inferResultType (op. getSourceType (), mixedOffsets,
3255- mixedSizes, mixedStrides);
3251+ MemRefType resTy = SubViewOp::inferResultType (
3252+ op. getSourceType (), mixedOffsets, mixedSizes, mixedStrides);
32563253 if (!resTy)
32573254 return {};
3258- MemRefType nonReducedType = cast<MemRefType>( resTy) ;
3255+ MemRefType nonReducedType = resTy;
32593256
32603257 // Directly return the non-rank reduced type if there are no dropped dims.
32613258 llvm::SmallBitVector droppedDims = op.getDroppedDims ();
0 commit comments