Skip to content

Commit 14eeb8e

Browse files
committed
fix: enforce non-null ir_container_ invariant in Fusion::swap
1 parent 0cd10e8 commit 14eeb8e

File tree

1 file changed

+22
-28
lines changed

1 file changed

+22
-28
lines changed

csrc/fusion.cpp

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -112,28 +112,27 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
112112
return;
113113
}
114114

115+
NVF_ERROR(a.ir_container_ != nullptr, "Fusion::swap: a has null ir_container_");
116+
NVF_ERROR(b.ir_container_ != nullptr, "Fusion::swap: b has null ir_container_");
117+
115118
// Collect statements owned by each Fusion BEFORE swap so we can update
116119
// Statement::ir_container_ pointers afterward.
117120
std::vector<Val*> a_owned_vals, b_owned_vals;
118121
std::vector<Expr*> a_owned_exprs, b_owned_exprs;
119122

120-
if (a.ir_container_) {
121-
const auto& av = a.ir_container_->valsOwnedBy(&a);
122-
const auto& ae = a.ir_container_->exprsOwnedBy(&a);
123-
a_owned_vals.assign(av.begin(), av.end());
124-
a_owned_exprs.assign(ae.begin(), ae.end());
125-
}
126-
if (b.ir_container_) {
127-
const auto& bv = b.ir_container_->valsOwnedBy(&b);
128-
const auto& be = b.ir_container_->exprsOwnedBy(&b);
129-
b_owned_vals.assign(bv.begin(), bv.end());
130-
b_owned_exprs.assign(be.begin(), be.end());
131-
}
123+
const auto& av = a.ir_container_->valsOwnedBy(&a);
124+
const auto& ae = a.ir_container_->exprsOwnedBy(&a);
125+
a_owned_vals.assign(av.begin(), av.end());
126+
a_owned_exprs.assign(ae.begin(), ae.end());
127+
128+
const auto& bv = b.ir_container_->valsOwnedBy(&b);
129+
const auto& be = b.ir_container_->exprsOwnedBy(&b);
130+
b_owned_vals.assign(bv.begin(), bv.end());
131+
b_owned_exprs.assign(be.begin(), be.end());
132132

133133
// Transfer Fusion registrations between containers before pointer swap.
134134
// After swap, a will own b's container and b will own a's container.
135-
if (a.ir_container_ && b.ir_container_ &&
136-
a.ir_container_.get() != b.ir_container_.get()) {
135+
if (a.ir_container_.get() != b.ir_container_.get()) {
137136
a.ir_container_->transferFusion(&a, &b);
138137
b.ir_container_->transferFusion(&b, &a);
139138
}
@@ -176,21 +175,16 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
176175
expr->ir_container_ = &a;
177176
}
178177

179-
// Update per-Fusion tracking keys in containers
180-
if (a.ir_container_ && b.ir_container_) {
181-
if (a.ir_container_.get() == b.ir_container_.get()) {
182-
// Same container: directly swap per-Fusion tracking entries
183-
auto* c = a.ir_container_.get();
184-
std::swap(c->per_fusion_vals_[&a], c->per_fusion_vals_[&b]);
185-
std::swap(c->per_fusion_exprs_[&a], c->per_fusion_exprs_[&b]);
186-
} else {
187-
// Different containers: rename tracking keys to match new owners
188-
a.ir_container_->transferStatementOwnership(&b, &a);
189-
b.ir_container_->transferStatementOwnership(&a, &b);
190-
}
191-
} else if (a.ir_container_) {
178+
// Update per-Fusion tracking keys in containers. At this point, both
179+
// a and b are guaranteed to have non-null ir_container_ (verified above).
180+
if (a.ir_container_.get() == b.ir_container_.get()) {
181+
// Same container: directly swap per-Fusion tracking entries
182+
auto* c = a.ir_container_.get();
183+
std::swap(c->per_fusion_vals_[&a], c->per_fusion_vals_[&b]);
184+
std::swap(c->per_fusion_exprs_[&a], c->per_fusion_exprs_[&b]);
185+
} else {
186+
// Different containers: rename tracking keys to match new owners
192187
a.ir_container_->transferStatementOwnership(&b, &a);
193-
} else if (b.ir_container_) {
194188
b.ir_container_->transferStatementOwnership(&a, &b);
195189
}
196190
}

0 commit comments

Comments
 (0)