Skip to content

Commit 05664b6

Browse files
authored
[LinalgExt] Support converting gather to loops (3/5) (iree-org#20464)
Adds `TilingInterface` methods to be able to convert `iree_linalg_ext.gather` to loops. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent cab3435 commit 05664b6

File tree

3 files changed

+162
-1
lines changed

3 files changed

+162
-1
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
226226
def IREELinalgExt_GatherOp : IREELinalgExt_Op<"gather",
227227
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
228228
DeclareOpInterfaceMethods<TilingInterface,
229-
["getIterationDomain",
229+
["generateScalarImplementation",
230+
"getIterationDomain",
230231
"getLoopIteratorTypes",
231232
"getResultTilePosition",
232233
"getTiledImplementation",

compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,56 @@ GatherOp::generateResultTileValue(OpBuilder &builder, unsigned resultNumber,
388388
return getTiledImplementation(builder, offsets, sizes);
389389
}
390390

391+
LogicalResult GatherOp::generateScalarImplementation(OpBuilder &b, Location loc,
392+
ValueRange ivs) {
393+
auto indexDepth = getIndexDepth();
394+
Value result = b.create<memref::LoadOp>(loc, getOutput(), ivs);
395+
SmallVector<Value> loadIndices(ivs.take_front(getBatchRank()));
396+
397+
// Populate with empty values.
398+
auto sourceTy = getSourceType();
399+
auto resultIvs = ivs.drop_front(getBatchRank());
400+
SmallVector<Value> starts(sourceTy.getRank() - resultIvs.size(), Value());
401+
llvm::append_range(starts, resultIvs);
402+
403+
// The innermost dim of `indices` having an innermost dim for each index.
404+
bool hasIndexDim = getIndicesType().getRank() > getBatchRank();
405+
if (hasIndexDim) {
406+
loadIndices.push_back(Value());
407+
}
408+
409+
// Populate `starts` by loading indices from `indices`
410+
ArrayRef<int64_t> dimMap = getDimensionMap();
411+
for (int64_t i = 0; i < indexDepth; ++i) {
412+
if (hasIndexDim) {
413+
loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
414+
}
415+
Value idx = b.create<memref::LoadOp>(loc, getIndices(), loadIndices);
416+
Value ret = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
417+
auto dim = dimMap[i];
418+
if (starts[dim])
419+
ret = b.create<arith::AddIOp>(loc, ret, starts[dim]);
420+
starts[dim] = ret;
421+
}
422+
423+
Value init = b.create<memref::LoadOp>(loc, getSource(), starts);
424+
425+
IRMapping bvm;
426+
Block &block = getRegion().front();
427+
bvm.map(block.getArgument(0), init);
428+
bvm.map(block.getArgument(1), result);
429+
for (auto &blockOp : block.without_terminator()) {
430+
b.clone(blockOp, bvm);
431+
}
432+
433+
// The last op is linalg_ext.yield op. Store the operand to
434+
// destination.
435+
b.create<memref::StoreOp>(
436+
loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)),
437+
getOutput(), ivs);
438+
return success();
439+
}
440+
391441
//===----------------------------------------------------------------------===//
392442
// SortOp
393443
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,3 +1410,113 @@ func.func @unpack(%arg0: memref<1x4x6x6x2xf32>, %arg1: memref<1x6x6x8xf32>) {
14101410
// CHECK: }
14111411
// CHECK: }
14121412
// CHECK: }
1413+
1414+
// -----
1415+
1416+
func.func @gather_1d_indices(%arg0 : memref<10x10xi32>, %arg1 : memref<1xi32>, %arg2 : memref<1x10xi32>) {
1417+
iree_linalg_ext.gather
1418+
dimension_map = [0]
1419+
ins(%arg0, %arg1: memref<10x10xi32>, memref<1xi32>)
1420+
outs(%arg2: memref<1x10xi32>) {
1421+
^bb0(%bb0: i32, %bb1: i32):
1422+
iree_linalg_ext.yield %bb0 : i32
1423+
}
1424+
return
1425+
}
1426+
// CHECK-LABEL: func @gather_1d_indices
1427+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1428+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1429+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1430+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1431+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1432+
// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
1433+
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C1]] step %[[C1]] {
1434+
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C10]] step %[[C1]] {
1435+
// CHECK: %[[IDX:.+]] = memref.load %[[ARG1]][%[[I]]] : memref<1xi32>
1436+
// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]] : i32 to index
1437+
// CHECK: %[[LOAD:.+]] = memref.load %[[ARG0]][%[[CAST]], %[[J]]] : memref<10x10xi32>
1438+
1439+
// -----
1440+
1441+
func.func @gather_2d_indices(%arg0 : memref<2x2xi32>, %arg1 : memref<2x2xi32>, %arg2 : memref<2xi32>) {
1442+
iree_linalg_ext.gather
1443+
dimension_map = [0, 1]
1444+
ins(%arg0, %arg1: memref<2x2xi32>, memref<2x2xi32>)
1445+
outs(%arg2: memref<2xi32>) {
1446+
^bb0(%bb0: i32, %bb1: i32):
1447+
iree_linalg_ext.yield %bb0 : i32
1448+
}
1449+
return
1450+
}
1451+
// CHECK-LABEL: func @gather_2d_indices
1452+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1453+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1454+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1455+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1456+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
1457+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1458+
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
1459+
// CHECK: %[[IDX0:.+]] = memref.load %[[ARG1]][%[[I]], %[[C0]]] : memref<2x2xi32>
1460+
// CHECK: %[[CAST0:.+]] = arith.index_cast %[[IDX0]] : i32 to index
1461+
// CHECK: %[[IDX1:.+]] = memref.load %[[ARG1]][%[[I]], %[[C1]]] : memref<2x2xi32>
1462+
// CHECK: %[[CAST1:.+]] = arith.index_cast %[[IDX1]] : i32 to index
1463+
// CHECK: %[[LOAD0:.+]] = memref.load %[[ARG0]][%[[CAST0]], %[[CAST1]]] : memref<2x2xi32>
1464+
// CHECK: memref.store %[[LOAD0]], %[[ARG2]][%[[I]]] : memref<2xi32>
1465+
1466+
// -----
1467+
1468+
func.func @gather_perm_dim_map(%arg0 : memref<2x2xi32>, %arg1 : memref<2x2xi32>, %arg2 : memref<2xi32>) {
1469+
iree_linalg_ext.gather
1470+
dimension_map = [1, 0]
1471+
ins(%arg0, %arg1: memref<2x2xi32>, memref<2x2xi32>)
1472+
outs(%arg2: memref<2xi32>) {
1473+
^bb0(%bb0: i32, %bb1: i32):
1474+
iree_linalg_ext.yield %bb0 : i32
1475+
}
1476+
return
1477+
}
1478+
// CHECK-LABEL: func @gather_perm_dim_map
1479+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1480+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1481+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1482+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1483+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
1484+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1485+
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
1486+
// CHECK: %[[IDX0:.+]] = memref.load %[[ARG1]][%[[I]], %[[C0]]] : memref<2x2xi32>
1487+
// CHECK: %[[CAST0:.+]] = arith.index_cast %[[IDX0]] : i32 to index
1488+
// CHECK: %[[IDX1:.+]] = memref.load %[[ARG1]][%[[I]], %[[C1]]] : memref<2x2xi32>
1489+
// CHECK: %[[CAST1:.+]] = arith.index_cast %[[IDX1]] : i32 to index
1490+
// CHECK: %[[LOAD0:.+]] = memref.load %[[ARG0]][%[[CAST1]], %[[CAST0]]] : memref<2x2xi32>
1491+
// CHECK: memref.store %[[LOAD0]], %[[ARG2]][%[[I]]] : memref<2xi32>
1492+
1493+
// -----
1494+
1495+
func.func @gather_inline_region(%arg0 : memref<2x2xi32>, %arg1 : memref<2x2xi32>, %arg2 : memref<2xi32>) {
1496+
%cst = arith.constant 3 : i32
1497+
iree_linalg_ext.gather
1498+
dimension_map = [0, 1]
1499+
ins(%arg0, %arg1: memref<2x2xi32>, memref<2x2xi32>)
1500+
outs(%arg2: memref<2xi32>) {
1501+
^bb0(%bb0: i32, %bb1: i32):
1502+
%0 = arith.muli %bb0, %cst : i32
1503+
iree_linalg_ext.yield %0 : i32
1504+
}
1505+
return
1506+
}
1507+
// CHECK-LABEL: func @gather_inline_region
1508+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1509+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1510+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1511+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1512+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
1513+
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : i32
1514+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1515+
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
1516+
// CHECK: %[[IDX0:.+]] = memref.load %[[ARG1]][%[[I]], %[[C0]]] : memref<2x2xi32>
1517+
// CHECK: %[[CAST0:.+]] = arith.index_cast %[[IDX0]] : i32 to index
1518+
// CHECK: %[[IDX1:.+]] = memref.load %[[ARG1]][%[[I]], %[[C1]]] : memref<2x2xi32>
1519+
// CHECK: %[[CAST1:.+]] = arith.index_cast %[[IDX1]] : i32 to index
1520+
// CHECK: %[[LOAD0:.+]] = memref.load %[[ARG0]][%[[CAST0]], %[[CAST1]]] : memref<2x2xi32>
1521+
// CHECK: %[[MUL:.+]] = arith.muli %[[LOAD0]], %[[C3]] : i32
1522+
// CHECK: memref.store %[[MUL]], %[[ARG2]][%[[I]]] : memref<2xi32>

0 commit comments

Comments
 (0)