Skip to content

Commit 8b162d9

Browse files
committed
Per-Fusion statement tracking and ownership-filtered accessors
Add per_fusion_vals_ / per_fusion_exprs_ maps to IrContainer so each Fusion can efficiently query only its own statements in a shared container. Fusion forwarding methods (vals(), unordered_exprs(), deterministic_vals(), etc.) now return per-Fusion filtered results. Fusion::clear() uses removeStatementsOwnedBy(this) instead of ir_container()->clear().
1 parent f8ff364 commit 8b162d9

File tree

4 files changed

+173
-25
lines changed

4 files changed

+173
-25
lines changed

csrc/fusion.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,11 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
111111
// will only swap the ptrs NOT the contents.
112112
IrContainer::swap(*(a.ir_container()), *(b.ir_container()));
113113

114-
// Fix parent pointers after swapping containers
115-
// After swap, each Fusion owns a different IrContainer, so we must
116-
// update the parent backpointers in those containers to point to their new
117-
// owners
114+
// After swapping container contents, per-Fusion tracking keys point to the
115+
// wrong Fusions. Rename: a's container had b's entries, b's had a's.
116+
a.ir_container()->transferStatementOwnership(&b, &a);
117+
b.ir_container()->transferStatementOwnership(&a, &b);
118+
118119
if (a.ir_container_) {
119120
for (auto val : a.vals()) {
120121
val->ir_container_ = &a;
@@ -295,10 +296,9 @@ void Fusion::clear() noexcept {
295296
// constructor of Trace, which could throw an exception.
296297
// FUSER_PERF_SCOPE("Fusion clear");
297298

298-
// Clear container contents instead of destroying it
299-
// This preserves the container object so Statement pointers don't become
300-
// dangling
301-
ir_container()->clear();
299+
if (ir_container_) {
300+
ir_container_->removeStatementsOwnedBy(this);
301+
}
302302

303303
inputs_.clear();
304304
outputs_.clear();
@@ -308,8 +308,6 @@ void Fusion::clear() noexcept {
308308
managed_data_.clear();
309309
managed_named_data_.clear();
310310

311-
// Reset per-Fusion special value caches (the vals themselves are owned by
312-
// ir_container and were already destroyed by ir_container()->clear() above).
313311
zero_val_ = nullptr;
314312
one_val_ = nullptr;
315313
true_val_ = nullptr;
@@ -354,6 +352,7 @@ void Fusion::removeExpr(Expr* expr) {
354352
NVF_ERROR(
355353
expr_in_deque != c->exprs_up_.end(),
356354
"Wanted to remove an expression but its unique ptr is missing.");
355+
c->per_fusion_exprs_[this].erase(expr);
357356
c->exprs_.erase(expr);
358357
c->exprs_up_.erase(expr_in_deque);
359358
}
@@ -414,6 +413,7 @@ void Fusion::removeVal(Val* val) {
414413
NVF_ERROR(
415414
val_in_deque != c->vals_up_.end(),
416415
"Wanted to remove a value but its unique ptr is missing.");
416+
c->per_fusion_vals_[this].erase(val);
417417
c->vals_.erase(val);
418418
c->vals_up_.erase(val_in_deque);
419419

@@ -446,13 +446,11 @@ void Fusion::removeStatementsCreatedAfter(
446446
for (Val* in : e->inputs()) {
447447
in->removeUse(e);
448448
}
449+
c->per_fusion_exprs_[this].erase(e);
449450
c->exprs_.erase(e);
450451
c->exprs_up_.pop_back();
451452
}
452453

453-
// Null out any special value caches that point to vals about to be destroyed.
454-
// This prevents dangling pointers when special vals are lazily created inside
455-
// a StatementGuard scope.
456454
while (std::ssize(c->vals_up_) > num_vals_before) {
457455
Val* v = c->vals_up_.back().get();
458456
if (v == zero_val_) {
@@ -466,6 +464,7 @@ void Fusion::removeStatementsCreatedAfter(
466464
} else if (v == magic_zero_val_) {
467465
magic_zero_val_ = nullptr;
468466
}
467+
c->per_fusion_vals_[this].erase(v);
469468
c->vals_.erase(v);
470469
c->vals_up_.pop_back();
471470
}
@@ -932,6 +931,7 @@ void Fusion::registerVal(Val* val) {
932931
auto* c = ir_container();
933932
c->vals_up_.emplace_back(val);
934933
c->vals_.insert(val);
934+
c->per_fusion_vals_[this].insert(val);
935935
val->setName(IrContainerPasskey(), c->getValName(val->vtype()));
936936
}
937937

@@ -948,6 +948,7 @@ void Fusion::registerExpr(Expr* expr) {
948948
auto* c = ir_container();
949949
c->exprs_up_.emplace_back(expr);
950950
c->exprs_.insert(expr);
951+
c->per_fusion_exprs_[this].insert(expr);
951952
expr->setName(IrContainerPasskey(), c->getExprName());
952953

953954
for (Val* input : expr->inputs()) {

csrc/fusion.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -521,31 +521,29 @@ class NVF_API Fusion : public PolymorphicBase {
521521
}
522522

523523
// Collections access (return values in insertion order)
524-
const std::deque<Val*> deterministic_vals() const noexcept {
525-
return ir_container()->deterministic_vals();
524+
std::deque<Val*> deterministic_vals() const noexcept {
525+
return ir_container()->deterministicValsOwnedBy(this);
526526
}
527527

528-
const std::deque<Expr*> deterministic_exprs() const noexcept {
529-
return ir_container()->deterministic_exprs();
528+
std::deque<Expr*> deterministic_exprs() const noexcept {
529+
return ir_container()->deterministicExprsOwnedBy(this);
530530
}
531531

532-
const std::unordered_map<Val*, int64_t> deterministic_vals_map()
533-
const noexcept {
534-
return ir_container()->deterministic_vals_map();
532+
std::unordered_map<Val*, int64_t> deterministic_vals_map() const noexcept {
533+
return ir_container()->deterministicValsMapOwnedBy(this);
535534
}
536535

537-
const std::unordered_map<Expr*, int64_t> deterministic_exprs_map()
538-
const noexcept {
539-
return ir_container()->deterministic_exprs_map();
536+
std::unordered_map<Expr*, int64_t> deterministic_exprs_map() const noexcept {
537+
return ir_container()->deterministicExprsMapOwnedBy(this);
540538
}
541539

542540
// Collections access (unordered sets)
543541
const std::unordered_set<Expr*>& unordered_exprs() const noexcept {
544-
return ir_container()->unordered_exprs();
542+
return ir_container()->exprsOwnedBy(this);
545543
}
546544

547545
const std::unordered_set<Val*>& vals() const noexcept {
548-
return ir_container()->vals();
546+
return ir_container()->valsOwnedBy(this);
549547
}
550548

551549
// Count queries

csrc/ir/container.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept {
8080

8181
std::swap(a.val_type_name_map_, b.val_type_name_map_);
8282
std::swap(a.expr_name_counter_, b.expr_name_counter_);
83+
84+
std::swap(a.per_fusion_vals_, b.per_fusion_vals_);
85+
std::swap(a.per_fusion_exprs_, b.per_fusion_exprs_);
8386
}
8487

8588
IrCloner IrContainer::copy(
@@ -124,6 +127,8 @@ void IrContainer::clear() noexcept {
124127
exprs_up_.clear();
125128
val_type_name_map_.clear();
126129
expr_name_counter_ = 0;
130+
per_fusion_vals_.clear();
131+
per_fusion_exprs_.clear();
127132
}
128133

129134
bool IrContainer::inContainer(const Statement* const_stmt) const {
@@ -183,4 +188,130 @@ const std::unordered_set<Fusion*>& IrContainer::sharingFusions() const {
183188
return sharing_fusions_;
184189
}
185190

191+
const std::unordered_set<Val*>& IrContainer::valsOwnedBy(
192+
const Fusion* fusion) const {
193+
static const std::unordered_set<Val*> empty;
194+
auto it = per_fusion_vals_.find(fusion);
195+
return it != per_fusion_vals_.end() ? it->second : empty;
196+
}
197+
198+
const std::unordered_set<Expr*>& IrContainer::exprsOwnedBy(
199+
const Fusion* fusion) const {
200+
static const std::unordered_set<Expr*> empty;
201+
auto it = per_fusion_exprs_.find(fusion);
202+
return it != per_fusion_exprs_.end() ? it->second : empty;
203+
}
204+
205+
void IrContainer::transferStatementOwnership(
206+
const Fusion* from,
207+
const Fusion* to) {
208+
auto vals_it = per_fusion_vals_.find(from);
209+
if (vals_it != per_fusion_vals_.end()) {
210+
auto& to_vals = per_fusion_vals_[to];
211+
to_vals.insert(vals_it->second.begin(), vals_it->second.end());
212+
per_fusion_vals_.erase(vals_it);
213+
}
214+
215+
auto exprs_it = per_fusion_exprs_.find(from);
216+
if (exprs_it != per_fusion_exprs_.end()) {
217+
auto& to_exprs = per_fusion_exprs_[to];
218+
to_exprs.insert(exprs_it->second.begin(), exprs_it->second.end());
219+
per_fusion_exprs_.erase(exprs_it);
220+
}
221+
}
222+
223+
void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) {
224+
auto vals_it = per_fusion_vals_.find(fusion);
225+
if (vals_it != per_fusion_vals_.end()) {
226+
for (auto it = vals_up_.begin(); it != vals_up_.end();) {
227+
if (vals_it->second.count(it->get()) > 0) {
228+
vals_.erase(it->get());
229+
it = vals_up_.erase(it);
230+
} else {
231+
++it;
232+
}
233+
}
234+
per_fusion_vals_.erase(vals_it);
235+
}
236+
237+
auto exprs_it = per_fusion_exprs_.find(fusion);
238+
if (exprs_it != per_fusion_exprs_.end()) {
239+
for (auto it = exprs_up_.begin(); it != exprs_up_.end();) {
240+
if (exprs_it->second.count(it->get()) > 0) {
241+
exprs_.erase(it->get());
242+
it = exprs_up_.erase(it);
243+
} else {
244+
++it;
245+
}
246+
}
247+
per_fusion_exprs_.erase(exprs_it);
248+
}
249+
}
250+
251+
std::deque<Val*> IrContainer::deterministicValsOwnedBy(
252+
const Fusion* fusion) const noexcept {
253+
std::deque<Val*> result;
254+
auto it = per_fusion_vals_.find(fusion);
255+
if (it == per_fusion_vals_.end()) {
256+
return result;
257+
}
258+
const auto& owned = it->second;
259+
for (const auto& val_up : vals_up_) {
260+
if (owned.count(val_up.get()) > 0) {
261+
result.push_back(val_up.get());
262+
}
263+
}
264+
return result;
265+
}
266+
267+
std::deque<Expr*> IrContainer::deterministicExprsOwnedBy(
268+
const Fusion* fusion) const noexcept {
269+
std::deque<Expr*> result;
270+
auto it = per_fusion_exprs_.find(fusion);
271+
if (it == per_fusion_exprs_.end()) {
272+
return result;
273+
}
274+
const auto& owned = it->second;
275+
for (const auto& expr_up : exprs_up_) {
276+
if (owned.count(expr_up.get()) > 0) {
277+
result.push_back(expr_up.get());
278+
}
279+
}
280+
return result;
281+
}
282+
283+
std::unordered_map<Val*, int64_t> IrContainer::deterministicValsMapOwnedBy(
284+
const Fusion* fusion) const noexcept {
285+
std::unordered_map<Val*, int64_t> result;
286+
auto it = per_fusion_vals_.find(fusion);
287+
if (it == per_fusion_vals_.end()) {
288+
return result;
289+
}
290+
const auto& owned = it->second;
291+
int64_t count = 0;
292+
for (const auto& val_up : vals_up_) {
293+
if (owned.count(val_up.get()) > 0) {
294+
result[val_up.get()] = count++;
295+
}
296+
}
297+
return result;
298+
}
299+
300+
std::unordered_map<Expr*, int64_t> IrContainer::deterministicExprsMapOwnedBy(
301+
const Fusion* fusion) const noexcept {
302+
std::unordered_map<Expr*, int64_t> result;
303+
auto it = per_fusion_exprs_.find(fusion);
304+
if (it == per_fusion_exprs_.end()) {
305+
return result;
306+
}
307+
const auto& owned = it->second;
308+
int64_t count = 0;
309+
for (const auto& expr_up : exprs_up_) {
310+
if (owned.count(expr_up.get()) > 0) {
311+
result[expr_up.get()] = count++;
312+
}
313+
}
314+
return result;
315+
}
316+
186317
} // namespace nvfuser

csrc/ir/container.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,26 @@ class IrContainer {
137137
bool hasMultipleFusions() const;
138138
const std::unordered_set<Fusion*>& sharingFusions() const;
139139

140+
NVF_API const std::unordered_set<Val*>& valsOwnedBy(
141+
const Fusion* fusion) const;
142+
const std::unordered_set<Expr*>& exprsOwnedBy(const Fusion* fusion) const;
143+
void transferStatementOwnership(const Fusion* from, const Fusion* to);
144+
void removeStatementsOwnedBy(const Fusion* fusion);
145+
146+
std::deque<Val*> deterministicValsOwnedBy(
147+
const Fusion* fusion) const noexcept;
148+
std::deque<Expr*> deterministicExprsOwnedBy(
149+
const Fusion* fusion) const noexcept;
150+
std::unordered_map<Val*, int64_t> deterministicValsMapOwnedBy(
151+
const Fusion* fusion) const noexcept;
152+
std::unordered_map<Expr*, int64_t> deterministicExprsMapOwnedBy(
153+
const Fusion* fusion) const noexcept;
154+
140155
private:
141156
std::unordered_set<Fusion*> sharing_fusions_;
157+
std::unordered_map<const Fusion*, std::unordered_set<Val*>> per_fusion_vals_;
158+
std::unordered_map<const Fusion*, std::unordered_set<Expr*>>
159+
per_fusion_exprs_;
142160
};
143161

144162
} // namespace nvfuser

0 commit comments

Comments
 (0)