Skip to content

Commit 9fdd0b4

Browse files
[MLIR][StaticValueUtils] Fold constant SSA values + fix isRankReduced
-- Ops with dynamic/static offsets/sizes/strides that make use of `dispatchIndexOpFoldResult` API shouldn't can afford to fold away any constant SSA values. This commit adds a fix to the same. -- Consequently it adds a fix to `isRankReducedType` as well as `computeRankReductionMask`. Signed-off-by: Abhishek Varma <[email protected]>
1 parent e25c556 commit 9fdd0b4

File tree

16 files changed

+142
-144
lines changed

16 files changed

+142
-144
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2328,7 +2328,22 @@ LogicalResult ExtractSliceOp::verify() {
23282328
// Verify result type against inferred type.
23292329
RankedTensorType expectedType = ExtractSliceOp::inferResultType(
23302330
getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
2331-
SliceVerificationResult result = isRankReducedType(expectedType, getType());
2331+
SliceVerificationResult result;
2332+
if (getSizes().size() != 0) {
2333+
bool hasNonCstValue = false;
2334+
for (OpFoldResult size : getSizes()) {
2335+
std::optional<int64_t> cst = getConstantIntValue(size);
2336+
if (!cst) {
2337+
hasNonCstValue = true;
2338+
break;
2339+
}
2340+
}
2341+
if (hasNonCstValue && llvm::cast<ShapedType>(getType()).hasStaticShape()) {
2342+
result = SliceVerificationResult::SizeMismatch;
2343+
return produceSliceErrorMsg(result, *this, expectedType);
2344+
}
2345+
}
2346+
result = isRankReducedType(expectedType, getType());
23322347
return produceSliceErrorMsg(result, *this, expectedType);
23332348
}
23342349

@@ -2700,10 +2715,26 @@ static SliceVerificationResult verifyInsertSliceOp(
27002715

27012716
/// Verifier for InsertSliceOp.
27022717
LogicalResult InsertSliceOp::verify() {
2703-
RankedTensorType expectedType;
2704-
SliceVerificationResult result =
2705-
verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2706-
getStaticSizes(), getStaticStrides(), &expectedType);
2718+
// insert_slice is the inverse of extract_slice, use the same type
2719+
// inference.
2720+
RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2721+
getType(), getStaticOffsets(), getStaticSizes(), getStaticStrides());
2722+
SliceVerificationResult result;
2723+
if (getSizes().size() != 0) {
2724+
bool hasNonCstValue = false;
2725+
for (OpFoldResult size : getSizes()) {
2726+
std::optional<int64_t> cst = getConstantIntValue(size);
2727+
if (!cst) {
2728+
hasNonCstValue = true;
2729+
break;
2730+
}
2731+
}
2732+
if (hasNonCstValue && getSourceType().hasStaticShape()) {
2733+
result = SliceVerificationResult::SizeMismatch;
2734+
return produceSliceErrorMsg(result, *this, expectedType);
2735+
}
2736+
}
2737+
result = isRankReducedType(expectedType, getSourceType());
27072738
return produceSliceErrorMsg(result, *this, expectedType);
27082739
}
27092740

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
4141

4242
/// Helper function to dispatch an OpFoldResult into `staticVec` if:
4343
/// a) it is an IntegerAttr
44+
/// b) it is a constant integer value
4445
/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
4546
/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
4647
/// `staticVec`. This is useful to extract mixed static and dynamic entries that
@@ -54,8 +55,13 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
5455
staticVec.push_back(apInt.getSExtValue());
5556
return;
5657
}
57-
dynamicVec.push_back(v);
58-
staticVec.push_back(ShapedType::kDynamic);
58+
std::optional<int64_t> maybeConstantInt = getConstantIntValue(ofr);
59+
if (!maybeConstantInt) {
60+
dynamicVec.push_back(v);
61+
staticVec.push_back(ShapedType::kDynamic);
62+
} else {
63+
staticVec.push_back(*maybeConstantInt);
64+
}
5965
}
6066

