Skip to content

Commit bc595c5

Browse files
authored
PR #5961 Review Fixes — Per-Fusion Statement Tracking (#6015)
## Summary Review fixes for PR #5961 (Per-Fusion statement tracking): - **O(n²) → O(n)**: Optimize `removeStatementsOwnedBy` with `std::erase_if` - **Per-Fusion counts**: Convert `numExprs()`/`numVals()` to return per-Fusion counts instead of global - **StatementGuard fixes**: Snapshot and compare per-Fusion counts for correct LIFO rollback in shared containers - **LIFO assertions**: Verify tail elements belong to this Fusion before popping ## Tests All tests pass: - ✅ StatementGuardTest.ExecuteAfterGuard - ✅ StatementGuardTest.LazySpecialValsNotDangling - ✅ FusionCopy_CUDA - ✅ FusionMove_CUDA
1 parent b8d202d commit bc595c5

File tree

4 files changed

+42
-35
lines changed

4 files changed

+42
-35
lines changed

csrc/fusion.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -427,24 +427,14 @@ void Fusion::removeStatementsCreatedAfter(
427427
int64_t num_vals_before) {
428428
auto* c = ir_container();
429429

430-
NVF_ERROR(
431-
c->exprs_up_.size() == c->exprs_.size(),
432-
"exprs_up_ (size ",
433-
c->exprs_up_.size(),
434-
") and exprs_ (size ",
435-
c->exprs_.size(),
436-
") are out of sync.");
437-
NVF_ERROR(
438-
std::ssize(c->exprs_up_) >= num_exprs_before,
439-
"exprs_up_ size (",
440-
std::ssize(c->exprs_up_),
441-
") is less than num_exprs_before (",
442-
num_exprs_before,
443-
").");
444-
445430
// Remove expressions before values because we need to change Val::uses_.
446-
while (std::ssize(c->exprs_up_) > num_exprs_before) {
431+
while (std::ssize(c->exprsOwnedBy(this)) > num_exprs_before) {
432+
// Pop from global deque back — statements created by this Fusion during
433+
// the guard scope are at the tail (LIFO invariant).
447434
Expr* e = c->exprs_up_.back().get();
435+
NVF_ERROR(
436+
c->per_fusion_exprs_[this].count(e) > 0,
437+
"removeStatementsCreatedAfter: tail expr belongs to another Fusion");
448438
for (Val* in : e->inputs()) {
449439
in->removeUse(e);
450440
}
@@ -453,8 +443,12 @@ void Fusion::removeStatementsCreatedAfter(
453443
c->exprs_up_.pop_back();
454444
}
455445

456-
while (std::ssize(c->vals_up_) > num_vals_before) {
446+
while (numValsExcludingShortcuts() > num_vals_before) {
457447
Val* v = c->vals_up_.back().get();
448+
NVF_ERROR(
449+
c->per_fusion_vals_[this].count(v) > 0,
450+
"removeStatementsCreatedAfter: tail val belongs to another Fusion");
451+
// Null out shortcut caches if they point to vals about to be destroyed
458452
if (v == zero_val_) {
459453
zero_val_ = nullptr;
460454
} else if (v == one_val_) {

csrc/fusion.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -546,13 +546,26 @@ class NVF_API Fusion : public PolymorphicBase {
546546
return ir_container()->valsOwnedBy(this);
547547
}
548548

549-
// Count queries
549+
// Count queries (per-Fusion: only counts statements owned by this Fusion)
550550
int64_t numExprs() const noexcept {
551-
return ir_container()->numExprs();
551+
return std::ssize(ir_container()->exprsOwnedBy(this));
552552
}
553553

554554
int64_t numVals() const noexcept {
555-
return ir_container()->numVals();
555+
return std::ssize(ir_container()->valsOwnedBy(this));
556+
}
557+
558+
//! Return per-Fusion val count excluding shortcut vals (zero_val_, etc.).
559+
//! Shortcut vals are registered in both per_fusion_vals_ and vals_up_, but
560+
//! since they're singletons that should persist across StatementGuard scopes,
561+
//! this count excludes them so the LIFO pop-back in
562+
//! removeStatementsCreatedAfter correctly skips over them.
563+
int64_t numValsExcludingShortcuts() const noexcept {
564+
int64_t count = std::ssize(ir_container()->valsOwnedBy(this));
565+
count -= (zero_val_ != nullptr) + (one_val_ != nullptr) +
566+
(true_val_ != nullptr) + (false_val_ != nullptr) +
567+
(magic_zero_val_ != nullptr);
568+
return count;
556569
}
557570

558571
// Shortcut values (frequently used constants)

csrc/ir/container.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -223,27 +223,27 @@ void IrContainer::transferStatementOwnership(
223223
void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) {
224224
auto vals_it = per_fusion_vals_.find(fusion);
225225
if (vals_it != per_fusion_vals_.end()) {
226-
for (auto it = vals_up_.begin(); it != vals_up_.end();) {
227-
if (vals_it->second.count(it->get()) > 0) {
228-
vals_.erase(it->get());
229-
it = vals_up_.erase(it);
230-
} else {
231-
++it;
226+
const auto& owned = vals_it->second;
227+
std::erase_if(vals_up_, [&](const std::unique_ptr<Val>& v) {
228+
if (owned.count(v.get()) > 0) {
229+
vals_.erase(v.get());
230+
return true;
232231
}
233-
}
232+
return false;
233+
});
234234
per_fusion_vals_.erase(vals_it);
235235
}
236236

237237
auto exprs_it = per_fusion_exprs_.find(fusion);
238238
if (exprs_it != per_fusion_exprs_.end()) {
239-
for (auto it = exprs_up_.begin(); it != exprs_up_.end();) {
240-
if (exprs_it->second.count(it->get()) > 0) {
241-
exprs_.erase(it->get());
242-
it = exprs_up_.erase(it);
243-
} else {
244-
++it;
239+
const auto& owned = exprs_it->second;
240+
std::erase_if(exprs_up_, [&](const std::unique_ptr<Expr>& e) {
241+
if (owned.count(e.get()) > 0) {
242+
exprs_.erase(e.get());
243+
return true;
245244
}
246-
}
245+
return false;
246+
});
247247
per_fusion_exprs_.erase(exprs_it);
248248
}
249249
}

csrc/statement_guard.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ StatementGuard::StatementGuard(Fusion* fusion)
2020
return fusion;
2121
}()),
2222
prev_num_exprs_(fusion_->numExprs()),
23-
prev_num_vals_(fusion_->numVals()) {}
23+
prev_num_vals_(fusion_->numValsExcludingShortcuts()) {}
2424

2525
StatementGuard::~StatementGuard() {
2626
fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_);

0 commit comments

Comments
 (0)