@@ -2702,10 +2702,10 @@ void SubViewOp::getAsmResultNames(
27022702// / A subview result type can be fully inferred from the source type and the
27032703// / static representation of offsets, sizes and strides. Special sentinels
27042704// / encode the dynamic case.
2705- Type SubViewOp::inferResultType (MemRefType sourceMemRefType,
2706- ArrayRef<int64_t > staticOffsets,
2707- ArrayRef<int64_t > staticSizes,
2708- ArrayRef<int64_t > staticStrides) {
2705+ MemRefType SubViewOp::inferResultType (MemRefType sourceMemRefType,
2706+ ArrayRef<int64_t > staticOffsets,
2707+ ArrayRef<int64_t > staticSizes,
2708+ ArrayRef<int64_t > staticStrides) {
27092709 unsigned rank = sourceMemRefType.getRank ();
27102710 (void )rank;
27112711 assert (staticOffsets.size () == rank && " staticOffsets length mismatch" );
@@ -2744,10 +2744,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
27442744 sourceMemRefType.getMemorySpace ());
27452745}
27462746
2747- Type SubViewOp::inferResultType (MemRefType sourceMemRefType,
2748- ArrayRef<OpFoldResult> offsets,
2749- ArrayRef<OpFoldResult> sizes,
2750- ArrayRef<OpFoldResult> strides) {
2747+ MemRefType SubViewOp::inferResultType (MemRefType sourceMemRefType,
2748+ ArrayRef<OpFoldResult> offsets,
2749+ ArrayRef<OpFoldResult> sizes,
2750+ ArrayRef<OpFoldResult> strides) {
27512751 SmallVector<int64_t > staticOffsets, staticSizes, staticStrides;
27522752 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
27532753 dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
@@ -2763,13 +2763,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
27632763 staticSizes, staticStrides);
27642764}
27652765
2766- Type SubViewOp::inferRankReducedResultType (ArrayRef<int64_t > resultShape,
2767- MemRefType sourceRankedTensorType,
2768- ArrayRef<int64_t > offsets,
2769- ArrayRef<int64_t > sizes,
2770- ArrayRef<int64_t > strides) {
2771- auto inferredType = llvm::cast<MemRefType>(
2772- inferResultType (sourceRankedTensorType, offsets, sizes, strides));
2766+ MemRefType SubViewOp::inferRankReducedResultType (
2767+ ArrayRef<int64_t > resultShape, MemRefType sourceRankedTensorType,
2768+ ArrayRef<int64_t > offsets, ArrayRef<int64_t > sizes,
2769+ ArrayRef<int64_t > strides) {
2770+ MemRefType inferredType =
2771+ inferResultType (sourceRankedTensorType, offsets, sizes, strides);
27732772 assert (inferredType.getRank () >= static_cast <int64_t >(resultShape.size ()) &&
27742773 " expected " );
27752774 if (inferredType.getRank () == static_cast <int64_t >(resultShape.size ()))
@@ -2795,11 +2794,10 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
27952794 inferredType.getMemorySpace ());
27962795}
27972796
2798- Type SubViewOp::inferRankReducedResultType (ArrayRef<int64_t > resultShape,
2799- MemRefType sourceRankedTensorType,
2800- ArrayRef<OpFoldResult> offsets,
2801- ArrayRef<OpFoldResult> sizes,
2802- ArrayRef<OpFoldResult> strides) {
2797+ MemRefType SubViewOp::inferRankReducedResultType (
2798+ ArrayRef<int64_t > resultShape, MemRefType sourceRankedTensorType,
2799+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2800+ ArrayRef<OpFoldResult> strides) {
28032801 SmallVector<int64_t > staticOffsets, staticSizes, staticStrides;
28042802 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
28052803 dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
@@ -2826,8 +2824,8 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
28262824 auto sourceMemRefType = llvm::cast<MemRefType>(source.getType ());
28272825 // Structuring implementation this way avoids duplication between builders.
28282826 if (!resultType) {
2829- resultType = llvm::cast<MemRefType>( SubViewOp::inferResultType (
2830- sourceMemRefType, staticOffsets, staticSizes, staticStrides) );
2827+ resultType = SubViewOp::inferResultType (sourceMemRefType, staticOffsets,
2828+ staticSizes, staticStrides);
28312829 }
28322830 result.addAttributes (attrs);
28332831 build (b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2992,8 +2990,8 @@ LogicalResult SubViewOp::verify() {
29922990
29932991 // Compute the expected result type, assuming that there are no rank
29942992 // reductions.
2995- auto expectedType = cast<MemRefType>( SubViewOp::inferResultType (
2996- baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ())) ;
2993+ MemRefType expectedType = SubViewOp::inferResultType (
2994+ baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ());
29972995
29982996 // Verify all properties of a shaped type: rank, element type and dimension
29992997 // sizes. This takes into account potential rank reductions.
@@ -3075,8 +3073,8 @@ static MemRefType getCanonicalSubViewResultType(
30753073 MemRefType currentResultType, MemRefType currentSourceType,
30763074 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
30773075 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3078- auto nonRankReducedType = llvm::cast<MemRefType>( SubViewOp::inferResultType (
3079- sourceType, mixedOffsets, mixedSizes, mixedStrides)) ;
3076+ MemRefType nonRankReducedType = SubViewOp::inferResultType (
3077+ sourceType, mixedOffsets, mixedSizes, mixedStrides);
30803078 FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask (
30813079 currentSourceType, currentResultType, mixedSizes);
30823080 if (failed (unusedDims))
@@ -3110,9 +3108,8 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp(
31103108 SmallVector<OpFoldResult> offsets (rank, b.getIndexAttr (0 ));
31113109 SmallVector<OpFoldResult> sizes = getMixedSizes (b, loc, memref);
31123110 SmallVector<OpFoldResult> strides (rank, b.getIndexAttr (1 ));
3113- auto targetType =
3114- llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType (
3115- targetShape, memrefType, offsets, sizes, strides));
3111+ MemRefType targetType = SubViewOp::inferRankReducedResultType (
3112+ targetShape, memrefType, offsets, sizes, strides);
31163113 return b.createOrFold <memref::SubViewOp>(loc, targetType, memref, offsets,
31173114 sizes, strides);
31183115}
@@ -3256,11 +3253,11 @@ struct SubViewReturnTypeCanonicalizer {
32563253 ArrayRef<OpFoldResult> mixedSizes,
32573254 ArrayRef<OpFoldResult> mixedStrides) {
32583255 // Infer a memref type without taking into account any rank reductions.
3259- auto resTy = SubViewOp::inferResultType (op. getSourceType (), mixedOffsets,
3260- mixedSizes, mixedStrides);
3256+ MemRefType resTy = SubViewOp::inferResultType (
3257+ op. getSourceType (), mixedOffsets, mixedSizes, mixedStrides);
32613258 if (!resTy)
32623259 return {};
3263- MemRefType nonReducedType = cast<MemRefType>( resTy) ;
3260+ MemRefType nonReducedType = resTy;
32643261
32653262 // Directly return the non-rank reduced type if there are no dropped dims.
32663263 llvm::SmallBitVector droppedDims = op.getDroppedDims ();
0 commit comments