Skip to content

Commit 2e491e9

Browse files
committed
PR #5961 Review Fixes — Per-Fusion Statement Tracking (#6015)
## Summary Review fixes for PR #5961 (Per-Fusion statement tracking): - **O(n²) → O(n)**: Optimize `removeStatementsOwnedBy` with `std::erase_if` - **Per-Fusion counts**: Convert `numExprs()`/`numVals()` to return per-Fusion counts instead of global - **StatementGuard fixes**: Snapshot and compare per-Fusion counts for correct LIFO rollback in shared containers - **LIFO assertions**: Verify tail elements belong to this Fusion before popping ## Tests All tests pass: - ✅ StatementGuardTest.ExecuteAfterGuard - ✅ StatementGuardTest.LazySpecialValsNotDangling - ✅ FusionCopy_CUDA - ✅ FusionMove_CUDA
1 parent fc83f2d commit 2e491e9

File tree

4 files changed

+42
-35
lines changed

4 files changed

+42
-35
lines changed

csrc/fusion.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -426,24 +426,14 @@ void Fusion::removeStatementsCreatedAfter(
426426
int64_t num_vals_before) {
427427
auto* c = ir_container();
428428

429-
NVF_ERROR(
430-
c->exprs_up_.size() == c->exprs_.size(),
431-
"exprs_up_ (size ",
432-
c->exprs_up_.size(),
433-
") and exprs_ (size ",
434-
c->exprs_.size(),
435-
") are out of sync.");
436-
NVF_ERROR(
437-
std::ssize(c->exprs_up_) >= num_exprs_before,
438-
"exprs_up_ size (",
439-
std::ssize(c->exprs_up_),
440-
") is less than num_exprs_before (",
441-
num_exprs_before,
442-
").");
443-
444429
// Remove expressions before values because we need to change Val::uses_.
445-
while (std::ssize(c->exprs_up_) > num_exprs_before) {
430+
while (std::ssize(c->exprsOwnedBy(this)) > num_exprs_before) {
431+
// Pop from global deque back — statements created by this Fusion during
432+
// the guard scope are at the tail (LIFO invariant).
446433
Expr* e = c->exprs_up_.back().get();
434+
NVF_ERROR(
435+
c->per_fusion_exprs_[this].count(e) > 0,
436+
"removeStatementsCreatedAfter: tail expr belongs to another Fusion");
447437
for (Val* in : e->inputs()) {
448438
in->removeUse(e);
449439
}
@@ -452,8 +442,12 @@ void Fusion::removeStatementsCreatedAfter(
452442
c->exprs_up_.pop_back();
453443
}
454444

455-
while (std::ssize(c->vals_up_) > num_vals_before) {
445+
while (numValsExcludingShortcuts() > num_vals_before) {
456446
Val* v = c->vals_up_.back().get();
447+
NVF_ERROR(
448+
c->per_fusion_vals_[this].count(v) > 0,
449+
"removeStatementsCreatedAfter: tail val belongs to another Fusion");
450+
// Null out shortcut caches if they point to vals about to be destroyed
457451
if (v == zero_val_) {
458452
zero_val_ = nullptr;
459453
} else if (v == one_val_) {

csrc/fusion.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -547,13 +547,26 @@ class NVF_API Fusion : public PolymorphicBase {
547547
return ir_container()->valsOwnedBy(this);
548548
}
549549

550-
// Count queries
550+
// Count queries (per-Fusion: only counts statements owned by this Fusion)
551551
int64_t numExprs() const noexcept {
552-
return ir_container()->numExprs();
552+
return std::ssize(ir_container()->exprsOwnedBy(this));
553553
}
554554

555555
int64_t numVals() const noexcept {
556-
return ir_container()->numVals();
556+
return std::ssize(ir_container()->valsOwnedBy(this));
557+
}
558+
559+
//! Return per-Fusion val count excluding shortcut vals (zero_val_, etc.).
560+
//! Shortcut vals are registered in both per_fusion_vals_ and vals_up_, but
561+
//! since they're singletons that should persist across StatementGuard scopes,
562+
//! this count excludes them so the LIFO pop-back in
563+
//! removeStatementsCreatedAfter correctly skips over them.
564+
int64_t numValsExcludingShortcuts() const noexcept {
565+
int64_t count = std::ssize(ir_container()->valsOwnedBy(this));
566+
count -= (zero_val_ != nullptr) + (one_val_ != nullptr) +
567+
(true_val_ != nullptr) + (false_val_ != nullptr) +
568+
(magic_zero_val_ != nullptr);
569+
return count;
557570
}
558571

559572
// Shortcut values (frequently used constants)

csrc/ir/container.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -218,27 +218,27 @@ void IrContainer::transferStatementOwnership(
218218
void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) {
219219
auto vals_it = per_fusion_vals_.find(fusion);
220220
if (vals_it != per_fusion_vals_.end()) {
221-
for (auto it = vals_up_.begin(); it != vals_up_.end();) {
222-
if (vals_it->second.count(it->get()) > 0) {
223-
vals_.erase(it->get());
224-
it = vals_up_.erase(it);
225-
} else {
226-
++it;
221+
const auto& owned = vals_it->second;
222+
std::erase_if(vals_up_, [&](const std::unique_ptr<Val>& v) {
223+
if (owned.count(v.get()) > 0) {
224+
vals_.erase(v.get());
225+
return true;
227226
}
228-
}
227+
return false;
228+
});
229229
per_fusion_vals_.erase(vals_it);
230230
}
231231

232232
auto exprs_it = per_fusion_exprs_.find(fusion);
233233
if (exprs_it != per_fusion_exprs_.end()) {
234-
for (auto it = exprs_up_.begin(); it != exprs_up_.end();) {
235-
if (exprs_it->second.count(it->get()) > 0) {
236-
exprs_.erase(it->get());
237-
it = exprs_up_.erase(it);
238-
} else {
239-
++it;
234+
const auto& owned = exprs_it->second;
235+
std::erase_if(exprs_up_, [&](const std::unique_ptr<Expr>& e) {
236+
if (owned.count(e.get()) > 0) {
237+
exprs_.erase(e.get());
238+
return true;
240239
}
241-
}
240+
return false;
241+
});
242242
per_fusion_exprs_.erase(exprs_it);
243243
}
244244
}

csrc/statement_guard.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ StatementGuard::StatementGuard(Fusion* fusion)
2020
return fusion;
2121
}()),
2222
prev_num_exprs_(fusion_->numExprs()),
23-
prev_num_vals_(fusion_->numVals()) {}
23+
prev_num_vals_(fusion_->numValsExcludingShortcuts()) {}
2424

2525
StatementGuard::~StatementGuard() {
2626
fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_);

0 commit comments

Comments
 (0)