@@ -7,25 +7,11 @@ func.func @load_scalar_from_memref(%input: memref<4x8xf32, strided<[8, 1], offse
77 return %value : f32
88}
99// CHECK-LABEL: func @load_scalar_from_memref
10- // CHECK: %[[C10:.*]] = arith.constant 10 : index
11- // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [100], sizes: [32], strides: [1]
10+ // CHECK-NEXT : %[[C10:.*]] = arith.constant 10 : index
11+ // CHECK-NEXT : %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [100], sizes: [32], strides: [1]
1212// CHECK-SAME: memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
13- // CHECK: memref.load %[[REINT]][%[[C10]]] : memref<32xf32, strided<[1], offset: 100>>
14-
15- // -----
16-
17- func.func @load_scalar_from_memref_static_dim_col_major (%input: memref <4 x8 xf32 , strided <[1 , 4 ], offset : 100 >>, %row: index , %col: index ) -> f32 {
18- %value = memref.load %input [%col , %row ] : memref <4 x8 xf32 , strided <[1 , 4 ], offset : 100 >>
19- return %value : f32
20- }
13+ // CHECK-NEXT: memref.load %[[REINT]][%[[C10]]] : memref<32xf32, strided<[1], offset: 100>>
2114
22- // CHECK: [[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 4)>
23- // CHECK: func @load_scalar_from_memref_static_dim_col_major
24- // CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[1, 4], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
25- // CHECK: %[[IDX:.*]] = affine.apply [[MAP]]()[%[[ARG2]], %[[ARG1]]]
26- // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1]
27- // CHECK-SAME: to memref<32xf32, strided<[1], offset: 100>>
28- // CHECK: memref.load %[[REINT]][%[[IDX]]]
2915
3016// -----
3117
@@ -46,6 +32,21 @@ func.func @load_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<[
4632
4733// -----
4834
35+ func.func @load_scalar_from_memref_static_dim (%input: memref <8 x12 xf32 , strided <[24 , 2 ], offset : 100 >>) -> f32 {
36+ %c7 = arith.constant 7 : index
37+ %c10 = arith.constant 10 : index
38+ %value = memref.load %input [%c7 , %c10 ] : memref <8 x12 xf32 , strided <[24 , 2 ], offset : 100 >>
39+ return %value : f32
40+ }
41+
42+ // CHECK-LABEL: func @load_scalar_from_memref_static_dim
43+ // CHECK-SAME: (%[[ARG0:.*]]: memref<8x12xf32, strided<[24, 2], offset: 100>>)
44+ // CHECK: %[[C188:.*]] = arith.constant 188 : index
45+ // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [192], strides: [1] : memref<8x12xf32, strided<[24, 2], offset: 100>> to memref<192xf32, strided<[1], offset: 100>>
46+ // CHECK: memref.load %[[REINT]][%[[C188]]] : memref<192xf32, strided<[1], offset: 100>>
47+
48+ // -----
49+
4950func.func @store_scalar_from_memref_padded (%input: memref <4 x8 xf32 , strided <[18 , 2 ], offset : 100 >>, %row: index , %col: index , %value: f32 ) {
5051 memref.store %value , %input [%col , %row ] : memref <4 x8 xf32 , strided <[18 , 2 ], offset : 100 >>
5152 return
@@ -256,3 +257,17 @@ func.func @chained_alloc_load() -> vector<8xf32> {
256257// CHECK-NEXT: %[[C30:.*]] = arith.constant 30 : index
257258// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32, strided<[1]>>
258259// CHECK-NEXT: vector.load %[[ALLOC]][%[[C30]]] : memref<32xf32, strided<[1]>>, vector<8xf32>
260+
261+ // -----
262+
263+ func.func @load_scalar_from_memref_static_dim_col_major (%input: memref <4 x8 xf32 , strided <[1 , 4 ], offset : 100 >>, %row: index , %col: index ) -> f32 {
264+ %value = memref.load %input [%col , %row ] : memref <4 x8 xf32 , strided <[1 , 4 ], offset : 100 >>
265+ return %value : f32
266+ }
267+
268+ // CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 4)>
269+ // CHECK: func @load_scalar_from_memref_static_dim_col_major
270+ // CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[1, 4], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
271+ // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
272+ // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[1, 4], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
273+ // CHECK: memref.load %[[REINT]][%[[IDX]]] : memref<32xf32, strided<[1], offset: 100>>
0 commit comments