Skip to content

Commit 25d9d60

Browse files
authored
[LinalgExt] Add TilingInterface support to GatherOp (2/5) (iree-org#20462)
Adds `TilingInterface` methods to make `iree_linalg_ext.gather` tile+fusible with its consumer. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 6900621 commit 25d9d60

File tree

3 files changed

+201
-2
lines changed

3 files changed

+201
-2
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,13 @@ def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
224224
}
225225

226226
def IREELinalgExt_GatherOp : IREELinalgExt_Op<"gather",
227-
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
227+
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
228+
DeclareOpInterfaceMethods<TilingInterface,
229+
["getIterationDomain",
230+
"getLoopIteratorTypes",
231+
"getResultTilePosition",
232+
"getTiledImplementation",
233+
"generateResultTileValue"]>]> {
228234
let summary = "Gather operator";
229235
let description = [{
230236
Takes two inputs (`source` and `indices`) and outputs value (`output`).

compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ SmallVector<Range> ScatterOp::getIterationDomain(OpBuilder &builder) {
8888
SmallVector<Range> ranges;
8989
for (auto dim : llvm::seq<int64_t>(0, getUpdateType().getRank())) {
9090
OpFoldResult ub = getDim(builder, loc, getUpdates(), dim);
91-
ranges.emplace_back(Range{zero, ub, one});
91+
ranges.push_back(Range{zero, ub, one});
9292
}
9393
return ranges;
9494
}
@@ -277,6 +277,117 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
277277
return success();
278278
}
279279

280+
//===----------------------------------------------------------------------===//
281+
// GatherOp
282+
//===----------------------------------------------------------------------===//
283+
284+
SmallVector<utils::IteratorType> GatherOp::getLoopIteratorTypes() {
285+
return SmallVector<utils::IteratorType>(getOutputType().getRank(),
286+
utils::IteratorType::parallel);
287+
}
288+
289+
SmallVector<Range> GatherOp::getIterationDomain(OpBuilder &builder) {
290+
Location loc = getLoc();
291+
OpFoldResult zero = builder.getIndexAttr(0);
292+
OpFoldResult one = builder.getIndexAttr(1);
293+
SmallVector<Range> ranges;
294+
for (auto dim : llvm::seq<int64_t>(0, getOutputType().getRank())) {
295+
OpFoldResult ub = getDim(builder, loc, getOutput(), dim);
296+
ranges.push_back(Range{zero, ub, one});
297+
}
298+
return ranges;
299+
}
300+
301+
FailureOr<TilingResult>
302+
GatherOp::getTiledImplementation(OpBuilder &builder,
303+
ArrayRef<OpFoldResult> offsets,
304+
ArrayRef<OpFoldResult> sizes) {
305+
assert(offsets.size() >= 1 && sizes.size() >= 1);
306+
Location loc = getLoc();
307+
auto zeroAttr = builder.getI64IntegerAttr(0);
308+
auto oneAttr = builder.getI64IntegerAttr(1);
309+
SmallVector<Operation *> slices;
310+
311+
// Slice of the result.
312+
auto resultRank = getOutputType().getRank();
313+
SmallVector<OpFoldResult> resultStrides(resultRank, oneAttr);
314+
Operation *resultSlice =
315+
getSlice(builder, loc, getOutput(), offsets, sizes, resultStrides);
316+
if (!resultSlice) {
317+
return emitOpError("failed to get result slice");
318+
}
319+
Value tiledResult = resultSlice->getResult(0);
320+
321+
// Slice of indices.
322+
auto indicesRank = getIndicesType().getRank();
323+
SmallVector<OpFoldResult> indicesOffsets(offsets.take_front(getBatchRank()));
324+
SmallVector<OpFoldResult> indicesSizes(sizes.take_front(getBatchRank()));
325+
if (getBatchRank() != getIndicesType().getRank()) {
326+
indicesOffsets.push_back(zeroAttr);
327+
indicesSizes.push_back(builder.getIndexAttr(getIndexDepth()));
328+
}
329+
SmallVector<OpFoldResult> indicesStrides(indicesRank, oneAttr);
330+
331+
Operation *indicesSlice = getSlice(builder, loc, getIndices(), indicesOffsets,
332+
indicesSizes, indicesStrides);
333+
if (!indicesSlice) {
334+
return emitOpError("failed to get indices slices");
335+
}
336+
Value tiledIndices = indicesSlice->getResult(0);
337+
338+
// Slice of the source.
339+
auto sourceRank = getSourceType().getRank();
340+
auto indexDepth = getIndexDepth();
341+
342+
// The first `indexDepth` dims are not tiled
343+
SmallVector<OpFoldResult> sourceOffsets, sourceSizes;
344+
for (auto dim : llvm::seq<int64_t>(0, indexDepth)) {
345+
sourceOffsets.push_back(zeroAttr);
346+
sourceSizes.push_back(getDim(builder, loc, getSource(), dim));
347+
}
348+
llvm::append_range(sourceOffsets,
349+
offsets.slice(getBatchRank(), sourceRank - indexDepth));
350+
llvm::append_range(sourceSizes,
351+
sizes.slice(getBatchRank(), sourceRank - indexDepth));
352+
SmallVector<OpFoldResult> sourceStrides(sourceRank, oneAttr);
353+
Operation *sourceSlice = getSlice(builder, loc, getSource(), sourceOffsets,
354+
sourceSizes, sourceStrides);
355+
if (!sourceSlice) {
356+
return emitOpError("failed to get source tensor slice");
357+
}
358+
Value tiledSource = sourceSlice->getResult(0);
359+
360+
slices.push_back(sourceSlice);
361+
slices.push_back(indicesSlice);
362+
slices.push_back(resultSlice);
363+
364+
SmallVector<Type> resultTypes;
365+
if (getNumResults()) {
366+
resultTypes.push_back(tiledResult.getType());
367+
}
368+
Operation *tiledGatherOp =
369+
mlir::clone(builder, getOperation(), resultTypes,
370+
ValueRange{tiledSource, tiledIndices, tiledResult});
371+
return TilingResult{
372+
{tiledGatherOp}, SmallVector<Value>(tiledGatherOp->getResults()), slices};
373+
}
374+
375+
LogicalResult GatherOp::getResultTilePosition(
376+
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
377+
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
378+
SmallVector<OpFoldResult> &resultSizes) {
379+
resultOffsets.assign(offsets.begin(), offsets.end());
380+
resultSizes.assign(sizes.begin(), sizes.end());
381+
return success();
382+
}
383+
384+
FailureOr<TilingResult>
385+
GatherOp::generateResultTileValue(OpBuilder &builder, unsigned resultNumber,
386+
ArrayRef<OpFoldResult> offsets,
387+
ArrayRef<OpFoldResult> sizes) {
388+
return getTiledImplementation(builder, offsets, sizes);
389+
}
390+
280391
//===----------------------------------------------------------------------===//
281392
// SortOp
282393
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2584,3 +2584,85 @@ module attributes { transform.with_named_sequence } {
25842584
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]](%[[IV]])[%[[NEW_INDEX]]]
25852585
// CHECK: linalg.generic
25862586
// CHECK-SAME: ins(%{{.+}}, %[[INDEX]] :
2587+
2588+
// -----
2589+
2590+
func.func @gather_1d_indices(%arg0 : memref<?x?xi32>, %arg1 : memref<?xi32>, %arg2 : memref<?x?xi32>) {
2591+
iree_linalg_ext.gather
2592+
dimension_map = [0]
2593+
ins(%arg0, %arg1: memref<?x?xi32>, memref<?xi32>)
2594+
outs(%arg2: memref<?x?xi32>) {
2595+
^bb0(%bb0: i32, %bb1: i32):
2596+
iree_linalg_ext.yield %bb0 : i32
2597+
}
2598+
return
2599+
}
2600+
module attributes { transform.with_named_sequence } {
2601+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
2602+
%0 = transform.structured.match ops{["iree_linalg_ext.gather"]} in %module_op : (!transform.any_op) -> !transform.any_op
2603+
%1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
2604+
transform.yield
2605+
}
2606+
}
2607+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
2608+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)>
2609+
// CHECK: func @gather_1d_indices
2610+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
2611+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
2612+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
2613+
// CHECK-DAG: %[[C20:.+]] = arith.constant 20
2614+
// CHECK-DAG: %[[C10:.+]] = arith.constant 10
2615+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
2616+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
2617+
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG2]], %[[C0]]
2618+
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG2]], %[[C1]]
2619+
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C10]]
2620+
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C20]]
2621+
// CHECK-DAG: %[[MIN0:.+]] = affine.min #[[MAP0]](%[[I]])[%[[D0]]]
2622+
// CHECK-DAG: %[[MIN1:.+]] = affine.min #[[MAP1]](%[[J]])[%[[D1]]]
2623+
// CHECK-DAG: %[[RESULT:.+]] = memref.subview %[[ARG2]][%[[I]], %[[J]]] [%[[MIN0]], %[[MIN1]]] [1, 1]
2624+
// CHECK-DAG: %[[INDEX:.+]] = memref.subview %[[ARG1]][%[[I]]] [%[[MIN0]]] [1]
2625+
// CHECK-DAG: %[[D2:.+]] = memref.dim %[[ARG0]], %[[C0]]
2626+
// CHECK-DAG: %[[SOURCE:.+]] = memref.subview %[[ARG0]][0, %[[J]]] [%[[D2]], %[[MIN1]]] [1, 1]
2627+
// CHECK: iree_linalg_ext.gather
2628+
// CHECK-SAME: ins(%[[SOURCE]], %[[INDEX]]
2629+
// CHECK-SAME: outs(%[[RESULT]]
2630+
2631+
// -----
2632+
2633+
func.func @gather_2d_indices(%arg0 : memref<?x?xi32>, %arg1 : memref<?x2xi32>, %arg2 : memref<?xi32>) {
2634+
iree_linalg_ext.gather
2635+
dimension_map = [0, 1]
2636+
ins(%arg0, %arg1: memref<?x?xi32>, memref<?x2xi32>)
2637+
outs(%arg2: memref<?xi32>) {
2638+
^bb0(%bb0: i32, %bb1: i32):
2639+
iree_linalg_ext.yield %bb0 : i32
2640+
}
2641+
return
2642+
}
2643+
module attributes { transform.with_named_sequence } {
2644+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
2645+
%0 = transform.structured.match ops{["iree_linalg_ext.gather"]} in %module_op : (!transform.any_op) -> !transform.any_op
2646+
%1, %loops = transform.structured.tile_using_for %0 tile_sizes [13] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
2647+
transform.yield
2648+
}
2649+
}
2650+
// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 13)>
2651+
// CHECK: func @gather_2d_indices
2652+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
2653+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
2654+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
2655+
// CHECK-DAG: %[[C13:.+]] = arith.constant 13
2656+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
2657+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
2658+
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG2]], %[[C0]]
2659+
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C13]]
2660+
// CHECK-DAG: %[[MIN:.+]] = affine.min #[[MAP0]](%[[I]])[%[[D0]]]
2661+
// CHECK-DAG: %[[RESULT:.+]] = memref.subview %[[ARG2]][%[[I]]] [%[[MIN]]] [1]
2662+
// CHECK-DAG: %[[INDEX:.+]] = memref.subview %[[ARG1]][%[[I]], 0] [%[[MIN]], 2] [1, 1]
2663+
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C0]]
2664+
// CHECK-DAG: %[[D2:.+]] = memref.dim %[[ARG0]], %[[C1]]
2665+
// CHECK-DAG: %[[SOURCE:.+]] = memref.subview %[[ARG0]][0, 0] [%[[D1]], %[[D2]]] [1, 1]
2666+
// CHECK: iree_linalg_ext.gather
2667+
// CHECK-SAME: ins(%[[SOURCE]], %[[INDEX]]
2668+
// CHECK-SAME: outs(%[[RESULT]]

0 commit comments

Comments
 (0)