@@ -1410,3 +1410,113 @@ func.func @unpack(%arg0: memref<1x4x6x6x2xf32>, %arg1: memref<1x6x6x8xf32>) {
14101410// CHECK: }
14111411// CHECK: }
14121412// CHECK: }
1413+
1414+ // -----
1415+
1416+ func.func @gather_1d_indices (%arg0 : memref <10 x10 xi32 >, %arg1 : memref <1 xi32 >, %arg2 : memref <1 x10 xi32 >) {
1417+ iree_linalg_ext.gather
1418+ dimension_map = [0 ]
1419+ ins (%arg0 , %arg1: memref <10 x10 xi32 >, memref <1 xi32 >)
1420+ outs (%arg2: memref <1 x10 xi32 >) {
1421+ ^bb0 (%bb0: i32 , %bb1: i32 ):
1422+ iree_linalg_ext.yield %bb0 : i32
1423+ }
1424+ return
1425+ }
1426+ // CHECK-LABEL: func @gather_1d_indices
1427+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1428+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1429+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1430+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1431+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1432+ // CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
1433+ // CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C1]] step %[[C1]] {
1434+ // CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C10]] step %[[C1]] {
1435+ // CHECK: %[[IDX:.+]] = memref.load %[[ARG1]][%[[I]]] : memref<1xi32>
1436+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]] : i32 to index
1437+ // CHECK: %[[LOAD:.+]] = memref.load %[[ARG0]][%[[CAST]], %[[J]]] : memref<10x10xi32>
1438+
1439+ // -----
1440+
1441+ func.func @gather_2d_indices (%arg0 : memref <2 x2 xi32 >, %arg1 : memref <2 x2 xi32 >, %arg2 : memref <2 xi32 >) {
1442+ iree_linalg_ext.gather
1443+ dimension_map = [0 , 1 ]
1444+ ins (%arg0 , %arg1: memref <2 x2 xi32 >, memref <2 x2 xi32 >)
1445+ outs (%arg2: memref <2 xi32 >) {
1446+ ^bb0 (%bb0: i32 , %bb1: i32 ):
1447+ iree_linalg_ext.yield %bb0 : i32
1448+ }
1449+ return
1450+ }
1451+ // CHECK-LABEL: func @gather_2d_indices
1452+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1453+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1454+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1455+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1456+ // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
1457+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1458+ // CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
1459+ // CHECK: %[[IDX0:.+]] = memref.load %[[ARG1]][%[[I]], %[[C0]]] : memref<2x2xi32>
1460+ // CHECK: %[[CAST0:.+]] = arith.index_cast %[[IDX0]] : i32 to index
1461+ // CHECK: %[[IDX1:.+]] = memref.load %[[ARG1]][%[[I]], %[[C1]]] : memref<2x2xi32>
1462+ // CHECK: %[[CAST1:.+]] = arith.index_cast %[[IDX1]] : i32 to index
1463+ // CHECK: %[[LOAD0:.+]] = memref.load %[[ARG0]][%[[CAST0]], %[[CAST1]]] : memref<2x2xi32>
1464+ // CHECK: memref.store %[[LOAD0]], %[[ARG2]][%[[I]]] : memref<2xi32>
1465+
1466+ // -----
1467+
1468+ func.func @gather_perm_dim_map (%arg0 : memref <2 x2 xi32 >, %arg1 : memref <2 x2 xi32 >, %arg2 : memref <2 xi32 >) {
1469+ iree_linalg_ext.gather
1470+ dimension_map = [1 , 0 ]
1471+ ins (%arg0 , %arg1: memref <2 x2 xi32 >, memref <2 x2 xi32 >)
1472+ outs (%arg2: memref <2 xi32 >) {
1473+ ^bb0 (%bb0: i32 , %bb1: i32 ):
1474+ iree_linalg_ext.yield %bb0 : i32
1475+ }
1476+ return
1477+ }
1478+ // CHECK-LABEL: func @gather_perm_dim_map
1479+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1480+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1481+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1482+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1483+ // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
1484+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1485+ // CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
1486+ // CHECK: %[[IDX0:.+]] = memref.load %[[ARG1]][%[[I]], %[[C0]]] : memref<2x2xi32>
1487+ // CHECK: %[[CAST0:.+]] = arith.index_cast %[[IDX0]] : i32 to index
1488+ // CHECK: %[[IDX1:.+]] = memref.load %[[ARG1]][%[[I]], %[[C1]]] : memref<2x2xi32>
1489+ // CHECK: %[[CAST1:.+]] = arith.index_cast %[[IDX1]] : i32 to index
1490+ // CHECK: %[[LOAD0:.+]] = memref.load %[[ARG0]][%[[CAST1]], %[[CAST0]]] : memref<2x2xi32>
1491+ // CHECK: memref.store %[[LOAD0]], %[[ARG2]][%[[I]]] : memref<2xi32>
1492+
1493+ // -----
1494+
1495+ func.func @gather_inline_region (%arg0 : memref <2 x2 xi32 >, %arg1 : memref <2 x2 xi32 >, %arg2 : memref <2 xi32 >) {
1496+ %cst = arith.constant 3 : i32
1497+ iree_linalg_ext.gather
1498+ dimension_map = [0 , 1 ]
1499+ ins (%arg0 , %arg1: memref <2 x2 xi32 >, memref <2 x2 xi32 >)
1500+ outs (%arg2: memref <2 xi32 >) {
1501+ ^bb0 (%bb0: i32 , %bb1: i32 ):
1502+ %0 = arith.muli %bb0 , %cst : i32
1503+ iree_linalg_ext.yield %0 : i32
1504+ }
1505+ return
1506+ }
1507+ // CHECK-LABEL: func @gather_inline_region
1508+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1509+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1510+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1511+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1512+ // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
1513+ // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : i32
1514+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1515+ // CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
1516+ // CHECK: %[[IDX0:.+]] = memref.load %[[ARG1]][%[[I]], %[[C0]]] : memref<2x2xi32>
1517+ // CHECK: %[[CAST0:.+]] = arith.index_cast %[[IDX0]] : i32 to index
1518+ // CHECK: %[[IDX1:.+]] = memref.load %[[ARG1]][%[[I]], %[[C1]]] : memref<2x2xi32>
1519+ // CHECK: %[[CAST1:.+]] = arith.index_cast %[[IDX1]] : i32 to index
1520+ // CHECK: %[[LOAD0:.+]] = memref.load %[[ARG0]][%[[CAST0]], %[[CAST1]]] : memref<2x2xi32>
1521+ // CHECK: %[[MUL:.+]] = arith.muli %[[LOAD0]], %[[C3]] : i32
1522+ // CHECK: memref.store %[[MUL]], %[[ARG2]][%[[I]]] : memref<2xi32>
0 commit comments