Skip to content

Commit 8c42da7

Browse files
committed
update tests
1 parent 1fee833 commit 8c42da7

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ static OpFoldResult computeStaticShape(OpBuilder &builder, Location loc,
7373
builder, loc, s0 * s1, ArrayRef<OpFoldResult>{dim, stride});
7474
auto constant = getConstantIntValue(size);
7575
assert(constant && "expected constant value");
76-
maxSize = *constant;
76+
maxSize = std::max(maxSize, *constant);
7777
}
7878
return builder.getIndexAttr(maxSize);
7979
}
@@ -104,7 +104,7 @@ static OpFoldResult computeDynamicShape(OpBuilder &builder, Location loc,
104104

105105
/// Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the
106106
/// span of the memref.
107-
static OpFoldResult computeSpan(OpBuilder &builder, Location loc,
107+
static OpFoldResult computeSize(OpBuilder &builder, Location loc,
108108
ArrayRef<OpFoldResult> dims,
109109
ArrayRef<OpFoldResult> strides) {
110110
assert(dims.size() == strides.size() &&
@@ -147,7 +147,7 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
147147
loc, source,
148148
/* offset = */ linearizedInfo.linearizedOffset,
149149
/* shapes = */
150-
ArrayRef<OpFoldResult>{computeSpan(
150+
ArrayRef<OpFoldResult>{computeSize(
151151
rewriter, loc, stridedMetadata.getConstifiedMixedSizes(),
152152
stridedMetadata.getConstifiedMixedStrides())},
153153
/* strides = */

mlir/test/Dialect/MemRef/flatten_memref.mlir

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@ func.func @load_scalar_from_memref(%input: memref<4x8xf32, strided<[8, 1], offse
1414

1515
// -----
1616

17-
func.func @load_scalar_from_memref_static_dim_2(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index) -> f32 {
18-
%value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>>
17+
func.func @load_scalar_from_memref_static_dim_col_major(%input: memref<4x8xf32, strided<[1, 4], offset: 100>>, %row: index, %col: index) -> f32 {
18+
%value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[1, 4], offset: 100>>
1919
return %value : f32
2020
}
2121

22-
// CHECK: [[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 12)>
23-
// CHECK: func @load_scalar_from_memref_static_dim_2
24-
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 12], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
22+
// CHECK: [[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 4)>
23+
// CHECK: func @load_scalar_from_memref_static_dim_col_major
24+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[1, 4], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
2525
// CHECK: %[[IDX:.*]] = affine.apply [[MAP]]()[%[[ARG2]], %[[ARG1]]]
26-
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [12]
27-
// CHECK-SAME: to memref<32xf32, strided<[12], offset: 100>>
26+
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1]
27+
// CHECK-SAME: to memref<32xf32, strided<[1], offset: 100>>
2828
// CHECK: memref.load %[[REINT]][%[[IDX]]]
2929

3030
// -----
@@ -35,27 +35,27 @@ func.func @load_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<[
3535
}
3636

3737
// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s2 * s3)>
38-
// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
38+
// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)>
3939
// CHECK: func @load_scalar_from_memref_dynamic_dim
4040
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
4141
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
4242
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
43-
// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[SIZES]]#0, %[[SIZES]]#1]
44-
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1]
43+
// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[SIZES]]#0, %[[STRIDES]]#0, %[[SIZES]]#1, %[[STRIDES]]#1]
44+
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
4545
// CHECK: memref.load %[[REINT]][%[[IDX]]]
4646

4747
// -----
4848

49-
func.func @store_scalar_from_memref_static_dim(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index, %value: f32) {
50-
memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>>
49+
func.func @store_scalar_from_memref_padded(%input: memref<4x8xf32, strided<[18, 2], offset: 100>>, %row: index, %col: index, %value: f32) {
50+
memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[18, 2], offset: 100>>
5151
return
5252
}
53-
// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 12)>
54-
// CHECK: func @store_scalar_from_memref_static_dim
55-
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 12], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32)
53+
// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 18 + s1 * 2)>
54+
// CHECK: func @store_scalar_from_memref_padded
55+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[18, 2], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32)
5656
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
5757
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
58-
// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]] : memref<32xf32, strided<[12], offset: 100>>
58+
// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]] : memref<72xf32, strided<[1], offset: 100>>
5959

6060
// -----
6161

@@ -64,13 +64,13 @@ func.func @store_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<
6464
return
6565
}
6666
// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s2 * s3)>
67-
// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
67+
// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)>
6868
// CHECK: func @store_scalar_from_memref_dynamic_dim
6969
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32)
7070
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
7171
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
72-
// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[SIZES]]#0, %[[SIZES]]#1]
73-
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1]
72+
// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[SIZES]]#0, %[[STRIDES]]#0, %[[SIZES]]#1, %[[STRIDES]]#1]
73+
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [1]
7474
// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]]
7575

7676
// -----

0 commit comments

Comments
 (0)