6167
std::pair<int64_t, OpFoldResult>

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,8 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
431431
int64_t origSize = originalShape[originalIdx];
432432
// if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
433433
if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
434-
(ShapedType::isDynamic(reducedShape[reducedIdx]) ||
435-
ShapedType::isDynamic(origSize))) {
434+
(ShapedType::isDynamic(origSize) ||
435+
ShapedType::isDynamic(reducedShape[reducedIdx]))) {
436436
reducedIdx++;
437437
continue;
438438
}
@@ -448,7 +448,7 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
448448
return std::nullopt;
449449
}
450450
// The whole reducedShape must be scanned, otherwise we bail.
451-
if (reducedIdx != reducedRank)
451+
if (reducedIdx != reducedRank && originalRank != 1)
452452
return std::nullopt;
453453
return unusedDims;
454454
}
@@ -472,8 +472,8 @@ mlir::isRankReducedType(ShapedType originalType,
472472
if (candidateReducedRank > originalRank)
473473
return SliceVerificationResult::RankTooLarge;
474474

475-
auto optionalUnusedDimsMask =
476-
computeRankReductionMask(originalShape, candidateReducedShape);
475+
auto optionalUnusedDimsMask = computeRankReductionMask(
476+
originalShape, candidateReducedShape, /*matchDynamic=*/true);
477477

478478
// Sizes cannot be matched in case empty vector is returned.
479479
if (!optionalUnusedDimsMask)

mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,9 +622,9 @@ func.func @split_at(%shape: tensor<?xindex>, %index: index) -> (tensor<?xindex>,
622622
// CHECK-NEXT: %[[ISNEG:.*]] = arith.cmpi slt, %[[INDEX]], %[[C0]] : index
623623
// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[ISNEG]], %[[POSINDEX]], %[[INDEX]] : index
624624
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
625-
// CHECK-NEXT: %[[HEAD:.*]] = tensor.extract_slice %[[SHAPE]][%[[C0]]] [%[[SELECT]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
625+
// CHECK-NEXT: %[[HEAD:.*]] = tensor.extract_slice %[[SHAPE]][0] [%[[SELECT]]] [1] : tensor<?xindex> to tensor<?xindex>
626626
// CHECK-NEXT: %[[TAIL_SIZE:.*]] = arith.subi %[[RANK]], %[[SELECT]] : index
627-
// CHECK-NEXT: %[[TAIL:.*]] = tensor.extract_slice %[[SHAPE]][%[[SELECT]]] [%[[TAIL_SIZE]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
627+
// CHECK-NEXT: %[[TAIL:.*]] = tensor.extract_slice %[[SHAPE]][%[[SELECT]]] [%[[TAIL_SIZE]]] [1] : tensor<?xindex> to tensor<?xindex>
628628
// CHECK-NEXT: return %[[HEAD]], %[[TAIL]] : tensor<?xindex>, tensor<?xindex>
629629
%head, %tail = "shape.split_at"(%shape, %index) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
630630
return %head, %tail : tensor<?xindex>, tensor<?xindex>

mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ func.func @scatter_test(%values_in: tensor<3x7x5xi32>, %indices : tensor<3x6xi32
7272
// CHECK: [[RESULT_1:%.+]] = scf.for [[ITER_VAR_1:%.+]] = [[C_0_0]] to [[C_6]] step [[C_1_0]] iter_args([[ITER_ARG_1:%.+]] = [[ITER_ARG_0]]) -> (tensor<3x7x5xi32>) {
7373
// CHECK-DAG: [[EXTRACTED:%.+]] = tensor.extract [[INDICES]][[[ITER_VAR_0]], [[ITER_VAR_1]]] : tensor<3x6xi32>
7474
// CHECK-DAG: [[EXTRACTED_CAST:%.+]] = arith.index_cast [[EXTRACTED]] : i32 to index
75-
// CHECK-DAG: [[EXTRACTED_SLICE:%.+]] = tensor.extract_slice [[INPUT]][[[ITER_VAR_0]], [[ITER_VAR_1]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<3x6x5xi32> to tensor<?x?x?xi32>
76-
// CHECK-DAG: [[INSERTED_SLICE:%.+]] = tensor.insert_slice [[EXTRACTED_SLICE]] into [[ITER_ARG_1]][[[ITER_VAR_0]], [[EXTRACTED_CAST]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<?x?x?xi32> into tensor<3x7x5xi32>
75+
// CHECK-DAG: [[EXTRACTED_SLICE:%.+]] = tensor.extract_slice [[INPUT]][[[ITER_VAR_0]], [[ITER_VAR_1]], 0] [1, 1, 5] [1, 1, 1] : tensor<3x6x5xi32> to tensor<1x1x5xi32>
76+
// CHECK-DAG: [[INSERTED_SLICE:%.+]] = tensor.insert_slice [[EXTRACTED_SLICE]] into [[ITER_ARG_1]][[[ITER_VAR_0]], [[EXTRACTED_CAST]], 0] [1, 1, 5] [1, 1, 1] : tensor<1x1x5xi32> into tensor<3x7x5xi32>
7777
// CHECK: scf.yield [[INSERTED_SLICE]] : tensor<3x7x5xi32>
7878
// CHECK: }
7979
// CHECK: scf.yield [[RESULT_1]] : tensor<3x7x5xi32>

mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
466466
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
467467
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
468468
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
469-
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
469+
// CHECK: tensor.pad %[[ARG0]] low[1, 3] high[2, 4] {
470470
// CHECK: tensor.yield [[CST]]
471471
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
472472
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
@@ -501,7 +501,7 @@ func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
501501
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
502502
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
503503
// CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
504-
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
504+
// CHECK: tensor.pad %[[ARG0]] low[1, 3] high[2, 4] {
505505
// CHECK: tensor.yield [[CST]]
506506
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
507507
%1 = arith.constant dense<42.0> : tensor<f32>
@@ -519,26 +519,26 @@ func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
519519
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
520520
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
521521
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
522-
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
522+
// CHECK: tensor.pad %[[ARG0]] low[1, 3] high[2, 4] {
523523
// CHECK: tensor.yield [[CST]]
524524
// CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
525525
%1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
526526
return %1 : tensor<?x9xf32>
527527
}
528528

529-
func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
529+
func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<2x9xf32>) {
530530
%0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
531531
// TODO: Output contains multiple "arith.constant 1 : index".
532532
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
533533
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
534534
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
535535
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
536536
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
537-
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
537+
// CHECK: tensor.pad %[[ARG0]] low[-1, 3] high[2, 4] {
538538
// CHECK: tensor.yield [[CST]]
539-
// CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
540-
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
541-
return %1 : tensor<?x9xf32>
539+
// CHECK: } : tensor<1x2xf32> to tensor<2x9xf32>
540+
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<2x9xf32>)
541+
return %1 : tensor<2x9xf32>
542542
}
543543

