Skip to content

Commit 35f4632

Browse files
authored
Add e2e tests for gather (4/5) (iree-org#20465)
Adds e2e tests for `iree_linalg_ext.gather` and adds two interfaces needed for codegen. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 05664b6 commit 35f4632

File tree

5 files changed

+103
-0
lines changed

5 files changed

+103
-0
lines changed

compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,8 @@ void registerBufferizationInterfaces(DialectRegistry &registry) {
653653
LinalgExtOpInterface<IREE::LinalgExt::ScanOp>>(*ctx);
654654
IREE::LinalgExt::ScatterOp::attachInterface<
655655
LinalgExtOpInterface<IREE::LinalgExt::ScatterOp>>(*ctx);
656+
IREE::LinalgExt::GatherOp::attachInterface<
657+
LinalgExtOpInterface<IREE::LinalgExt::GatherOp>>(*ctx);
656658
IREE::LinalgExt::SortOp::attachInterface<
657659
LinalgExtOpInterface<IREE::LinalgExt::SortOp>>(*ctx);
658660
IREE::LinalgExt::TopkOp::attachInterface<

compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ void registerPartitionableLoopsInterfaceModels(DialectRegistry &registry) {
245245
AllParallelAsPartitionableLoops<IREE::LinalgExt::ScanOp>>(*ctx);
246246
IREE::LinalgExt::ScatterOp::attachInterface<
247247
OuterParallelAsPartitionableLoops<IREE::LinalgExt::ScatterOp>>(*ctx);
248+
IREE::LinalgExt::GatherOp::attachInterface<
249+
AllParallelAsPartitionableLoops<IREE::LinalgExt::GatherOp>>(*ctx);
248250
IREE::LinalgExt::SortOp::attachInterface<
249251
AllParallelAsPartitionableLoops<IREE::LinalgExt::SortOp>>(*ctx);
250252
IREE::LinalgExt::TopkOp::attachInterface<

tests/e2e/linalg_ext_ops/BUILD.bazel

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ ALL_SRCS = enforce_glob(
1616
# keep sorted
1717
[
1818
"attention.mlir",
19+
"gather.mlir",
1920
"scan.mlir",
2021
"scatter.mlir",
2122
"sort.mlir",
@@ -63,6 +64,7 @@ iree_check_single_backend_test_suite(
6364
VMVX_SRCS = enforce_glob(
6465
# keep sorted
6566
[
67+
"gather.mlir",
6668
"scan.mlir",
6769
"scatter.mlir",
6870
"sort.mlir",
@@ -87,6 +89,7 @@ iree_check_single_backend_test_suite(
8789
LLVM_GPU_SRCS = enforce_glob(
8890
# keep sorted
8991
[
92+
"gather.mlir",
9093
"scan.mlir",
9194
"scatter.mlir",
9295
"sort.mlir",
@@ -120,6 +123,7 @@ iree_check_single_backend_test_suite(
120123
ROCM_HIP_SRCS = enforce_glob(
121124
# keep sorted
122125
[
126+
"gather.mlir",
123127
"scan.mlir",
124128
"scatter.mlir",
125129
"sort.mlir",
@@ -146,6 +150,7 @@ iree_check_single_backend_test_suite(
146150
srcs = enforce_glob(
147151
# keep sorted
148152
[
153+
"gather.mlir",
149154
"scan.mlir",
150155
"scatter.mlir",
151156
"sort.mlir",
@@ -168,6 +173,7 @@ iree_check_single_backend_test_suite(
168173
srcs = enforce_glob(
169174
# keep sorted
170175
[
176+
"gather.mlir",
171177
"scan.mlir",
172178
"scatter.mlir",
173179
"sort.mlir",

tests/e2e/linalg_ext_ops/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ iree_check_single_backend_test_suite(
1515
check_llvm-cpu_local-task
1616
SRCS
1717
"attention.mlir"
18+
"gather.mlir"
1819
"scan.mlir"
1920
"scatter.mlir"
2021
"sort.mlir"
@@ -51,6 +52,7 @@ iree_check_single_backend_test_suite(
5152
NAME
5253
check_vmvx_local-task
5354
SRCS
55+
"gather.mlir"
5456
"scan.mlir"
5557
"scatter.mlir"
5658
"sort.mlir"
@@ -67,6 +69,7 @@ iree_check_single_backend_test_suite(
6769
NAME
6870
check_cuda
6971
SRCS
72+
"gather.mlir"
7073
"scan.mlir"
7174
"scatter.mlir"
7275
"sort.mlir"
@@ -89,6 +92,7 @@ iree_check_single_backend_test_suite(
8992
NAME
9093
check_rocm_hip
9194
SRCS
95+
"gather.mlir"
9296
"scan.mlir"
9397
"scatter.mlir"
9498
"sort.mlir"
@@ -104,6 +108,7 @@ iree_check_single_backend_test_suite(
104108
NAME
105109
check_metal-spirv_vulkan
106110
SRCS
111+
"gather.mlir"
107112
"scan.mlir"
108113
"scatter.mlir"
109114
"sort.mlir"
@@ -119,6 +124,7 @@ iree_check_single_backend_test_suite(
119124
NAME
120125
check_vulkan-spirv_vulkan
121126
SRCS
127+
"gather.mlir"
122128
"scan.mlir"
123129
"scatter.mlir"
124130
"sort.mlir"
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
func.func @gather_from_splat_tensor() {
2+
%source = util.unfoldable_constant dense<0> : tensor<10x10xi32>
3+
%empty = tensor.empty() : tensor<1x10xi32>
4+
%indices = util.unfoldable_constant dense<0> : tensor<1xi32>
5+
%result = iree_linalg_ext.gather dimension_map = [0]
6+
ins(%source, %indices : tensor<10x10xi32>, tensor<1xi32>)
7+
outs(%empty : tensor<1x10xi32>) {
8+
^bb0(%arg0: i32, %arg1: i32):
9+
iree_linalg_ext.yield %arg0 : i32
10+
} -> tensor<1x10xi32>
11+
12+
check.expect_eq_const(%result, dense<0> : tensor<1x10xi32>)
13+
: tensor<1x10xi32>
14+
return
15+
}
16+
17+
func.func @gather_2d_index_with_batch() {
18+
%source = util.unfoldable_constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
19+
%empty = tensor.empty() : tensor<2xi32>
20+
%indices = util.unfoldable_constant dense<[[0, 1], [1, 0]]> : tensor<2x2xi32>
21+
%result = iree_linalg_ext.gather dimension_map = [0, 1]
22+
ins(%source, %indices : tensor<2x2xi32>, tensor<2x2xi32>)
23+
outs(%empty: tensor<2xi32>) {
24+
^bb0(%arg0: i32, %arg1: i32):
25+
iree_linalg_ext.yield %arg0 : i32
26+
} -> tensor<2xi32>
27+
check.expect_eq_const(%result, dense<[1, 2]> : tensor<2xi32>) : tensor<2xi32>
28+
return
29+
}
30+
31+
func.func @gather_2d_index_no_batch() {
32+
%source = util.unfoldable_constant dense<[[[0], [1]], [[0], [0]]]> : tensor<2x2x1xi32>
33+
%empty = tensor.empty() : tensor<1xi32>
34+
%indices = util.unfoldable_constant dense<[0, 1]> : tensor<2xi32>
35+
%result = iree_linalg_ext.gather dimension_map = [0, 1]
36+
ins(%source, %indices : tensor<2x2x1xi32>, tensor<2xi32>)
37+
outs(%empty: tensor<1xi32>) {
38+
^bb0(%arg0: i32, %arg1: i32):
39+
iree_linalg_ext.yield %arg0 : i32
40+
} -> tensor<1xi32>
41+
check.expect_eq_const(%result, dense<[1]> : tensor<1xi32>) : tensor<1xi32>
42+
return
43+
}
44+
45+
func.func @gather_1d_index_no_batch() {
46+
%source = util.unfoldable_constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
47+
%empty = tensor.empty() : tensor<2xi32>
48+
%indices = util.unfoldable_constant dense<[1]> : tensor<1xi32>
49+
%result = iree_linalg_ext.gather dimension_map = [0]
50+
ins(%source, %indices : tensor<2x2xi32>, tensor<1xi32>)
51+
outs(%empty: tensor<2xi32>) {
52+
^bb0(%arg0: i32, %arg1: i32):
53+
iree_linalg_ext.yield %arg0 : i32
54+
} -> tensor<2xi32>
55+
check.expect_eq_const(%result, dense<[2, 3]> : tensor<2xi32>) : tensor<2xi32>
56+
return
57+
}
58+
59+
func.func @gather_muli_in_region() {
60+
%cst = arith.constant 2 : i32
61+
%source = util.unfoldable_constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
62+
%empty = tensor.empty() : tensor<2xi32>
63+
%indices = util.unfoldable_constant dense<[1]> : tensor<1xi32>
64+
%result = iree_linalg_ext.gather dimension_map = [0]
65+
ins(%source, %indices : tensor<2x2xi32>, tensor<1xi32>)
66+
outs(%empty: tensor<2xi32>) {
67+
^bb0(%arg0: i32, %arg1: i32):
68+
%0 = arith.muli %arg0, %cst : i32
69+
iree_linalg_ext.yield %0 : i32
70+
} -> tensor<2xi32>
71+
check.expect_eq_const(%result, dense<[4, 6]> : tensor<2xi32>) : tensor<2xi32>
72+
return
73+
}
74+
75+
func.func @gather_perm_map() {
76+
%source = util.unfoldable_constant dense<[[[0], [1]], [[2], [3]]]> : tensor<2x2x1xi32>
77+
%empty = tensor.empty() : tensor<1xi32>
78+
%indices = util.unfoldable_constant dense<[0, 1]> : tensor<2xi32>
79+
%result = iree_linalg_ext.gather dimension_map = [1, 0]
80+
ins(%source, %indices : tensor<2x2x1xi32>, tensor<2xi32>)
81+
outs(%empty: tensor<1xi32>) {
82+
^bb0(%arg0: i32, %arg1: i32):
83+
iree_linalg_ext.yield %arg0 : i32
84+
} -> tensor<1xi32>
85+
check.expect_eq_const(%result, dense<[2]> : tensor<1xi32>) : tensor<1xi32>
86+
return
87+
}

0 commit comments

Comments
 (0)