Skip to content

Commit e7cf723

Browse files
committed
[mlir] Added strides check to rank reducing subview verification
Added missing strides check to verification method of rank reducing subview which enforces strides specification for the resulting type. Differential Revision: https://reviews.llvm.org/D88879
1 parent 7c88d13 commit e7cf723

File tree

3 files changed

+118
-44
lines changed

3 files changed

+118
-44
lines changed

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 79 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,19 +2823,30 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
28232823
}));
28242824
}
28252825

2826+
enum SubViewVerificationResult {
2827+
Success,
2828+
RankTooLarge,
2829+
SizeMismatch,
2830+
StrideMismatch,
2831+
ElemTypeMismatch,
2832+
MemSpaceMismatch,
2833+
AffineMapMismatch
2834+
};
2835+
28262836
/// Checks if `original` Type type can be rank reduced to `reduced` type.
28272837
/// This function is slight variant of `is subsequence` algorithm where
28282838
/// not matching dimension must be 1.
2829-
static bool isRankReducedType(Type originalType, Type reducedType) {
2839+
static SubViewVerificationResult isRankReducedType(Type originalType,
2840+
Type reducedType) {
28302841
if (originalType == reducedType)
2831-
return true;
2842+
return SubViewVerificationResult::Success;
28322843
if (!originalType.isa<RankedTensorType>() && !originalType.isa<MemRefType>())
2833-
return true;
2844+
return SubViewVerificationResult::Success;
28342845
if (originalType.isa<RankedTensorType>() &&
28352846
!reducedType.isa<RankedTensorType>())
2836-
return true;
2847+
return SubViewVerificationResult::Success;
28372848
if (originalType.isa<MemRefType>() && !reducedType.isa<MemRefType>())
2838-
return true;
2849+
return SubViewVerificationResult::Success;
28392850

28402851
ShapedType originalShapedType = originalType.cast<ShapedType>();
28412852
ShapedType reducedShapedType = reducedType.cast<ShapedType>();
@@ -2846,7 +2857,7 @@ static bool isRankReducedType(Type originalType, Type reducedType) {
28462857
unsigned originalRank = originalShape.size(),
28472858
reducedRank = reducedShape.size();
28482859
if (reducedRank > originalRank)
2849-
return false;
2860+
return SubViewVerificationResult::RankTooLarge;
28502861

28512862
unsigned reducedIdx = 0;
28522863
SmallVector<bool, 4> keepMask(originalRank);
@@ -2858,41 +2869,78 @@ static bool isRankReducedType(Type originalType, Type reducedType) {
28582869
reducedIdx++;
28592870
// 1 is the only non-matching allowed.
28602871
else if (originalShape[originalIdx] != 1)
2861-
return false;
2872+
return SubViewVerificationResult::SizeMismatch;
28622873
}
28632874
// Must match the reduced rank.
28642875
if (reducedIdx != reducedRank)
2865-
return false;
2876+
return SubViewVerificationResult::SizeMismatch;
28662877

28672878
// We are done for the tensor case.
28682879
if (originalType.isa<RankedTensorType>())
2869-
return true;
2880+
return SubViewVerificationResult::Success;
28702881

28712882
// Strided layout logic is relevant for MemRefType only.
28722883
MemRefType original = originalType.cast<MemRefType>();
28732884
MemRefType reduced = reducedType.cast<MemRefType>();
28742885
MLIRContext *c = original.getContext();
2875-
int64_t originalOffset, symCounter = 0, dimCounter = 0;
2876-
SmallVector<int64_t, 4> originalStrides;
2886+
int64_t originalOffset, reducedOffset;
2887+
SmallVector<int64_t, 4> originalStrides, reducedStrides, keepStrides;
28772888
getStridesAndOffset(original, originalStrides, originalOffset);
2878-
auto getSymbolOrConstant = [&](int64_t offset) {
2879-
return offset == ShapedType::kDynamicStrideOrOffset
2880-
? getAffineSymbolExpr(symCounter++, c)
2881-
: getAffineConstantExpr(offset, c);
2882-
};
2883-
2884-
AffineExpr expr = getSymbolOrConstant(originalOffset);
2885-
for (unsigned i = 0, e = originalStrides.size(); i < e; i++) {
2886-
if (keepMask[i])
2887-
expr = expr + getSymbolOrConstant(originalStrides[i]) *
2888-
getAffineDimExpr(dimCounter++, c);
2889+
getStridesAndOffset(reduced, reducedStrides, reducedOffset);
2890+
2891+
// Filter strides based on the mask and check that they are the same
2892+
// as reduced ones.
2893+
reducedIdx = 0;
2894+
for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
2895+
if (keepMask[originalIdx]) {
2896+
if (originalStrides[originalIdx] != reducedStrides[reducedIdx++])
2897+
return SubViewVerificationResult::StrideMismatch;
2898+
keepStrides.push_back(originalStrides[originalIdx]);
2899+
}
28892900
}
28902901

2891-
auto reducedMap = AffineMap::get(dimCounter, symCounter, expr, c);
2892-
return original.getElementType() == reduced.getElementType() &&
2893-
original.getMemorySpace() == reduced.getMemorySpace() &&
2894-
(reduced.getAffineMaps().empty() ||
2895-
reducedMap == reduced.getAffineMaps().front());
2902+
if (original.getElementType() != reduced.getElementType())
2903+
return SubViewVerificationResult::ElemTypeMismatch;
2904+
2905+
if (original.getMemorySpace() != reduced.getMemorySpace())
2906+
return SubViewVerificationResult::MemSpaceMismatch;
2907+
2908+
auto reducedMap = makeStridedLinearLayoutMap(keepStrides, originalOffset, c);
2909+
if (!reduced.getAffineMaps().empty() &&
2910+
reducedMap != reduced.getAffineMaps().front())
2911+
return SubViewVerificationResult::AffineMapMismatch;
2912+
2913+
return SubViewVerificationResult::Success;
2914+
}
2915+
2916+
template <typename OpTy>
2917+
static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
2918+
OpTy op, Type expectedType) {
2919+
auto memrefType = expectedType.cast<ShapedType>();
2920+
switch (result) {
2921+
case SubViewVerificationResult::Success:
2922+
return success();
2923+
case SubViewVerificationResult::RankTooLarge:
2924+
return op.emitError("expected result rank to be smaller or equal to ")
2925+
<< "the source rank.";
2926+
case SubViewVerificationResult::SizeMismatch:
2927+
return op.emitError("expected result type to be ")
2928+
<< expectedType
2929+
<< " or a rank-reduced version. (mismatch of result sizes)";
2930+
case SubViewVerificationResult::StrideMismatch:
2931+
return op.emitError("expected result type to be ")
2932+
<< expectedType
2933+
<< " or a rank-reduced version. (mismatch of result strides)";
2934+
case SubViewVerificationResult::ElemTypeMismatch:
2935+
return op.emitError("expected result element type to be ")
2936+
<< memrefType.getElementType();
2937+
case SubViewVerificationResult::MemSpaceMismatch:
2938+
return op.emitError("expected result and source memory spaces to match.");
2939+
case SubViewVerificationResult::AffineMapMismatch:
2940+
return op.emitError("expected result type to be ")
2941+
<< expectedType
2942+
<< " or a rank-reduced version. (mismatch of result affine map)";
2943+
}
28962944
}
28972945

28982946
template <typename OpType>
@@ -2937,11 +2985,9 @@ static LogicalResult verify(SubViewOp op) {
29372985
baseType, extractFromI64ArrayAttr(op.static_offsets()),
29382986
extractFromI64ArrayAttr(op.static_sizes()),
29392987
extractFromI64ArrayAttr(op.static_strides()));
2940-
if (!isRankReducedType(expectedType, subViewType))
2941-
return op.emitError("expected result type to be ")
2942-
<< expectedType << " or a rank-reduced version.";
29432988

2944-
return success();
2989+
auto result = isRankReducedType(expectedType, subViewType);
2990+
return produceSubViewErrorMsg(result, op, expectedType);
29452991
}
29462992

