Skip to content

Commit d3cfe11

Browse files
authored
[GPU] Set insertion point to last slice index operand in reshape and slice fusion (#19959)
Empty tensor elimination relies on dominance of SSA values when attempting to reuse buffers for slices of init operands. Ideally, empty tensor elimination should be able to handle this, but it is difficult to fix at that level. For now, this PR tries to avoid creating these dominance issues in the first place. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 99304ff commit d3cfe11

File tree

2 files changed

+45
-6
lines changed

2 files changed

+45
-6
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_collapse_shape_with_forall.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,13 @@ module attributes {transform.with_named_sequence} {
9898
// CHECK-SAME: {{\[}}[0], [1, 2]] output_shape [%[[SIZE0]], %[[SIZE1]], 8] : tensor<?x?xf32> into tensor<?x?x8xf32>
9999
// CHECK-DAG: %[[SLICE_SIZE_0:.+]] = affine.min #map(%[[IDX0]])[%[[SIZE0]]]
100100
// CHECK-DAG: %[[SLICE_SIZE_1:.+]] = affine.min #map(%[[IDX1]])[%[[SIZE1]]]
101+
// CHECK-DAG: %[[LINEAR_SLICE_IDX:.+]] = affine.linearize_index disjoint [%[[IDX1]], %[[C0]]] by (%[[SIZE1]], 8) : index
102+
// CHECK-DAG: %[[COLLAPSED_SLICE_SIZE:.+]] = affine.apply #[[$MAP1]](%[[SLICE_SIZE_1]])
101103
// CHECK-DAG: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]]
102104
// CHECK-SAME: [%[[IDX0]], %[[IDX1]], 0]{{.*}}[%[[SLICE_SIZE_0]], %[[SLICE_SIZE_1]], 8] [1, 1, 1] : tensor<?x?x8xf32> to tensor<?x?x8xf32>
103105
// CHECK-DAG: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[EXPANDED_BBARG]]
104106
// CHECK-SAME: [%[[IDX0]], %[[IDX1]], 0] [%[[SLICE_SIZE_0]], %[[SLICE_SIZE_1]], 8] [1, 1, 1] : tensor<?x?x8xf32> to tensor<?x?x8xf32>
105107
// CHECK-DAG: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<?x?x8xf32>) outs(%[[OUT_SLICE]] : tensor<?x?x8xf32>) -> tensor<?x?x8xf32>
106-
// CHECK-DAG: %[[LINEAR_SLICE_IDX:.+]] = affine.linearize_index disjoint [%[[IDX1]], %[[C0]]] by (%[[SIZE1]], 8) : index
107-
// CHECK-DAG: %[[COLLAPSED_SLICE_SIZE:.+]] = affine.apply #[[$MAP1]](%[[SLICE_SIZE_1]])
108108
// CHECK-DAG: %[[COLLAPSED_COPY:.+]] = tensor.collapse_shape %[[COPY]] {{\[}}[0], [1, 2]] : tensor<?x?x8xf32> into tensor<?x?xf32>
109109
// CHECK: scf.forall.in_parallel {
110110
// CHECK: tensor.parallel_insert_slice %[[COLLAPSED_COPY]] into %[[COLLAPSED_BBARG]]

compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,43 @@ collapsibleSlicePrecondition(RewriterBase &rewriter,
354354
return success();
355355
}
356356

357+
/// Given a tensor.parallel_insert_slice op, find all values that are needed to
358+
/// build an equivalent subset extract_slice, and set the insertion point to the
359+
/// last of these values. This helper is useful in cases where additional index
360+
/// computation must be composed with the current indexing operations for the
361+
/// slice, since we want all index operations for the slice to retain the same
362+
/// level of dominance after composing the new computation.
363+
static Operation *
364+
setInsertionPointAfterLastIndexOperand(RewriterBase &rewriter,
365+
tensor::ParallelInsertSliceOp op) {
366+
DominanceInfo domInfo;
367+
auto subsetOp = cast<SubsetInsertionOpInterface>(op.getOperation());
368+
SmallVector<Value> values = subsetOp.getValuesNeededToBuildSubsetExtraction();
369+
Operation *lastOp = nullptr;
370+
bool setInsertionPointBefore = false;
371+
for (auto val : values) {
372+
auto definingOp = val.getDefiningOp();
373+
if (!definingOp) {
374+
definingOp =
375+
&cast<BlockArgument>(val).getOwner()->getOperations().front();
376+
}
377+
if (!definingOp || (lastOp && domInfo.dominates(definingOp, lastOp)))
378+
continue;
379+
lastOp = definingOp;
380+
381+
// For block arguments we want the insertion point to be at the start of
382+
// the block, so we need to set the insertion point before the first op
383+
// in the block.
384+
setInsertionPointBefore = isa<BlockArgument>(val);
385+
}
386+
if (setInsertionPointBefore) {
387+
rewriter.setInsertionPoint(lastOp);
388+
} else {
389+
rewriter.setInsertionPointAfter(lastOp);
390+
}
391+
return lastOp;
392+
}
393+
357394
/// Collapse all `ops` with the given `reassociations`. All `ops` are expected
358395
/// to have equivalent offsets, sizes, and strides. All strides are expected to
359396
/// be 1. This function assumes that the parallelInsertOp passes the
@@ -363,8 +400,9 @@ collapseParallelInsertOp(RewriterBase &rewriter,
363400
tensor::ParallelInsertSliceOp parallelInsertOp,
364401
SmallVector<ReassociationIndices> reassociations) {
365402
// Compute the collapsed offsets, sizes, and strides.
366-
rewriter.setInsertionPoint(parallelInsertOp.getParallelCombiningParent());
367-
Location loc = parallelInsertOp->getLoc();
403+
Operation *lastOp =
404+
setInsertionPointAfterLastIndexOperand(rewriter, parallelInsertOp);
405+
Location loc = lastOp->getLoc();
368406
int64_t resultIdx = parallelInsertOp.getTiedOpResult().getResultNumber();
369407
auto forallOp = parallelInsertOp->getParentOfType<scf::ForallOp>();
370408
Value loopInit = forallOp.getOutputs()[resultIdx];
@@ -555,8 +593,9 @@ clampParallelInsertSliceOp(RewriterBase &rewriter,
555593
tensor::ParallelInsertSliceOp parallelInsertOp,
556594
SmallVector<OpFoldResult> upperBoundSizes) {
557595
OpBuilder::InsertionGuard g(rewriter);
558-
rewriter.setInsertionPoint(parallelInsertOp.getParallelCombiningParent());
559-
Location loc = parallelInsertOp.getParallelCombiningParent()->getLoc();
596+
Operation *lastOp =
597+
setInsertionPointAfterLastIndexOperand(rewriter, parallelInsertOp);
598+
Location loc = lastOp->getLoc();
560599

561600
// Clamp the parallel_insert_slice sizes to fit within the full result tensor.
562601
SmallVector<OpFoldResult> offsets = parallelInsertOp.getMixedOffsets();

0 commit comments

Comments
 (0)