|
| 1 | +// RUN: mlir-opt --split-input-file --enforce-immutable-func-args %s -o - | FileCheck %s |
| 2 | + |
| 3 | + |
| 4 | +// CHECK-LABEL: func.func @func_no_input() { |
| 5 | +// CHECK: return |
| 6 | +// CHECK: } |
| 7 | + |
| 8 | +func.func @func_no_input() { |
| 9 | + return |
| 10 | +} |
| 11 | + |
| 12 | +// ----- |
| 13 | + |
| 14 | +// CHECK-LABEL: func.func private @func_with_returned_argument( |
| 15 | +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> { |
| 16 | +// CHECK: return %[[VAL_0]] : memref<1x13x21x3xf32> |
| 17 | +// CHECK: } |
| 18 | + |
| 19 | +func.func private @func_with_returned_argument(%arg0: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>) { |
| 20 | + return %arg0 : memref<1x13x21x3xf32> |
| 21 | +} |
| 22 | + |
| 23 | +// ----- |
| 24 | + |
| 25 | +// CHECK-LABEL: func.func private @func_with_modified_argument_directly( |
| 26 | +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> { |
| 27 | +// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32> |
| 28 | +// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32> |
| 29 | +// CHECK: %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32> |
| 30 | +// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) { |
| 31 | +// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32): |
| 32 | +// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : f32 |
| 33 | +// CHECK: linalg.yield %[[VAL_7]] : f32 |
| 34 | +// CHECK: } |
| 35 | +// CHECK: return %[[VAL_3]] : memref<1x13x21x3xf32> |
| 36 | +// CHECK: } |
| 37 | + |
| 38 | +func.func private @func_with_modified_argument_directly(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){ |
| 39 | + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32> |
| 40 | + linalg.generic { |
| 41 | + indexing_maps = [ |
| 42 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 43 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 44 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
| 45 | + ], |
| 46 | + iterator_types = ["parallel", "parallel", "parallel", "parallel"] |
| 47 | + } |
| 48 | + ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) |
| 49 | + outs(%arg0 : memref<1x13x21x3xf32>) { |
| 50 | + ^bb0(%in: f32, %in_0: f32, %out: f32): |
| 51 | + %0 = arith.addf %in, %in_0 : f32 |
| 52 | + linalg.yield %0 : f32 |
| 53 | + } |
| 54 | + return %alloc : memref<1x13x21x3xf32> |
| 55 | +} |
| 56 | + |
| 57 | +// ----- |
| 58 | + |
| 59 | +// CHECK-LABEL: func.func private @func_with_modified_argument_directly_and_returned( |
| 60 | +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> { |
| 61 | +// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32> |
| 62 | +// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32> |
| 63 | +// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) { |
| 64 | +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32): |
| 65 | +// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 |
| 66 | +// CHECK: linalg.yield %[[VAL_6]] : f32 |
| 67 | +// CHECK: } |
| 68 | +// CHECK: return %[[VAL_2]] : memref<1x13x21x3xf32> |
| 69 | +// CHECK: } |
| 70 | + |
| 71 | +func.func private @func_with_modified_argument_directly_and_returned(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){ |
| 72 | + linalg.generic { |
| 73 | + indexing_maps = [ |
| 74 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 75 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 76 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
| 77 | + ], |
| 78 | + iterator_types = ["parallel", "parallel", "parallel", "parallel"] |
| 79 | + } |
| 80 | + ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) |
| 81 | + outs(%arg0 : memref<1x13x21x3xf32>) { |
| 82 | + ^bb0(%in: f32, %in_0: f32, %out: f32): |
| 83 | + %0 = arith.addf %in, %in_0 : f32 |
| 84 | + linalg.yield %0 : f32 |
| 85 | + } |
| 86 | + return %arg0 : memref<1x13x21x3xf32> |
| 87 | +} |
| 88 | + |
| 89 | +// ----- |
| 90 | + |
| 91 | +// CHECK-LABEL: func.func private @func_with_modified_argument_directly_twice( |
| 92 | +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> { |
| 93 | +// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32> |
| 94 | +// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32> |
| 95 | +// CHECK: %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32> |
| 96 | +// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) { |
| 97 | +// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32): |
| 98 | +// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : f32 |
| 99 | +// CHECK: linalg.yield %[[VAL_7]] : f32 |
| 100 | +// CHECK: } |
| 101 | +// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) { |
| 102 | +// CHECK: ^bb0(%[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32): |
| 103 | +// CHECK: %[[VAL_11:.*]] = arith.addf %[[VAL_8]], %[[VAL_9]] : f32 |
| 104 | +// CHECK: linalg.yield %[[VAL_11]] : f32 |
| 105 | +// CHECK: } |
| 106 | +// CHECK: return %[[VAL_3]] : memref<1x13x21x3xf32> |
| 107 | +// CHECK: } |
| 108 | + |
| 109 | +func.func private @func_with_modified_argument_directly_twice(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){ |
| 110 | + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32> |
| 111 | + linalg.generic { |
| 112 | + indexing_maps = [ |
| 113 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 114 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 115 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
| 116 | + ], |
| 117 | + iterator_types = ["parallel", "parallel", "parallel", "parallel"] |
| 118 | + } |
| 119 | + ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) |
| 120 | + outs(%arg0 : memref<1x13x21x3xf32>) { |
| 121 | + ^bb0(%in: f32, %in_0: f32, %out: f32): |
| 122 | + %0 = arith.addf %in, %in_0 : f32 |
| 123 | + linalg.yield %0 : f32 |
| 124 | + } |
| 125 | + linalg.generic { |
| 126 | + indexing_maps = [ |
| 127 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 128 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 129 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
| 130 | + ], |
| 131 | + iterator_types = ["parallel", "parallel", "parallel", "parallel"] |
| 132 | + } |
| 133 | + ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) |
| 134 | + outs(%arg0 : memref<1x13x21x3xf32>) { |
| 135 | + ^bb0(%in: f32, %in_0: f32, %out: f32): |
| 136 | + %0 = arith.addf %in, %in_0 : f32 |
| 137 | + linalg.yield %0 : f32 |
| 138 | + } |
| 139 | + return %alloc : memref<1x13x21x3xf32> |
| 140 | +} |
| 141 | + |
| 142 | +// ----- |
| 143 | + |
| 144 | +// CHECK-LABEL: func.func private @func_with_modified_argument_directly( |
| 145 | +// CHECK-SAME: %[[VAL_0:.*]]: memref<5xi32, 1>, %[[VAL_1:.*]]: memref<5xi32, 1>, %[[VAL_2:.*]]: memref<5xi32, 1>) -> memref<5xi32, 1> { |
| 146 | +// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<5xi32, 1> |
| 147 | +// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<5xi32, 1> to memref<5xi32, 1> |
| 148 | +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index |
| 149 | +// CHECK: %[[VAL_5:.*]] = arith.constant 5 : index |
| 150 | +// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index |
| 151 | +// CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_4]] { |
| 152 | +// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_7]]] : memref<5xi32, 1> |
| 153 | +// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index |
| 154 | +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_7]]] : memref<5xi32, 1> |
| 155 | +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] : memref<5xi32, 1> |
| 156 | +// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : i32 |
| 157 | +// CHECK: memref.store %[[VAL_12]], %[[VAL_3]]{{\[}}%[[VAL_9]]] : memref<5xi32, 1> |
| 158 | +// CHECK: } |
| 159 | +// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<5xi32, 1> |
| 160 | +// CHECK: memref.copy %[[VAL_3]], %[[VAL_13]] : memref<5xi32, 1> to memref<5xi32, 1> |
| 161 | +// CHECK: return %[[VAL_13]] : memref<5xi32, 1> |
| 162 | +// CHECK: } |
| 163 | + |
| 164 | +func.func private @func_with_modified_argument_directly(%arg0: memref<5xi32, 1>, %arg1: memref<5xi32, 1>, %arg2: memref<5xi32, 1>) -> (memref<5xi32, 1>){ |
| 165 | + %c1 = arith.constant 1 : index |
| 166 | + %c5 = arith.constant 5 : index |
| 167 | + %c0 = arith.constant 0 : index |
| 168 | + scf.for %arg3 = %c0 to %c5 step %c1 { |
| 169 | + %0 = memref.load %arg0[%arg3] : memref<5xi32, 1> |
| 170 | + %1 = arith.index_cast %0 : i32 to index |
| 171 | + %2 = memref.load %arg1[%arg3] : memref<5xi32, 1> |
| 172 | + %3 = memref.load %arg2[%1] : memref<5xi32, 1> |
| 173 | + %4 = arith.addi %2, %3 : i32 |
| 174 | + memref.store %4, %arg2[%1] : memref<5xi32, 1> |
| 175 | + } |
| 176 | + %alloc = memref.alloc() : memref<5xi32, 1> |
| 177 | + memref.copy %arg2, %alloc : memref<5xi32, 1> to memref<5xi32, 1> |
| 178 | + return %alloc : memref<5xi32, 1> |
| 179 | +} |
| 180 | + |
| 181 | +// ----- |
| 182 | + |
| 183 | +// CHECK-LABEL: func.func private @func_with_modified_argument_indirectly( |
| 184 | +// CHECK-SAME: %[[VAL_0:.*]]: memref<3x3x4xf32, 1>) -> memref<3x3x4xf32, 1> { |
| 185 | +// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<3x3x4xf32, 1> |
| 186 | +// CHECK: memref.copy %[[VAL_0]], %[[VAL_1]] : memref<3x3x4xf32, 1> to memref<3x3x4xf32, 1> |
| 187 | +// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2]] : memref<3x3x4xf32, 1> into memref<9x4xf32, 1> |
| 188 | +// CHECK: %[[VAL_3:.*]] = memref.expand_shape %[[VAL_2]] {{\[\[}}0, 1], [2]] output_shape [3, 3, 4] : memref<9x4xf32, 1> into memref<3x3x4xf32, 1> |
| 189 | +// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%[[VAL_3]] : memref<3x3x4xf32, 1>) { |
| 190 | +// CHECK: ^bb0(%[[VAL_4:.*]]: f32): |
| 191 | +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_4]], %[[VAL_4]] : f32 |
| 192 | +// CHECK: linalg.yield %[[VAL_5]] : f32 |
| 193 | +// CHECK: } |
| 194 | +// CHECK: return %[[VAL_3]] : memref<3x3x4xf32, 1> |
| 195 | +// CHECK: } |
| 196 | + |
| 197 | +func.func private @func_with_modified_argument_indirectly(%arg0: memref<3x3x4xf32, 1>) -> (memref<3x3x4xf32, 1>) { |
| 198 | + %collapse_arg = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<3x3x4xf32, 1> into memref<9x4xf32, 1> |
| 199 | + %expand_arg = memref.expand_shape %collapse_arg [[0, 1], [2]] output_shape [3, 3, 4] : memref<9x4xf32, 1> into memref<3x3x4xf32, 1> |
| 200 | + linalg.generic { |
| 201 | + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], |
| 202 | + iterator_types = ["parallel", "parallel", "parallel"] |
| 203 | + } |
| 204 | + outs(%expand_arg : memref<3x3x4xf32, 1>) { |
| 205 | + ^bb0(%out: f32): |
| 206 | + %0 = arith.addf %out, %out : f32 |
| 207 | + linalg.yield %0 : f32 |
| 208 | + } |
| 209 | + return %expand_arg: memref<3x3x4xf32, 1> |
| 210 | +} |
| 211 | + |
| 212 | +// ----- |
| 213 | + |
| 214 | +// CHECK-LABEL: func.func private @func_with_modified_argument_subview( |
| 215 | +// CHECK-SAME: %[[VAL_0:.*]]: memref<2x4x4xi32, 1>) -> memref<4x4xi32, 1> { |
| 216 | +// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<2x4x4xi32, 1> |
| 217 | +// CHECK: memref.copy %[[VAL_0]], %[[VAL_1]] : memref<2x4x4xi32, 1> to memref<2x4x4xi32, 1> |
| 218 | +// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_1]][0, 0, 0] [1, 4, 4] [1, 1, 1] : memref<2x4x4xi32, 1> to memref<1x4x4xi32, strided<[16, 4, 1]>, 1> |
| 219 | +// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1], [2]] : memref<1x4x4xi32, strided<[16, 4, 1]>, 1> into memref<4x4xi32, strided<[4, 1]>, 1> |
| 220 | +// CHECK: %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<4x4xi32, strided<[4, 1]>, 1> to memref<4x4xi32, 1> |
| 221 | +// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_4]] : memref<4x4xi32, 1>) { |
| 222 | +// CHECK: ^bb0(%[[VAL_5:.*]]: i32): |
| 223 | +// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : i32 |
| 224 | +// CHECK: linalg.yield %[[VAL_6]] : i32 |
| 225 | +// CHECK: } |
| 226 | +// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<4x4xi32, 1> |
| 227 | +// CHECK: memref.copy %[[VAL_4]], %[[VAL_7]] : memref<4x4xi32, 1> to memref<4x4xi32, 1> |
| 228 | +// CHECK: return %[[VAL_7]] : memref<4x4xi32, 1> |
| 229 | +// CHECK: } |
| 230 | + |
| 231 | +func.func private @func_with_modified_argument_subview(%arg0: memref<2x4x4xi32, 1>) -> ( memref<4x4xi32, 1>){ |
| 232 | + %subview = memref.subview %arg0[0, 0, 0] [1, 4, 4] [1, 1, 1] : memref<2x4x4xi32, 1> to memref<1x4x4xi32, strided<[16, 4, 1]>, 1> |
| 233 | + %collapse_shape = memref.collapse_shape %subview [[0, 1], [2]] : memref<1x4x4xi32, strided<[16, 4, 1]>, 1> into memref<4x4xi32, strided<[4, 1]>, 1> |
| 234 | + %cast = memref.cast %collapse_shape : memref<4x4xi32, strided<[4, 1]>, 1> to memref<4x4xi32, 1> |
| 235 | + linalg.generic { |
| 236 | + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], |
| 237 | + iterator_types = ["parallel", "parallel"] |
| 238 | + } |
| 239 | + outs(%cast : memref<4x4xi32, 1>) { |
| 240 | + ^bb0(%out: i32): |
| 241 | + %0 = arith.addi %out, %out : i32 |
| 242 | + linalg.yield %0 : i32 |
| 243 | + } |
| 244 | + %alloc = memref.alloc() : memref<4x4xi32, 1> |
| 245 | + memref.copy %cast, %alloc : memref<4x4xi32, 1> to memref<4x4xi32, 1> |
| 246 | + return %alloc : memref<4x4xi32, 1> |
| 247 | +} |
| 248 | + |
0 commit comments