29472993
raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
@@ -3352,11 +3398,8 @@ static LogicalResult verify(SubTensorOp op) {
33523398
op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
33533399
extractFromI64ArrayAttr(op.static_sizes()),
33543400
extractFromI64ArrayAttr(op.static_strides()));
3355-
if (!isRankReducedType(expectedType, op.getType()))
3356-
return op.emitError("expected result type to be ")
3357-
<< expectedType << " or a rank-reduced version.";
3358-
3359-
return success();
3401+
auto result = isRankReducedType(expectedType, op.getType());
3402+
return produceSubViewErrorMsg(result, op, expectedType);
33603403
}
33613404

33623405
void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,

mlir/test/IR/core-ops.mlir

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
// CHECK-DAG: #[[$SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1 * 2)>
2222
// CHECK-DAG: #[[$SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0 * 36 + d1 * 36 + d2 * 4 + d3 * 4 + d4)>
2323
// CHECK-DAG: #[[$SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5 * s6)>
24+
// CHECK-DAG: #[[$SUBVIEW_MAP8:map[0-9]+]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
2425

2526
// CHECK-LABEL: func @func_with_ops
2627
// CHECK-SAME: %[[ARG:.*]]: f32
@@ -811,11 +812,11 @@ func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
811812

812813
%15 = alloc(%arg1, %arg2)[%c0, %c1, %arg1, %arg0, %arg0, %arg2, %arg2] : memref<1x?x5x1x?x1xf32, affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * d0 + s2 * d1 + s3 * d2 + s4 * d3 + s5 * d4 + s6 * d5)>>
813814
// CHECK: subview %15[0, 0, 0, 0, 0, 0] [1, %arg1, 5, 1, %arg2, 1] [1, 1, 1, 1, 1, 1] :
814-
// CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref<?x5x?xf32>
815-
%16 = subview %15[0, 0, 0, 0, 0, 0][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref<?x5x?xf32>
815+
// CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref<?x5x?xf32, #[[$BASE_MAP3]]>
816+
%16 = subview %15[0, 0, 0, 0, 0, 0][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref<?x5x?xf32, offset: ?, strides: [?, ?, ?]>
816817
// CHECK: subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1] [1, %arg1, 5, 1, %arg2, 1] [1, 1, 1, 1, 1, 1] :
817-
// CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref<?x5x?x1xf32>
818-
%17 = subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref<?x5x?x1xf32>
818+
// CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref<?x5x?x1xf32, #[[$SUBVIEW_MAP8]]>
819+
%17 = subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref<?x5x?x1xf32, offset: ?, strides: [?, ?, ?, ?]>
819820

820821
%18 = alloc() : memref<1x8xf32>
821822
// CHECK: subview %18[0, 0] [1, 8] [1, 1] : memref<1x8xf32> to memref<8xf32>

mlir/test/IR/invalid-ops.mlir

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
10111011

10121012
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
10131013
%0 = alloc() : memref<8x16x4xf32>
1014-
// expected-error@+1 {{expected result type to be 'memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>'}}
1014+
// expected-error@+1 {{expected result type to be 'memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>' or a rank-reduced version. (mismatch of result strides)}}
10151015
%1 = subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2]
10161016
: memref<8x16x4xf32> to
10171017
memref<?x?x?xf32, offset: ?, strides: [64, 4, 1]>
@@ -1020,16 +1020,46 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
10201020

10211021
// -----
10221022

1023+
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
1024+
%0 = alloc() : memref<8x16x4xf32>
1025+
// expected-error@+1 {{expected result element type to be 'f32'}}
1026+
%1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1]
1027+
: memref<8x16x4xf32> to
1028+
memref<8x16x4xi32>
1029+
return
1030+
}
1031+
1032+
// -----
1033+
1034+
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
1035+
%0 = alloc() : memref<8x16x4xf32>
1036+
// expected-error@+1 {{expected result rank to be smaller or equal to the source rank.}}
1037+
%1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1]
1038+
: memref<8x16x4xf32> to
1039+
memref<8x16x4x3xi32>
1040+
return
1041+
}
1042+
1043+
// -----
1044+
10231045
func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
10241046
%0 = alloc() : memref<8x16x4xf32>
1025-
// expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>'}}
1047+
// expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}}
10261048
%1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1]
10271049
: memref<8x16x4xf32> to memref<16x4xf32>
10281050
return
10291051
}
10301052

