Skip to content

Commit 8ae2e8d

Browse files
committed
[mlir][memref]: Add Check and Negative Test
Add a missing check and negative test. Signed-off-by: Jack Frankland <[email protected]>
1 parent 41d948a commit 8ae2e8d

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,11 +377,17 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
377377
if (!op.getPermutationMap().isMinorIdentity())
378378
return failure();
379379

380+
// We only support the case where the source of the expand shape has
381+
// rank greater than or equal to the vector rank.
382+
const int64_t sourceRank = sourceIndices.size();
383+
const int64_t vectorRank = op.getVectorType().getRank();
384+
if (sourceRank < vectorRank)
385+
return failure();
386+
380387
// We need to construct a new minor identity map since we will have lost
381388
// some dimensions in folding away the expand shape.
382-
auto minorIdMap = AffineMap::getMinorIdentityMap(
383-
sourceIndices.size(), op.getVectorType().getRank(),
384-
op.getContext());
389+
auto minorIdMap = AffineMap::getMinorIdentityMap(sourceRank, vectorRank,
390+
op.getContext());
385391

386392
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
387393
op, op.getVectorType(), expandShapeOp.getViewSource(),

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,21 @@ func.func @fold_vector_transfer_read_with_perm_map(
10261026

10271027
// -----
10281028

1029+
func.func @fold_vector_transfer_read_rank_mismatch(
1030+
%arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> {
1031+
%c0 = arith.constant 0 : index
1032+
%pad = ub.poison : f32
1033+
%0 = memref.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 4, 4] : memref<32xf32> into memref<2x4x4xf32>
1034+
%1 = vector.transfer_read %0[%arg1, %c0, %c0], %pad {in_bounds = [true, true]} : memref<2x4x4xf32>, vector<4x4xf32>
1035+
return %1 : vector<4x4xf32>
1036+
}
1037+
1038+
// CHECK-LABEL: func @fold_vector_transfer_read_rank_mismatch
1039+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
1040+
// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1, 2]] output_shape [2, 4, 4] : memref<32xf32> into memref<2x4x4xf32>
1041+
1042+
// -----
1043+
10291044
func.func @fold_vector_load_collapse_shape(
10301045
%arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> {
10311046
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>

0 commit comments

Comments
 (0)