@@ -2823,19 +2823,30 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
2823
2823
}));
2824
2824
}
2825
2825
2826
+ enum SubViewVerificationResult {
2827
+ Success,
2828
+ RankTooLarge,
2829
+ SizeMismatch,
2830
+ StrideMismatch,
2831
+ ElemTypeMismatch,
2832
+ MemSpaceMismatch,
2833
+ AffineMapMismatch
2834
+ };
2835
+
2826
2836
// / Checks if `original` Type type can be rank reduced to `reduced` type.
2827
2837
// / This function is slight variant of `is subsequence` algorithm where
2828
2838
// / not matching dimension must be 1.
2829
- static bool isRankReducedType (Type originalType, Type reducedType) {
2839
+ static SubViewVerificationResult isRankReducedType (Type originalType,
2840
+ Type reducedType) {
2830
2841
if (originalType == reducedType)
2831
- return true ;
2842
+ return SubViewVerificationResult::Success ;
2832
2843
if (!originalType.isa <RankedTensorType>() && !originalType.isa <MemRefType>())
2833
- return true ;
2844
+ return SubViewVerificationResult::Success ;
2834
2845
if (originalType.isa <RankedTensorType>() &&
2835
2846
!reducedType.isa <RankedTensorType>())
2836
- return true ;
2847
+ return SubViewVerificationResult::Success ;
2837
2848
if (originalType.isa <MemRefType>() && !reducedType.isa <MemRefType>())
2838
- return true ;
2849
+ return SubViewVerificationResult::Success ;
2839
2850
2840
2851
ShapedType originalShapedType = originalType.cast <ShapedType>();
2841
2852
ShapedType reducedShapedType = reducedType.cast <ShapedType>();
@@ -2846,7 +2857,7 @@ static bool isRankReducedType(Type originalType, Type reducedType) {
2846
2857
unsigned originalRank = originalShape.size (),
2847
2858
reducedRank = reducedShape.size ();
2848
2859
if (reducedRank > originalRank)
2849
- return false ;
2860
+ return SubViewVerificationResult::RankTooLarge ;
2850
2861
2851
2862
unsigned reducedIdx = 0 ;
2852
2863
SmallVector<bool , 4 > keepMask (originalRank);
@@ -2858,41 +2869,78 @@ static bool isRankReducedType(Type originalType, Type reducedType) {
2858
2869
reducedIdx++;
2859
2870
// 1 is the only non-matching allowed.
2860
2871
else if (originalShape[originalIdx] != 1 )
2861
- return false ;
2872
+ return SubViewVerificationResult::SizeMismatch ;
2862
2873
}
2863
2874
// Must match the reduced rank.
2864
2875
if (reducedIdx != reducedRank)
2865
- return false ;
2876
+ return SubViewVerificationResult::SizeMismatch ;
2866
2877
2867
2878
// We are done for the tensor case.
2868
2879
if (originalType.isa <RankedTensorType>())
2869
- return true ;
2880
+ return SubViewVerificationResult::Success ;
2870
2881
2871
2882
// Strided layout logic is relevant for MemRefType only.
2872
2883
MemRefType original = originalType.cast <MemRefType>();
2873
2884
MemRefType reduced = reducedType.cast <MemRefType>();
2874
2885
MLIRContext *c = original.getContext ();
2875
- int64_t originalOffset, symCounter = 0 , dimCounter = 0 ;
2876
- SmallVector<int64_t , 4 > originalStrides;
2886
+ int64_t originalOffset, reducedOffset ;
2887
+ SmallVector<int64_t , 4 > originalStrides, reducedStrides, keepStrides ;
2877
2888
getStridesAndOffset (original, originalStrides, originalOffset);
2878
- auto getSymbolOrConstant = [&]( int64_t offset) {
2879
- return offset == ShapedType:: kDynamicStrideOrOffset
2880
- ? getAffineSymbolExpr (symCounter++, c)
2881
- : getAffineConstantExpr (offset, c);
2882
- } ;
2883
-
2884
- AffineExpr expr = getSymbolOrConstant (originalOffset);
2885
- for ( unsigned i = 0 , e = originalStrides. size (); i < e; i++) {
2886
- if (keepMask[i])
2887
- expr = expr + getSymbolOrConstant (originalStrides[i]) *
2888
- getAffineDimExpr (dimCounter++, c);
2889
+ getStridesAndOffset (reduced, reducedStrides, reducedOffset);
2890
+
2891
+ // Filter strides based on the mask and check that they are the same
2892
+ // as reduced ones.
2893
+ reducedIdx = 0 ;
2894
+ for ( unsigned originalIdx = 0 ; originalIdx < originalRank; ++originalIdx) {
2895
+ if (keepMask[originalIdx]) {
2896
+ if (originalStrides[originalIdx] != reducedStrides[reducedIdx++])
2897
+ return SubViewVerificationResult::StrideMismatch;
2898
+ keepStrides. push_back (originalStrides[originalIdx]);
2899
+ }
2889
2900
}
2890
2901
2891
- auto reducedMap = AffineMap::get (dimCounter, symCounter, expr, c);
2892
- return original.getElementType () == reduced.getElementType () &&
2893
- original.getMemorySpace () == reduced.getMemorySpace () &&
2894
- (reduced.getAffineMaps ().empty () ||
2895
- reducedMap == reduced.getAffineMaps ().front ());
2902
+ if (original.getElementType () != reduced.getElementType ())
2903
+ return SubViewVerificationResult::ElemTypeMismatch;
2904
+
2905
+ if (original.getMemorySpace () != reduced.getMemorySpace ())
2906
+ return SubViewVerificationResult::MemSpaceMismatch;
2907
+
2908
+ auto reducedMap = makeStridedLinearLayoutMap (keepStrides, originalOffset, c);
2909
+ if (!reduced.getAffineMaps ().empty () &&
2910
+ reducedMap != reduced.getAffineMaps ().front ())
2911
+ return SubViewVerificationResult::AffineMapMismatch;
2912
+
2913
+ return SubViewVerificationResult::Success;
2914
+ }
2915
+
2916
+ template <typename OpTy>
2917
+ static LogicalResult produceSubViewErrorMsg (SubViewVerificationResult result,
2918
+ OpTy op, Type expectedType) {
2919
+ auto memrefType = expectedType.cast <ShapedType>();
2920
+ switch (result) {
2921
+ case SubViewVerificationResult::Success:
2922
+ return success ();
2923
+ case SubViewVerificationResult::RankTooLarge:
2924
+ return op.emitError (" expected result rank to be smaller or equal to " )
2925
+ << " the source rank." ;
2926
+ case SubViewVerificationResult::SizeMismatch:
2927
+ return op.emitError (" expected result type to be " )
2928
+ << expectedType
2929
+ << " or a rank-reduced version. (mismatch of result sizes)" ;
2930
+ case SubViewVerificationResult::StrideMismatch:
2931
+ return op.emitError (" expected result type to be " )
2932
+ << expectedType
2933
+ << " or a rank-reduced version. (mismatch of result strides)" ;
2934
+ case SubViewVerificationResult::ElemTypeMismatch:
2935
+ return op.emitError (" expected result element type to be " )
2936
+ << memrefType.getElementType ();
2937
+ case SubViewVerificationResult::MemSpaceMismatch:
2938
+ return op.emitError (" expected result and source memory spaces to match." );
2939
+ case SubViewVerificationResult::AffineMapMismatch:
2940
+ return op.emitError (" expected result type to be " )
2941
+ << expectedType
2942
+ << " or a rank-reduced version. (mismatch of result affine map)" ;
2943
+ }
2896
2944
}
2897
2945
2898
2946
template <typename OpType>
@@ -2937,11 +2985,9 @@ static LogicalResult verify(SubViewOp op) {
2937
2985
baseType, extractFromI64ArrayAttr (op.static_offsets ()),
2938
2986
extractFromI64ArrayAttr (op.static_sizes ()),
2939
2987
extractFromI64ArrayAttr (op.static_strides ()));
2940
- if (!isRankReducedType (expectedType, subViewType))
2941
- return op.emitError (" expected result type to be " )
2942
- << expectedType << " or a rank-reduced version." ;
2943
2988
2944
- return success ();
2989
+ auto result = isRankReducedType (expectedType, subViewType);
2990
+ return produceSubViewErrorMsg (result, op, expectedType);
2945
2991
}
2946
2992
2947
2993
raw_ostream &mlir::operator <<(raw_ostream &os, Range &range) {
@@ -3352,11 +3398,8 @@ static LogicalResult verify(SubTensorOp op) {
3352
3398
op.getSourceType (), extractFromI64ArrayAttr (op.static_offsets ()),
3353
3399
extractFromI64ArrayAttr (op.static_sizes ()),
3354
3400
extractFromI64ArrayAttr (op.static_strides ()));
3355
- if (!isRankReducedType (expectedType, op.getType ()))
3356
- return op.emitError (" expected result type to be " )
3357
- << expectedType << " or a rank-reduced version." ;
3358
-
3359
- return success ();
3401
+ auto result = isRankReducedType (expectedType, op.getType ());
3402
+ return produceSubViewErrorMsg (result, op, expectedType);
3360
3403
}
3361
3404
3362
3405
void SubTensorOp::getCanonicalizationPatterns (OwningRewritePatternList &results,
0 commit comments