Skip to content

Commit b8d202d

Browse files
committed
Per-Fusion statement tracking and ownership-filtered accessors
Add per_fusion_vals_ / per_fusion_exprs_ maps to IrContainer so each Fusion can efficiently query only its own statements in a shared container. Fusion forwarding methods (vals(), unordered_exprs(), deterministic_vals(), etc.) now return per-Fusion filtered results. Fusion::clear() uses removeStatementsOwnedBy(this) instead of ir_container()->clear().
1 parent c965408 commit b8d202d

File tree

4 files changed

+173
-22
lines changed

4 files changed

+173
-22
lines changed

csrc/fusion.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,11 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
113113

114114
// After swapping container contents, update Statement::ir_container_
115115
// pointers so each Statement points to the Fusion whose container now
116-
// holds it.
116+
// holds it. Also fix per-Fusion tracking keys since a's container had
117+
// b's entries and vice versa.
118+
a.ir_container()->transferStatementOwnership(&b, &a);
119+
b.ir_container()->transferStatementOwnership(&a, &b);
120+
117121
if (a.ir_container_) {
118122
for (auto val : a.vals()) {
119123
val->ir_container_ = &a;
@@ -294,10 +298,9 @@ void Fusion::clear() noexcept {
294298
// constructor of Trace, which could throw an exception.
295299
// FUSER_PERF_SCOPE("Fusion clear");
296300

297-
// Clear container contents instead of destroying it
298-
// This preserves the container object so Statement pointers don't become
299-
// dangling
300-
ir_container()->clear();
301+
if (ir_container_) {
302+
ir_container_->removeStatementsOwnedBy(this);
303+
}
301304

302305
inputs_.clear();
303306
outputs_.clear();
@@ -307,8 +310,6 @@ void Fusion::clear() noexcept {
307310
managed_data_.clear();
308311
managed_named_data_.clear();
309312

310-
// Reset per-Fusion special value caches (the vals themselves are owned by
311-
// ir_container and were already destroyed by ir_container()->clear() above).
312313
zero_val_ = nullptr;
313314
one_val_ = nullptr;
314315
true_val_ = nullptr;
@@ -353,6 +354,7 @@ void Fusion::removeExpr(Expr* expr) {
353354
NVF_ERROR(
354355
expr_in_deque != c->exprs_up_.end(),
355356
"Wanted to remove an expression but its unique ptr is missing.");
357+
c->per_fusion_exprs_[this].erase(expr);
356358
c->exprs_.erase(expr);
357359
c->exprs_up_.erase(expr_in_deque);
358360
}
@@ -413,6 +415,7 @@ void Fusion::removeVal(Val* val) {
413415
NVF_ERROR(
414416
val_in_deque != c->vals_up_.end(),
415417
"Wanted to remove a value but its unique ptr is missing.");
418+
c->per_fusion_vals_[this].erase(val);
416419
c->vals_.erase(val);
417420
c->vals_up_.erase(val_in_deque);
418421

@@ -445,13 +448,11 @@ void Fusion::removeStatementsCreatedAfter(
445448
for (Val* in : e->inputs()) {
446449
in->removeUse(e);
447450
}
451+
c->per_fusion_exprs_[this].erase(e);
448452
c->exprs_.erase(e);
449453
c->exprs_up_.pop_back();
450454
}
451455

452-
// Null out any special value caches that point to vals about to be destroyed.
453-
// This prevents dangling pointers when special vals are lazily created inside
454-
// a StatementGuard scope.
455456
while (std::ssize(c->vals_up_) > num_vals_before) {
456457
Val* v = c->vals_up_.back().get();
457458
if (v == zero_val_) {
@@ -465,6 +466,7 @@ void Fusion::removeStatementsCreatedAfter(
465466
} else if (v == magic_zero_val_) {
466467
magic_zero_val_ = nullptr;
467468
}
469+
c->per_fusion_vals_[this].erase(v);
468470
c->vals_.erase(v);
469471
c->vals_up_.pop_back();
470472
}
@@ -931,6 +933,7 @@ void Fusion::registerVal(Val* val) {
931933
auto* c = ir_container();
932934
c->vals_up_.emplace_back(val);
933935
c->vals_.insert(val);
936+
c->per_fusion_vals_[this].insert(val);
934937
val->setName(IrContainerPasskey(), c->getValName(val->vtype()));
935938
}
936939

@@ -947,6 +950,7 @@ void Fusion::registerExpr(Expr* expr) {
947950
auto* c = ir_container();
948951
c->exprs_up_.emplace_back(expr);
949952
c->exprs_.insert(expr);
953+
c->per_fusion_exprs_[this].insert(expr);
950954
expr->setName(IrContainerPasskey(), c->getExprName());
951955

952956
for (Val* input : expr->inputs()) {

csrc/fusion.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -521,31 +521,29 @@ class NVF_API Fusion : public PolymorphicBase {
521521
}
522522

523523
// Collections access (return values in insertion order)
524-
const std::deque<Val*> deterministic_vals() const noexcept {
525-
return ir_container()->deterministic_vals();
524+
std::deque<Val*> deterministic_vals() const noexcept {
525+
return ir_container()->deterministicValsOwnedBy(this);
526526
}
527527

528-
const std::deque<Expr*> deterministic_exprs() const noexcept {
529-
return ir_container()->deterministic_exprs();
528+
std::deque<Expr*> deterministic_exprs() const noexcept {
529+
return ir_container()->deterministicExprsOwnedBy(this);
530530
}
531531

532-
const std::unordered_map<Val*, int64_t> deterministic_vals_map()
533-
const noexcept {
534-
return ir_container()->deterministic_vals_map();
532+
std::unordered_map<Val*, int64_t> deterministic_vals_map() const noexcept {
533+
return ir_container()->deterministicValsMapOwnedBy(this);
535534
}
536535

537-
const std::unordered_map<Expr*, int64_t> deterministic_exprs_map()
538-
const noexcept {
539-
return ir_container()->deterministic_exprs_map();
536+
std::unordered_map<Expr*, int64_t> deterministic_exprs_map() const noexcept {
537+
return ir_container()->deterministicExprsMapOwnedBy(this);
540538
}
541539

542540
// Collections access (unordered sets)
543541
const std::unordered_set<Expr*>& unordered_exprs() const noexcept {
544-
return ir_container()->unordered_exprs();
542+
return ir_container()->exprsOwnedBy(this);
545543
}
546544

547545
const std::unordered_set<Val*>& vals() const noexcept {
548-
return ir_container()->vals();
546+
return ir_container()->valsOwnedBy(this);
549547
}
550548

551549
// Count queries

csrc/ir/container.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept {
8080

8181
std::swap(a.val_type_name_map_, b.val_type_name_map_);
8282
std::swap(a.expr_name_counter_, b.expr_name_counter_);
83+
84+
std::swap(a.per_fusion_vals_, b.per_fusion_vals_);
85+
std::swap(a.per_fusion_exprs_, b.per_fusion_exprs_);
8386
}
8487

8588
IrCloner IrContainer::copy(
@@ -124,6 +127,8 @@ void IrContainer::clear() noexcept {
124127
exprs_up_.clear();
125128
val_type_name_map_.clear();
126129
expr_name_counter_ = 0;
130+
per_fusion_vals_.clear();
131+
per_fusion_exprs_.clear();
127132
}
128133

129134
bool IrContainer::inContainer(const Statement* const_stmt) const {
@@ -183,4 +188,130 @@ const std::unordered_set<Fusion*>& IrContainer::sharingFusions() const {
183188
return sharing_fusions_;
184189
}
185190

191+
const std::unordered_set<Val*>& IrContainer::valsOwnedBy(
192+
const Fusion* fusion) const {
193+
static const std::unordered_set<Val*> empty;
194+
auto it = per_fusion_vals_.find(fusion);
195+
return it != per_fusion_vals_.end() ? it->second : empty;
196+
}
197+
198+
const std::unordered_set<Expr*>& IrContainer::exprsOwnedBy(
199+
const Fusion* fusion) const {
200+
static const std::unordered_set<Expr*> empty;
201+
auto it = per_fusion_exprs_.find(fusion);
202+
return it != per_fusion_exprs_.end() ? it->second : empty;
203+
}
204+
205+
void IrContainer::transferStatementOwnership(
206+
const Fusion* from,
207+
const Fusion* to) {
208+
auto vals_it = per_fusion_vals_.find(from);
209+
if (vals_it != per_fusion_vals_.end()) {
210+
auto& to_vals = per_fusion_vals_[to];
211+
to_vals.insert(vals_it->second.begin(), vals_it->second.end());
212+
per_fusion_vals_.erase(vals_it);
213+
}
214+
215+
auto exprs_it = per_fusion_exprs_.find(from);
216+
if (exprs_it != per_fusion_exprs_.end()) {
217+
auto& to_exprs = per_fusion_exprs_[to];
218+
to_exprs.insert(exprs_it->second.begin(), exprs_it->second.end());
219+
per_fusion_exprs_.erase(exprs_it);
220+
}
221+
}
222+
223+
void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) {
224+
auto vals_it = per_fusion_vals_.find(fusion);
225+
if (vals_it != per_fusion_vals_.end()) {
226+
for (auto it = vals_up_.begin(); it != vals_up_.end();) {
227+
if (vals_it->second.count(it->get()) > 0) {
228+
vals_.erase(it->get());
229+
it = vals_up_.erase(it);
230+
} else {
231+
++it;
232+
}
233+
}
234+
per_fusion_vals_.erase(vals_it);
235+
}
236+
237+
auto exprs_it = per_fusion_exprs_.find(fusion);
238+
if (exprs_it != per_fusion_exprs_.end()) {
239+
for (auto it = exprs_up_.begin(); it != exprs_up_.end();) {
240+
if (exprs_it->second.count(it->get()) > 0) {
241+
exprs_.erase(it->get());
242+
it = exprs_up_.erase(it);
243+
} else {
244+
++it;
245+
}
246+
}
247+
per_fusion_exprs_.erase(exprs_it);
248+
}
249+
}
250+
251+
std::deque<Val*> IrContainer::deterministicValsOwnedBy(
252+
const Fusion* fusion) const noexcept {
253+
std::deque<Val*> result;
254+
auto it = per_fusion_vals_.find(fusion);
255+
if (it == per_fusion_vals_.end()) {
256+
return result;
257+
}
258+
const auto& owned = it->second;
259+
for (const auto& val_up : vals_up_) {
260+
if (owned.count(val_up.get()) > 0) {
261+
result.push_back(val_up.get());
262+
}
263+
}
264+
return result;
265+
}
266+
267+
std::deque<Expr*> IrContainer::deterministicExprsOwnedBy(
268+
const Fusion* fusion) const noexcept {
269+
std::deque<Expr*> result;
270+
auto it = per_fusion_exprs_.find(fusion);
271+
if (it == per_fusion_exprs_.end()) {
272+
return result;
273+
}
274+
const auto& owned = it->second;
275+
for (const auto& expr_up : exprs_up_) {
276+
if (owned.count(expr_up.get()) > 0) {
277+
result.push_back(expr_up.get());
278+
}
279+
}
280+
return result;
281+
}
282+
283+
std::unordered_map<Val*, int64_t> IrContainer::deterministicValsMapOwnedBy(
284+
const Fusion* fusion) const noexcept {
285+
std::unordered_map<Val*, int64_t> result;
286+
auto it = per_fusion_vals_.find(fusion);
287+
if (it == per_fusion_vals_.end()) {
288+
return result;
289+
}
290+
const auto& owned = it->second;
291+
int64_t count = 0;
292+
for (const auto& val_up : vals_up_) {
293+
if (owned.count(val_up.get()) > 0) {
294+
result[val_up.get()] = count++;
295+
}
296+
}
297+
return result;
298+
}
299+
300+
std::unordered_map<Expr*, int64_t> IrContainer::deterministicExprsMapOwnedBy(
301+
const Fusion* fusion) const noexcept {
302+
std::unordered_map<Expr*, int64_t> result;
303+
auto it = per_fusion_exprs_.find(fusion);
304+
if (it == per_fusion_exprs_.end()) {
305+
return result;
306+
}
307+
const auto& owned = it->second;
308+
int64_t count = 0;
309+
for (const auto& expr_up : exprs_up_) {
310+
if (owned.count(expr_up.get()) > 0) {
311+
result[expr_up.get()] = count++;
312+
}
313+
}
314+
return result;
315+
}
316+
186317
} // namespace nvfuser

csrc/ir/container.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,26 @@ class IrContainer {
137137
bool hasMultipleFusions() const;
138138
const std::unordered_set<Fusion*>& sharingFusions() const;
139139

140+
NVF_API const std::unordered_set<Val*>& valsOwnedBy(
141+
const Fusion* fusion) const;
142+
const std::unordered_set<Expr*>& exprsOwnedBy(const Fusion* fusion) const;
143+
void transferStatementOwnership(const Fusion* from, const Fusion* to);
144+
void removeStatementsOwnedBy(const Fusion* fusion);
145+
146+
std::deque<Val*> deterministicValsOwnedBy(
147+
const Fusion* fusion) const noexcept;
148+
std::deque<Expr*> deterministicExprsOwnedBy(
149+
const Fusion* fusion) const noexcept;
150+
std::unordered_map<Val*, int64_t> deterministicValsMapOwnedBy(
151+
const Fusion* fusion) const noexcept;
152+
std::unordered_map<Expr*, int64_t> deterministicExprsMapOwnedBy(
153+
const Fusion* fusion) const noexcept;
154+
140155
private:
141156
std::unordered_set<Fusion*> sharing_fusions_;
157+
std::unordered_map<const Fusion*, std::unordered_set<Val*>> per_fusion_vals_;
158+
std::unordered_map<const Fusion*, std::unordered_set<Expr*>>
159+
per_fusion_exprs_;
142160
};
143161

144162
} // namespace nvfuser

0 commit comments

Comments
 (0)