|
| 1 | +func.func @gather_from_splat_tensor() { |
| 2 | + %source = util.unfoldable_constant dense<0> : tensor<10x10xi32> |
| 3 | + %empty = tensor.empty() : tensor<1x10xi32> |
| 4 | + %indices = util.unfoldable_constant dense<0> : tensor<1xi32> |
| 5 | + %result = iree_linalg_ext.gather dimension_map = [0] |
| 6 | + ins(%source, %indices : tensor<10x10xi32>, tensor<1xi32>) |
| 7 | + outs(%empty : tensor<1x10xi32>) { |
| 8 | + ^bb0(%arg0: i32, %arg1: i32): |
| 9 | + iree_linalg_ext.yield %arg0 : i32 |
| 10 | + } -> tensor<1x10xi32> |
| 11 | + |
| 12 | + check.expect_eq_const(%result, dense<0> : tensor<1x10xi32>) |
| 13 | + : tensor<1x10xi32> |
| 14 | + return |
| 15 | +} |
| 16 | + |
| 17 | +func.func @gather_2d_index_with_batch() { |
| 18 | + %source = util.unfoldable_constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> |
| 19 | + %empty = tensor.empty() : tensor<2xi32> |
| 20 | + %indices = util.unfoldable_constant dense<[[0, 1], [1, 0]]> : tensor<2x2xi32> |
| 21 | + %result = iree_linalg_ext.gather dimension_map = [0, 1] |
| 22 | + ins(%source, %indices : tensor<2x2xi32>, tensor<2x2xi32>) |
| 23 | + outs(%empty: tensor<2xi32>) { |
| 24 | + ^bb0(%arg0: i32, %arg1: i32): |
| 25 | + iree_linalg_ext.yield %arg0 : i32 |
| 26 | + } -> tensor<2xi32> |
| 27 | + check.expect_eq_const(%result, dense<[1, 2]> : tensor<2xi32>) : tensor<2xi32> |
| 28 | + return |
| 29 | +} |
| 30 | + |
| 31 | +func.func @gather_2d_index_no_batch() { |
| 32 | + %source = util.unfoldable_constant dense<[[[0], [1]], [[0], [0]]]> : tensor<2x2x1xi32> |
| 33 | + %empty = tensor.empty() : tensor<1xi32> |
| 34 | + %indices = util.unfoldable_constant dense<[0, 1]> : tensor<2xi32> |
| 35 | + %result = iree_linalg_ext.gather dimension_map = [0, 1] |
| 36 | + ins(%source, %indices : tensor<2x2x1xi32>, tensor<2xi32>) |
| 37 | + outs(%empty: tensor<1xi32>) { |
| 38 | + ^bb0(%arg0: i32, %arg1: i32): |
| 39 | + iree_linalg_ext.yield %arg0 : i32 |
| 40 | + } -> tensor<1xi32> |
| 41 | + check.expect_eq_const(%result, dense<[1]> : tensor<1xi32>) : tensor<1xi32> |
| 42 | + return |
| 43 | +} |
| 44 | + |
| 45 | +func.func @gather_1d_index_no_batch() { |
| 46 | + %source = util.unfoldable_constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> |
| 47 | + %empty = tensor.empty() : tensor<2xi32> |
| 48 | + %indices = util.unfoldable_constant dense<[1]> : tensor<1xi32> |
| 49 | + %result = iree_linalg_ext.gather dimension_map = [0] |
| 50 | + ins(%source, %indices : tensor<2x2xi32>, tensor<1xi32>) |
| 51 | + outs(%empty: tensor<2xi32>) { |
| 52 | + ^bb0(%arg0: i32, %arg1: i32): |
| 53 | + iree_linalg_ext.yield %arg0 : i32 |
| 54 | + } -> tensor<2xi32> |
| 55 | + check.expect_eq_const(%result, dense<[2, 3]> : tensor<2xi32>) : tensor<2xi32> |
| 56 | + return |
| 57 | +} |
| 58 | + |
| 59 | +func.func @gather_muli_in_region() { |
| 60 | + %cst = arith.constant 2 : i32 |
| 61 | + %source = util.unfoldable_constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> |
| 62 | + %empty = tensor.empty() : tensor<2xi32> |
| 63 | + %indices = util.unfoldable_constant dense<[1]> : tensor<1xi32> |
| 64 | + %result = iree_linalg_ext.gather dimension_map = [0] |
| 65 | + ins(%source, %indices : tensor<2x2xi32>, tensor<1xi32>) |
| 66 | + outs(%empty: tensor<2xi32>) { |
| 67 | + ^bb0(%arg0: i32, %arg1: i32): |
| 68 | + %0 = arith.muli %arg0, %cst : i32 |
| 69 | + iree_linalg_ext.yield %0 : i32 |
| 70 | + } -> tensor<2xi32> |
| 71 | + check.expect_eq_const(%result, dense<[4, 6]> : tensor<2xi32>) : tensor<2xi32> |
| 72 | + return |
| 73 | +} |
| 74 | + |
| 75 | +func.func @gather_perm_map() { |
| 76 | + %source = util.unfoldable_constant dense<[[[0], [1]], [[2], [3]]]> : tensor<2x2x1xi32> |
| 77 | + %empty = tensor.empty() : tensor<1xi32> |
| 78 | + %indices = util.unfoldable_constant dense<[0, 1]> : tensor<2xi32> |
| 79 | + %result = iree_linalg_ext.gather dimension_map = [1, 0] |
| 80 | + ins(%source, %indices : tensor<2x2x1xi32>, tensor<2xi32>) |
| 81 | + outs(%empty: tensor<1xi32>) { |
| 82 | + ^bb0(%arg0: i32, %arg1: i32): |
| 83 | + iree_linalg_ext.yield %arg0 : i32 |
| 84 | + } -> tensor<1xi32> |
| 85 | + check.expect_eq_const(%result, dense<[2]> : tensor<1xi32>) : tensor<1xi32> |
| 86 | + return |
| 87 | +} |
0 commit comments