Skip to content

Commit 9df82fd

Browse files
authored
[LinalgExt] Add support for fusing scatter with producers (iree-org#19584)
This adds implementations for "getIterationDomainTileFromOperandTile" and "getTiledImplementationFromOperandTile" to linalg_ext.scatter. This allows fusing scatters with producer loops during tiling. The implementation of these methods is trivial because the iteration domain is already defined in terms of the input operands, so we can just invoke the tiling implementation.
1 parent 5a97523 commit 9df82fd

File tree

3 files changed

+70
-1
lines changed

3 files changed

+70
-1
lines changed

compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,3 +672,32 @@ func.func @v_shaped_graph(%0: tensor<12xf32>, %1: tensor<12xf32>) -> tensor<12xf
672672
// CHECK-DAG: %[[RIGHT:.+]] = linalg.generic {{.*}} ins(%[[SLICE1]]
673673
// CHECK: linalg.generic {{.*}} ins(%[[LEFT]], %[[RIGHT]]
674674
// CHECK: return %[[RESULT]]
675+
676+
// -----
677+
678+
func.func @consumer_fuse_scatter(%arg0: tensor<3x2048x2048xf32>,
679+
%arg1: tensor<3x2048x2048xf32>,
680+
%arg2: tensor<3x1xi32>) -> tensor<3x2048x2048xf32> {
681+
%cst = arith.constant 0.000000e+00 : f32
682+
%0 = tensor.empty() : tensor<3x2048x2048xf32>
683+
%1 = linalg.add {lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 256]}>}
684+
ins(%arg0, %arg1 : tensor<3x2048x2048xf32>, tensor<3x2048x2048xf32>) outs(%0 : tensor<3x2048x2048xf32>) -> tensor<3x2048x2048xf32>
685+
%2 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
686+
ins(%1, %arg2 : tensor<3x2048x2048xf32>, tensor<3x1xi32>) outs(%0 : tensor<3x2048x2048xf32>) {
687+
^bb0(%arg3: f32, %arg4: f32):
688+
iree_linalg_ext.yield %arg3 : f32
689+
} -> tensor<3x2048x2048xf32>
690+
return %2 : tensor<3x2048x2048xf32>
691+
}
692+
693+
// CHECK-LABEL: func @consumer_fuse_scatter(
694+
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<3x2048x2048xf32>
695+
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<3x2048x2048xf32>
696+
// CHECK-SAME: %[[IND:[A-Za-z0-9]+]]: tensor<3x1xi32>
697+
// CHECK: %[[RESULT:.+]] = scf.forall (%[[ID0:.+]], %[[ID1:.+]], %[[ID2:[A-Za-z0-9]+]]) {{.*}} shared_outs(%[[DEST:.+]] = %{{.*}})
698+
// CHECK-DAG: %[[SRC:.+]] = linalg.add
699+
// CHECK-DAG: %[[IND_SLICE:.+]] = tensor.extract_slice %[[IND]][%[[ID0]], 0] {{.*}} : tensor<3x1xi32> to tensor<1x1xi32>
700+
// CHECK-DAG: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]][0, %[[ID1]], %[[ID2]]] {{.*}} to tensor<3x1x256xf32>
701+
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
702+
// CHECK-SAME: ins(%[[SRC]], %[[IND_SLICE]]{{.*}} outs(%[[DEST_SLICE]]
703+
// CHECK: tensor.parallel_insert_slice %[[SCATTER]] into %[[DEST]][0, %[[ID1]], %[[ID2]]]

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
102102
"getIterationDomain",
103103
"getLoopIteratorTypes",
104104
"getResultTilePosition",
105-
"getTiledImplementation"]>]> {
105+
"getTiledImplementation",
106+
"getIterationDomainTileFromOperandTile",
107+
"getTiledImplementationFromOperandTile"]>]> {
106108
let summary = "Scatter operator";
107109
let description = [{
108110
Based on XLA operation semantics, takes two `inputs` (`update` and

compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,44 @@ LogicalResult ScatterOp::getResultTilePosition(
181181
return success();
182182
}
183183

184+
/// Method to return the position of the result tile computed by the tiled
185+
/// operation.
186+
LogicalResult ScatterOp::getIterationDomainTileFromOperandTile(
187+
OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
188+
ArrayRef<OpFoldResult> sizes,
189+
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
190+
SmallVectorImpl<OpFoldResult> &iterDomainSizes) {
191+
// Fusion with producers is not possible in general if `unique_indices` is not
192+
// true as reductions along the scattered indices are not tilable in parallel.
193+
if (!getUniqueIndices()) {
194+
return failure();
195+
}
196+
// TODO: Support fusion along the index operand. For the index operand, the
197+
// offset + size must be the full size for the inner most dim.
198+
if (getInputs().getBeginOperandIndex() != operandNumber) {
199+
return failure();
200+
}
201+
202+
// The iteration domain is defined in terms of the |input|, so simply
203+
// use the given offsets/sizes.
204+
iterDomainOffsets.assign(offsets.begin(), offsets.end());
205+
iterDomainSizes.assign(sizes.begin(), sizes.end());
206+
return success();
207+
}
208+
209+
/// Method to generate the tiled implementation of an operation from the tile
210+
/// of the operand.
211+
FailureOr<TilingResult> ScatterOp::getTiledImplementationFromOperandTile(
212+
OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
213+
ArrayRef<OpFoldResult> sizes) {
214+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
215+
if (failed(getIterationDomainTileFromOperandTile(
216+
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
217+
return failure();
218+
}
219+
return getTiledImplementation(b, mappedOffsets, mappedSizes);
220+
}
221+
184222
LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
185223
Location loc,
186224
ValueRange ivs) {

0 commit comments

Comments
 (0)