Skip to content

Commit 96829a4

Browse files
authored
Temporary automatic reference counting(ish) pass for inserting async deallocations. (#20765)
This performs local analysis only and bails on almost any case other than the most trivial (calls, control flow (cf/scf), etc). It does seem to work well for current programs that require it, though. The intent is that the ARC pass would grow to cover all programs via global analysis and handle program boundary cases around I/O using the new `stream.resource.retain`/`stream.resource.release` ops. Currently those cases are ignored. This opens the opportunity for some in-compiler reuse analysis that can happen after the ARC pass runs to take dealloca -> (join) -> alloca sequences of the same affinity and size and reuse them. In extremely fragmented programs (like sharded tensor-level parallel models) this could eliminate nearly all allocations within the program. A simple local ReuseAllocationsPass was added to handle the basic cases and in TP 405B that reduces the total number of allocations from 36k to 4k (which is still way too high). Future global passes that track timelines better or options that allow users to aggressively reuse allocations that may be non-temporally adjacent could drop that number significantly now that the deallocations are modeled. A few canonicalization patterns were added for common cases that we want to eagerly fix in the IR, such as erasing unused allocations and flattening deallocation chains. There are some timepoint patterns required but more testing is required to know whether they are performance positive/neutral. On the TP 405b model we now end up with perfect pairings of deallocas to allocas in all but the boundary cases: https://gist.github.com/benvanik/6f7a8abdca4fca389955882e6e98cf9d After the new ReuseAllocationsPass the model is able to drop most transients moving between devices: ![image](https://github.com/user-attachments/assets/1c8be76f-8370-4d86-a400-3f2025f080fb)
1 parent 3470dbb commit 96829a4

File tree

12 files changed

+1531
-5
lines changed

12 files changed

+1531
-5
lines changed

compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp

Lines changed: 155 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,159 @@ void ResourceAllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
507507
// stream.resource.alloca
508508
//===----------------------------------------------------------------------===//
509509

510+
namespace {
511+
512+
// Elides transient allocations that have no uses of their resource.
513+
// This sometimes arises when operations that were using the resource are
514+
// DCEd by other patterns or passes. The ElideAllocaDeallocaOp pattern will be
515+
// used after deallocations have been inserted but prior to that point this
516+
// pattern allows for more eager removal of unused allocations.
517+
struct ElideUnusedAllocaOp : public OpRewritePattern<ResourceAllocaOp> {
518+
using OpRewritePattern::OpRewritePattern;
519+
LogicalResult matchAndRewrite(ResourceAllocaOp allocaOp,
520+
PatternRewriter &rewriter) const override {
521+
if (!allocaOp.getResult().use_empty()) {
522+
return failure(); // >= 1 user
523+
}
524+
Value newTimepoint = allocaOp.getAwaitTimepoint();
525+
if (!newTimepoint) {
526+
newTimepoint = rewriter.create<IREE::Stream::TimepointImmediateOp>(
527+
allocaOp.getLoc());
528+
}
529+
rewriter.replaceAllUsesWith(allocaOp.getResultTimepoint(), newTimepoint);
530+
rewriter.eraseOp(allocaOp);
531+
return success();
532+
}
533+
};
534+
535+
// Elides transient allocations that are only used by deallocations.
536+
// This sometimes arises when operations that were using the resource are
537+
// DCEd by other patterns or passes.
538+
//
539+
// Example:
540+
// %resource, %alloca_t = stream.resource.alloca
541+
// %dealloca_t = stream.resource.dealloca await(%alloca_t) %resource
542+
struct ElideAllocaDeallocaOp : public OpRewritePattern<ResourceAllocaOp> {
543+
using OpRewritePattern::OpRewritePattern;
544+
LogicalResult matchAndRewrite(ResourceAllocaOp allocaOp,
545+
PatternRewriter &rewriter) const override {
546+
if (!allocaOp.getResult().hasOneUse()) {
547+
return failure(); // more than one user
548+
}
549+
auto user = *allocaOp.getResult().getUsers().begin();
550+
auto deallocaOp = dyn_cast<IREE::Stream::ResourceDeallocaOp>(user);
551+
if (!deallocaOp) {
552+
return failure(); // not used by a dealloca
553+
}
554+
555+
// Replace waiters on the alloca and dealloca.
556+
// Note that the dealloca may be using the timepoint of the alloca so we
557+
// replace that first.
558+
Value newAllocaTimepoint = allocaOp.getAwaitTimepoint();
559+
if (!newAllocaTimepoint) {
560+
newAllocaTimepoint = rewriter.create<IREE::Stream::TimepointImmediateOp>(
561+
allocaOp.getLoc());
562+
}
563+
rewriter.replaceAllUsesWith(allocaOp.getResultTimepoint(),
564+
newAllocaTimepoint);
565+
Value newDeallocaTimepoint = deallocaOp.getAwaitTimepoint();
566+
if (!newDeallocaTimepoint) {
567+
newDeallocaTimepoint =
568+
rewriter.create<IREE::Stream::TimepointImmediateOp>(
569+
deallocaOp.getLoc());
570+
}
571+
rewriter.replaceAllUsesWith(deallocaOp.getResultTimepoint(),
572+
newDeallocaTimepoint);
573+
574+
// Erase the deallocation first (its the only user of the allocated
575+
// resource).
576+
rewriter.eraseOp(deallocaOp);
577+
rewriter.eraseOp(allocaOp);
578+
579+
return success();
580+
}
581+
};
582+
583+
// Finds sequences of chained allocas/deallocas and rewrites them to batch as
584+
// many as possible on a single timepoint. This is done as a canonicalization as
585+
// it is always intended that allocations and deallocations do not wait and we
586+
// can repeatedly optimize when run as part of a larger canonicalization pass
587+
// that cleans up timepoints with other patterns as we modify them here.
588+
//
589+
// Example:
590+
// %d0 = dealloca await(%t)
591+
// %d1 = dealloca await(%d0)
592+
// %d2 = dealloca await(%d1)
593+
// %d3 = dealloca await(%d2)
594+
// ... await(%d3)
595+
// ->
596+
// %d0 = dealloca await(%t)
597+
// %d1 = dealloca await(%t)
598+
// %d2 = dealloca await(%t)
599+
// %d3 = dealloca await(%t)
600+
// %j = join %d0, %d1, %d2, %d3
601+
// ... await(%j)
602+
template <typename OpT>
603+
struct BatchAllocaOps : public OpRewritePattern<OpT> {
604+
using OpRewritePattern<OpT>::OpRewritePattern;
605+
LogicalResult matchAndRewrite(OpT op,
606+
PatternRewriter &rewriter) const override {
607+
// Gather alloca ops chained on timepoints starting from this op.
608+
SmallVector<OpT> allocaOps;
609+
OpT nextOp = op;
610+
while (nextOp) {
611+
Value resultTimepoint = nextOp.getResultTimepoint();
612+
allocaOps.push_back(nextOp);
613+
if (!resultTimepoint.hasOneUse()) {
614+
break;
615+
}
616+
nextOp = dyn_cast<OpT>(*resultTimepoint.user_begin());
617+
}
618+
if (allocaOps.size() <= 1) {
619+
return failure(); // no-op if only one op
620+
}
621+
622+
// Gather the result timepoints of all alloca ops so we can join on them.
623+
// We'll issue all of them concurrently and only join after all
624+
// deallocations complete.
625+
SmallVector<Location> allocaLocs;
626+
SmallVector<Value> allocaTimepoints;
627+
for (auto allocaOp : allocaOps) {
628+
allocaLocs.push_back(allocaOp.getLoc());
629+
allocaTimepoints.push_back(allocaOp.getResultTimepoint());
630+
}
631+
rewriter.setInsertionPointAfter(allocaOps.back());
632+
auto joinOp = rewriter.create<IREE::Stream::TimepointJoinOp>(
633+
rewriter.getFusedLoc(allocaLocs),
634+
rewriter.getType<IREE::Stream::TimepointType>(), allocaTimepoints);
635+
636+
// Make all alloca ops wait on the earliest timepoint so they can proceed
637+
// together. Note that the origin op may be waiting on an immediate
638+
// timepoint and be nullptr.
639+
Value awaitTimepoint = op.getAwaitTimepoint();
640+
for (auto allocaOp : allocaOps) {
641+
rewriter.modifyOpInPlace(allocaOp, [&]() {
642+
allocaOp.getAwaitTimepointMutable().assign(awaitTimepoint);
643+
});
644+
}
645+
646+
// Replace the tail timepoint in the alloca chain with the join result so
647+
// subsequent waiters are waiting on the batch.
648+
rewriter.replaceAllUsesExcept(allocaOps.back().getResultTimepoint(),
649+
joinOp.getResultTimepoint(), joinOp);
650+
651+
return success();
652+
}
653+
};
654+
655+
} // namespace
656+
510657
void ResourceAllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
511658
MLIRContext *context) {
512659
// TODO(benvanik): sink to first user.
513-
// TODO(benvanik): elide if only user is dealloc.
660+
results.insert<ElideUnusedAllocaOp>(context);
661+
results.insert<ElideAllocaDeallocaOp>(context);
662+
results.insert<BatchAllocaOps<ResourceAllocaOp>>(context);
514663
results.insert<ElideImmediateTimepointWait<ResourceAllocaOp>>(context);
515664
}
516665