544544
// -----

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,8 @@ func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %memref: memre
416416
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
417417
// CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
418418
// CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
419-
// CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
420-
// CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
421-
// CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
419+
// CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[READ_SUBVIEW]] (d0, d1) -> (d1, d0) : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<4x?xf32, strided<[1, ?], offset: ?>>
420+
// CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] {in_bounds = [true, false]} : memref<4x?xf32, strided<[1, ?], offset: ?>>, vector<4x[8]xf32>
422421
// CHECK-NEXT: return %[[LEGAL_READ]]
423422
%pad = arith.constant 0.0 : f32
424423
%illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xf32>, vector<[8]x4xf32>
@@ -434,11 +433,10 @@ func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %memref: memre
434433
// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>
435434
func.func @lift_illegal_transpose_to_memory_with_mask(%dim0: index, %dim1: index, %memref: memref<?x?xf32>, %a: index, %b: index) -> vector<4x[8]xf32> {
436435
// CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
437-
// CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
438-
// CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
436+
// CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[READ_SUBVIEW]]
439437
// CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
440438
// CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
441-
// CHECK-SAME: %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
439+
// CHECK-SAME: %[[MASK]] {in_bounds = [true, false]} : memref<4x?xf32, strided<[1, ?], offset: ?>>, vector<4x[8]xf32>
442440
// CHECK-NEXT: return %[[LEGAL_READ]]
443441
%pad = arith.constant 0.0 : f32
444442
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
@@ -453,8 +451,7 @@ func.func @lift_illegal_transpose_to_memory_with_mask(%dim0: index, %dim1: index
453451
// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>
454452
func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
455453
// CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
456-
// CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
457-
// CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
454+
// CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[READ_SUBVIEW]]
458455
// CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
459456
// CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
460457
// CHECK-NEXT: return %[[EXT_TYPE]]
@@ -514,7 +511,7 @@ func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector
514511

515512
// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
516513
func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
517-
// CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
514+
// CHECK: vector.transfer_read {{.*}} : memref<1x?xf32, {{.*}}>, vector<1x[4]xf32>
518515
// CHECK-NOT: vector.shape_cast
519516
%pad = arith.constant 0.0 : f32
520517
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
@@ -526,7 +523,7 @@ func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: m
526523

527524
// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
528525
func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
529-
// CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
526+
// CHECK: vector.transfer_read {{.*}} : memref<1x?xf32, {{.*}}>, vector<1x[4]xf32>
530527
// CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
531528
%pad = arith.constant 0.0 : f32
532529
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>

mlir/test/Dialect/GPU/decompose-memrefs.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func.func @decompose_load(%arg0 : memref<?x?x?xf32>) {
8787
// CHECK: gpu.launch
8888
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
8989
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
90-
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
90+
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [2, 2, 2], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
9191
// CHECK: "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
9292
func.func @decompose_subview(%arg0 : memref<?x?x?xf32>) {
9393
%c0 = arith.constant 0 : index
@@ -118,7 +118,7 @@ func.func @decompose_subview(%arg0 : memref<?x?x?xf32>) {
118118
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[STRIDES]]#0]
119119
// CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#1]
120120
// CHECK: %[[IDX2:.*]] = affine.apply #[[MAP2]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
121-
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX2]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[IDX]], %[[IDX1]], 4]
121+
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX2]]], sizes: [2, 2, 2], strides: [%[[IDX]], %[[IDX1]], 4]
122122
// CHECK: "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>) -> ()
123123
func.func @decompose_subview_strided(%arg0 : memref<?x?x?xf32>) {
124124
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)