10311053
// -----
10321054

1055+
func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
1056+
// expected-error@+1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result strides)}}
1057+
%0 = subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32>
1058+
return
1059+
}
1060+
1061+
// -----
1062+
10331063
func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) {
10341064
// expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}}
10351065
%0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]>
@@ -1259,7 +1289,7 @@ func @imaginary_part_from_incompatible_complex_type(%cplx: complex<f64>) {
12591289
// -----
12601290

12611291
func @subtensor_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
1262-
// expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>'}}
1292+
// expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (mismatch of result sizes)}}
12631293
%0 = subtensor %t[0, 2, 0][4, 4, 4][1, 1, 1]
12641294
: tensor<8x16x4xf32> to tensor<?x4x4xf32>
12651295

@@ -1269,7 +1299,7 @@ func @subtensor_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
12691299
// -----
12701300

12711301
func @subtensor_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) {
1272-
// expected-error @+1 {{expected result type to be 'tensor<?x3x?xf32>'}}
1302+
// expected-error @+1 {{expected result type to be 'tensor<?x3x?xf32>' or a rank-reduced version. (mismatch of result sizes)}}
12731303
%0 = subtensor %t[0, 0, 0][%idx, 3, %idx][1, 1, 1]
12741304
: tensor<8x16x4xf32> to tensor<4x4x4xf32>
12751305

0 commit comments

Comments
 (0)