@@ -521,6 +670,7 @@ void ResourceAllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
521670
void ResourceDeallocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
522671
MLIRContext *context) {
523672
// TODO(benvanik): move up to producer of timepoint.
673+
results.insert<BatchAllocaOps<ResourceDeallocaOp>>(context);
524674
results.insert<ElideImmediateTimepointWait<ResourceDeallocaOp>>(context);
525675
}
526676

@@ -2827,10 +2977,10 @@ namespace {
28272977

28282978
// Elides a region-carrying op when the region is empty.
28292979
// Requires no results that need replacement.
2830-
template <typename Op>
2831-
struct ElideEmptyCmdRegionOp : public OpRewritePattern<Op> {
2832-
using OpRewritePattern<Op>::OpRewritePattern;
2833-
LogicalResult matchAndRewrite(Op op,
2980+
template <typename OpT>
2981+
struct ElideEmptyCmdRegionOp : public OpRewritePattern<OpT> {
2982+
using OpRewritePattern<OpT>::OpRewritePattern;
2983+
LogicalResult matchAndRewrite(OpT op,
28342984
PatternRewriter &rewriter) const override {
28352985
auto &entryBlock = op.getBody().front();
28362986
auto yieldOp = getYieldIfOnlyOp(entryBlock);

compiler/src/iree/compiler/Dialect/Stream/IR/test/resource_folding.mlir

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,73 @@ util.func private @SelectResourceSizeOp(%arg0: !stream.resource<staging>, %arg1:
3030

3131
// -----
3232

33+
// Erases allocation ops that have no users of their allocated resource.
34+
35+
// CHECK-LABEL: @ElideUnusedAllocaOp
36+
// CHECK-SAME: (%[[AWAIT_TIMEPOINT:.+]]: !stream.timepoint, %[[SIZE:.+]]: index)
37+
util.func private @ElideUnusedAllocaOp(%await_timepoint: !stream.timepoint, %size: index) -> (!stream.timepoint, !stream.timepoint) {
38+
// CHECK-NOT: stream.resource.alloca
39+
// CHECK: %[[IMMEDIATE_TIMEPOINT:.+]] = stream.timepoint.immediate
40+
%resource0, %immediate_timepoint = stream.resource.alloca uninitialized : !stream.resource<transient>{%size} => !stream.timepoint
41+
// CHECK-NOT: stream.resource.alloca
42+
%resource1, %alloca_timepoint = stream.resource.alloca uninitialized await(%await_timepoint) => !stream.resource<transient>{%size} => !stream.timepoint
43+
// CHECK: util.return %[[IMMEDIATE_TIMEPOINT]], %[[AWAIT_TIMEPOINT]]
44+
util.return %immediate_timepoint, %alloca_timepoint : !stream.timepoint, !stream.timepoint
45+
}
46+
47+
// -----
48+
49+
// Erases allocation ops that are only ever used by a deallocation.
50+
51+
// CHECK-LABEL: @ElideAllocaDeallocaOp
52+
// CHECK-SAME: (%[[AWAIT_TIMEPOINT:.+]]: !stream.timepoint, %[[SIZE:.+]]: index)
53+
util.func private @ElideAllocaDeallocaOp(%await_timepoint: !stream.timepoint, %size: index) -> (!stream.timepoint, !stream.timepoint) {
54+
// CHECK-NOT: stream.resource.alloca
55+
%resource, %alloca_timepoint = stream.resource.alloca uninitialized await(%await_timepoint) => !stream.resource<transient>{%size} => !stream.timepoint
56+
// CHECK-NOT: stream.resource.dealloca
57+
%dealloca_timepoint = stream.resource.dealloca origin await(%alloca_timepoint) => %resource : !stream.resource<transient>{%size} => !stream.timepoint
58+
// CHECK: util.return %[[AWAIT_TIMEPOINT]], %[[AWAIT_TIMEPOINT]]
59+
util.return %alloca_timepoint, %dealloca_timepoint : !stream.timepoint, !stream.timepoint
60+
}
61+
62+
// -----
63+
64+
// CHECK-LABEL: @BatchAllocaOps
65+
// CHECK-SAME: (%[[AWAIT_TIMEPOINT:.+]]: !stream.timepoint, %[[SIZE:.+]]: index)
66+
util.func private @BatchAllocaOps(%await_timepoint: !stream.timepoint, %size: index) -> (!stream.resource<transient>, !stream.resource<transient>, !stream.resource<transient>, !stream.resource<transient>, !stream.timepoint) {
67+
// CHECK: %[[ALLOCA0:.+]], %[[ALLOCA0_TIMEPOINT:.+]] = stream.resource.alloca uninitialized await(%[[AWAIT_TIMEPOINT]])
68+
%alloca0, %alloca0_timepoint = stream.resource.alloca uninitialized await(%await_timepoint) => !stream.resource<transient>{%size} => !stream.timepoint
69+
// CHECK: %[[ALLOCA1:.+]], %[[ALLOCA1_TIMEPOINT:.+]] = stream.resource.alloca uninitialized await(%[[AWAIT_TIMEPOINT]])
70+
%alloca1, %alloca1_timepoint = stream.resource.alloca uninitialized await(%alloca0_timepoint) => !stream.resource<transient>{%size} => !stream.timepoint
71+
// CHECK: %[[ALLOCA2:.+]], %[[ALLOCA2_TIMEPOINT:.+]] = stream.resource.alloca uninitialized await(%[[AWAIT_TIMEPOINT]])
72+
%alloca2, %alloca2_timepoint = stream.resource.alloca uninitialized await(%alloca1_timepoint) => !stream.resource<transient>{%size} => !stream.timepoint
73+
// CHECK: %[[ALLOCA3:.+]], %[[ALLOCA3_TIMEPOINT:.+]] = stream.resource.alloca uninitialized await(%[[AWAIT_TIMEPOINT]])
74+
%alloca3, %alloca3_timepoint = stream.resource.alloca uninitialized await(%alloca2_timepoint) => !stream.resource<transient>{%size} => !stream.timepoint
75+
// CHECK: %[[JOIN_TIMEPOINT:.+]] = stream.timepoint.join max(%[[ALLOCA0_TIMEPOINT]], %[[ALLOCA1_TIMEPOINT]], %[[ALLOCA2_TIMEPOINT]], %[[ALLOCA3_TIMEPOINT]]) => !stream.timepoint
76+
// CHECK: util.return %[[ALLOCA0]], %[[ALLOCA1]], %[[ALLOCA2]], %[[ALLOCA3]], %[[JOIN_TIMEPOINT]]
77+
util.return %alloca0, %alloca1, %alloca2, %alloca3, %alloca3_timepoint : !stream.resource<transient>, !stream.resource<transient>, !stream.resource<transient>, !stream.resource<transient>, !stream.timepoint
78+
}
79+
80+
// -----
81+
82+
// CHECK-LABEL: @BatchDeallocaOps
83+
// CHECK-SAME: (%[[AWAIT_TIMEPOINT:.+]]: !stream.timepoint, %[[RESOURCE0:.+]]: !stream.resource<transient>, %[[RESOURCE1:.+]]: !stream.resource<transient>, %[[RESOURCE2:.+]]: !stream.resource<transient>, %[[RESOURCE3:.+]]: !stream.resource<transient>, %[[SIZE:.+]]: index)
84+
util.func private @BatchDeallocaOps(%await_timepoint: !stream.timepoint, %resource0: !stream.resource<transient>, %resource1: !stream.resource<transient>, %resource2: !stream.resource<transient>, %resource3: !stream.resource<transient>, %size: index) -> !stream.timepoint {
85+
// CHECK: %[[DEALLOCA0_TIMEPOINT:.+]] = stream.resource.dealloca origin await(%[[AWAIT_TIMEPOINT]]) => %[[RESOURCE0]]
86+
%dealloca0_timepoint = stream.resource.dealloca origin await(%await_timepoint) => %resource0 : !stream.resource<transient>{%size} => !stream.timepoint
87+
// CHECK: %[[DEALLOCA1_TIMEPOINT:.+]] = stream.resource.dealloca origin await(%[[AWAIT_TIMEPOINT]]) => %[[RESOURCE1]]
88+
%dealloca1_timepoint = stream.resource.dealloca origin await(%dealloca0_timepoint) => %resource1 : !stream.resource<transient>{%size} => !stream.timepoint
89+
// CHECK: %[[DEALLOCA2_TIMEPOINT:.+]] = stream.resource.dealloca origin await(%[[AWAIT_TIMEPOINT]]) => %[[RESOURCE2]]
90+
%dealloca2_timepoint = stream.resource.dealloca origin await(%dealloca1_timepoint) => %resource2 : !stream.resource<transient>{%size} => !stream.timepoint
91+
// CHECK: %[[DEALLOCA3_TIMEPOINT:.+]] = stream.resource.dealloca origin await(%[[AWAIT_TIMEPOINT]]) => %[[RESOURCE3]]
92+
%dealloca3_timepoint = stream.resource.dealloca origin await(%dealloca2_timepoint) => %resource3 : !stream.resource<transient>{%size} => !stream.timepoint
93+
// CHECK: %[[JOIN_TIMEPOINT:.+]] = stream.timepoint.join max(%[[DEALLOCA0_TIMEPOINT]], %[[DEALLOCA1_TIMEPOINT]], %[[DEALLOCA2_TIMEPOINT]], %[[DEALLOCA3_TIMEPOINT]]) => !stream.timepoint
94+
// CHECK: util.return %[[JOIN_TIMEPOINT]]
95+
util.return %dealloca3_timepoint : !stream.timepoint
96+
}
97+
98+
// -----
99+
33100
// CHECK-LABEL: @FoldSubviewIntoLoadOp
34101
util.func private @FoldSubviewIntoLoadOp(%arg0: !stream.resource<staging>, %arg1: index) -> i32 {
35102
%c64 = arith.constant 64 : index

0 commit comments

Comments
 (0)