@@ -100,6 +100,35 @@ func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memre
100100
101101// -----
102102
103+ // CHECK-DAG: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 * 7680 + d1 * 320 + d2 * 10 + d3)>
104+ // CHECK-DAG: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
105+
106+ // CHECK-LABEL: func.func @collapsable_memref_projected_ops(
107+ // CHECK-SAME: %[[ARG0:.*]]: memref<1x24x32x8xf32>, %[[ARG1:.*]]: memref<1x24x32x8xf32>, %[[ARG2:.*]]: memref<1x24x32x8xf32, #[[$ATTR_0]]>) {
108+ // CHECK: %[[VAL_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32> into memref<1x768x8xf32>
109+ // CHECK: %[[VAL_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32> into memref<1x768x8xf32>
110+ // CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32, #[[$ATTR_0]]> into memref<1x768x8xf32, strided<[7680, 10, 1]>>
111+ // CHECK: linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<1x768x8xf32>, memref<1x768x8xf32>) outs(%[[VAL_2]] : memref<1x768x8xf32, strided<[7680, 10, 1]>>) {
112+ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
113+ // CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32
114+ // CHECK: linalg.yield %[[VAL_6]] : f32
115+ // CHECK: }
116+ // CHECK: return
117+ // CHECK: }
118+
119+ #map = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 , d1 )>
120+ #map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 * 7680 + d1 * 320 + d2 * 10 + d3 )>
121+ func.func @collapsable_memref_projected_ops (%arg0: memref <1 x24 x32 x8 xf32 >, %arg1: memref <1 x24 x32 x8 xf32 >, %arg2: memref <1 x24 x32 x8 xf32 , #map1 >) {
122+ linalg.generic {index ing_maps = [#map , #map , #map ], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]} ins (%arg0 , %arg1 : memref <1 x24 x32 x8 xf32 >, memref <1 x24 x32 x8 xf32 >) outs (%arg2 : memref <1 x24 x32 x8 xf32 , #map1 >) {
123+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
124+ %0 = arith.addf %in , %in_0 : f32
125+ linalg.yield %0 : f32
126+ }
127+ return
128+ }
129+
130+ // -----
131+
103132// CHECK-LABEL: func @uncollapsable_strided_memref(
104133// CHECK: linalg.generic
105134// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
@@ -119,6 +148,23 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
119148
120149// -----
121150
151+ // CHECK-LABEL: func @uncollapsable_memref_projected_ops(
152+ // CHECK: linalg.generic
153+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
154+
155+ #map = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 , d1 )>
156+ #map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 * 7680 + d1 * 320 + d2 * 8 + d3 )>
157+ func.func @uncollapsable_memref_projected_ops (%arg0: memref <1 x24 x32 x8 xf32 >, %arg1: memref <1 x24 x32 x8 xf32 >, %arg2: memref <1 x24 x32 x8 xf32 , #map1 >) {
158+ linalg.generic {index ing_maps = [#map , #map , #map ], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]} ins (%arg0 , %arg1 : memref <1 x24 x32 x8 xf32 >, memref <1 x24 x32 x8 xf32 >) outs (%arg2 : memref <1 x24 x32 x8 xf32 , #map1 >) {
159+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
160+ %0 = arith.addf %in , %in_0 : f32
161+ linalg.yield %0 : f32
162+ }
163+ return
164+ }
165+
166+ // -----
167+
122168// CHECK-LABEL: func.func @linalg_copy(
123169// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
124170// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
0 commit comments