Skip to content

Commit 0bfea77

Browse files
committed
Per-Fusion statement tracking and ownership-filtered accessors
Add per_fusion_vals_ / per_fusion_exprs_ maps to IrContainer so each Fusion can efficiently query only its own statements in a shared container. Fusion forwarding methods (vals(), unordered_exprs(), deterministic_vals(), etc.) now return per-Fusion filtered results. Fusion::clear() uses removeStatementsOwnedBy(this) instead of ir_container()->clear().
1 parent dcc99d7 commit 0bfea77

File tree

4 files changed

+173
-22
lines changed

4 files changed

+173
-22
lines changed

csrc/fusion.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
114114

115115
// After swapping container contents, update Statement::ir_container_
116116
// pointers so each Statement points to the Fusion whose container now
117-
// holds it.
117+
// holds it. Also fix per-Fusion tracking keys since a's container had
118+
// b's entries and vice versa.
119+
a.ir_container()->transferStatementOwnership(&b, &a);
120+
b.ir_container()->transferStatementOwnership(&a, &b);
121+
118122
if (a.ir_container_) {
119123
for (auto val : a.vals()) {
120124
val->ir_container_ = &a;
@@ -295,10 +299,9 @@ void Fusion::clear() noexcept {
295299
// constructor of Trace, which could throw an exception.
296300
// FUSER_PERF_SCOPE("Fusion clear");
297301

298-
// Clear container contents instead of destroying it
299-
// This preserves the container object so Statement pointers don't become
300-
// dangling
301-
ir_container()->clear();
302+
if (ir_container_) {
303+
ir_container_->removeStatementsOwnedBy(this);
304+
}
302305

303306
inputs_.clear();
304307
outputs_.clear();
@@ -308,8 +311,6 @@ void Fusion::clear() noexcept {
308311
managed_data_.clear();
309312
managed_named_data_.clear();
310313

311-
// Reset per-Fusion special value caches (the vals themselves are owned by
312-
// ir_container and were already destroyed by ir_container()->clear() above).
313314
zero_val_ = nullptr;
314315
one_val_ = nullptr;
315316
true_val_ = nullptr;
@@ -353,6 +354,7 @@ void Fusion::removeExpr(Expr* expr) {
353354
NVF_ERROR(
354355
expr_in_deque != c->exprs_up_.end(),
355356
"Wanted to remove an expression but its unique ptr is missing.");
357+
c->per_fusion_exprs_[this].erase(expr);
356358
c->exprs_.erase(expr);
357359
c->exprs_up_.erase(expr_in_deque);
358360
}
@@ -412,6 +414,7 @@ void Fusion::removeVal(Val* val) {
412414
NVF_ERROR(
413415
val_in_deque != c->vals_up_.end(),
414416
"Wanted to remove a value but its unique ptr is missing.");
417+
c->per_fusion_vals_[this].erase(val);
415418
c->vals_.erase(val);
416419
c->vals_up_.erase(val_in_deque);
417420

@@ -444,13 +447,11 @@ void Fusion::removeStatementsCreatedAfter(
444447
for (Val* in : e->inputs()) {
445448
in->removeUse(e);
446449
}
450+
c->per_fusion_exprs_[this].erase(e);
447451
c->exprs_.erase(e);
448452
c->exprs_up_.pop_back();
449453
}
450454

451-
// Null out any special value caches that point to vals about to be destroyed.
452-
// This prevents dangling pointers when special vals are lazily created inside
453-
// a StatementGuard scope.
454455
while (std::ssize(c->vals_up_) > num_vals_before) {
455456
Val* v = c->vals_up_.back().get();
456457
if (v == zero_val_) {
@@ -464,6 +465,7 @@ void Fusion::removeStatementsCreatedAfter(
464465
} else if (v == magic_zero_val_) {
465466
magic_zero_val_ = nullptr;
466467
}
468+
c->per_fusion_vals_[this].erase(v);
467469
c->vals_.erase(v);
468470
c->vals_up_.pop_back();
469471
}
@@ -927,6 +929,7 @@ void Fusion::registerVal(Val* val) {
927929
auto* c = ir_container();
928930
c->vals_up_.emplace_back(val);
929931
c->vals_.insert(val);
932+
c->per_fusion_vals_[this].insert(val);
930933
val->setName(IrContainerPasskey(), c->getValName(val->vtype()));
931934
}
932935

@@ -943,6 +946,7 @@ void Fusion::registerExpr(Expr* expr) {
943946
auto* c = ir_container();
944947
c->exprs_up_.emplace_back(expr);
945948
c->exprs_.insert(expr);
949+
c->per_fusion_exprs_[this].insert(expr);
946950
expr->setName(IrContainerPasskey(), c->getExprName());
947951

948952
for (Val* input : expr->inputs()) {

csrc/fusion.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -522,31 +522,29 @@ class NVF_API Fusion : public PolymorphicBase {
522522
}
523523

524524
// Collections access (return values in insertion order)
525-
const std::deque<Val*> deterministic_vals() const noexcept {
526-
return ir_container()->deterministic_vals();
525+
std::deque<Val*> deterministic_vals() const noexcept {
526+
return ir_container()->deterministicValsOwnedBy(this);
527527
}
528528

529-
const std::deque<Expr*> deterministic_exprs() const noexcept {
530-
return ir_container()->deterministic_exprs();
529+
std::deque<Expr*> deterministic_exprs() const noexcept {
530+
return ir_container()->deterministicExprsOwnedBy(this);
531531
}
532532

533-
const std::unordered_map<Val*, int64_t> deterministic_vals_map()
534-
const noexcept {
535-
return ir_container()->deterministic_vals_map();
533+
std::unordered_map<Val*, int64_t> deterministic_vals_map() const noexcept {
534+
return ir_container()->deterministicValsMapOwnedBy(this);
536535
}
537536

538-
const std::unordered_map<Expr*, int64_t> deterministic_exprs_map()
539-
const noexcept {
540-
return ir_container()->deterministic_exprs_map();
537+
std::unordered_map<Expr*, int64_t> deterministic_exprs_map() const noexcept {
538+
return ir_container()->deterministicExprsMapOwnedBy(this);
541539
}
542540

543541
// Collections access (unordered sets)
544542
const std::unordered_set<Expr*>& unordered_exprs() const noexcept {
545-
return ir_container()->unordered_exprs();
543+
return ir_container()->exprsOwnedBy(this);
546544
}
547545

548546
const std::unordered_set<Val*>& vals() const noexcept {
549-
return ir_container()->vals();
547+
return ir_container()->valsOwnedBy(this);
550548
}
551549

552550
// Count queries

csrc/ir/container.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept {
7575

7676
std::swap(a.val_type_name_map_, b.val_type_name_map_);
7777
std::swap(a.expr_name_counter_, b.expr_name_counter_);
78+
79+
std::swap(a.per_fusion_vals_, b.per_fusion_vals_);
80+
std::swap(a.per_fusion_exprs_, b.per_fusion_exprs_);
7881
}
7982

8083
IrCloner IrContainer::copy(
@@ -119,6 +122,8 @@ void IrContainer::clear() noexcept {
119122
exprs_up_.clear();
120123
val_type_name_map_.clear();
121124
expr_name_counter_ = 0;
125+
per_fusion_vals_.clear();
126+
per_fusion_exprs_.clear();
122127
}
123128

124129
bool IrContainer::inContainer(const Statement* const_stmt) const {
@@ -178,4 +183,130 @@ const std::unordered_set<Fusion*>& IrContainer::sharingFusions() const {
178183
return sharing_fusions_;
179184
}
180185

186+
const std::unordered_set<Val*>& IrContainer::valsOwnedBy(
187+
const Fusion* fusion) const {
188+
static const std::unordered_set<Val*> empty;
189+
auto it = per_fusion_vals_.find(fusion);
190+
return it != per_fusion_vals_.end() ? it->second : empty;
191+
}
192+
193+
const std::unordered_set<Expr*>& IrContainer::exprsOwnedBy(
194+
const Fusion* fusion) const {
195+
static const std::unordered_set<Expr*> empty;
196+
auto it = per_fusion_exprs_.find(fusion);
197+
return it != per_fusion_exprs_.end() ? it->second : empty;
198+
}
199+
200+
void IrContainer::transferStatementOwnership(
201+
const Fusion* from,
202+
const Fusion* to) {
203+
auto vals_it = per_fusion_vals_.find(from);
204+
if (vals_it != per_fusion_vals_.end()) {
205+
auto& to_vals = per_fusion_vals_[to];
206+
to_vals.insert(vals_it->second.begin(), vals_it->second.end());
207+
per_fusion_vals_.erase(vals_it);
208+
}
209+
210+
auto exprs_it = per_fusion_exprs_.find(from);
211+
if (exprs_it != per_fusion_exprs_.end()) {
212+
auto& to_exprs = per_fusion_exprs_[to];
213+
to_exprs.insert(exprs_it->second.begin(), exprs_it->second.end());
214+
per_fusion_exprs_.erase(exprs_it);
215+
}
216+
}
217+
218+
void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) {
219+
auto vals_it = per_fusion_vals_.find(fusion);
220+
if (vals_it != per_fusion_vals_.end()) {
221+
for (auto it = vals_up_.begin(); it != vals_up_.end();) {
222+
if (vals_it->second.count(it->get()) > 0) {
223+
vals_.erase(it->get());
224+
it = vals_up_.erase(it);
225+
} else {
226+
++it;
227+
}
228+
}
229+
per_fusion_vals_.erase(vals_it);
230+
}
231+
232+
auto exprs_it = per_fusion_exprs_.find(fusion);
233+
if (exprs_it != per_fusion_exprs_.end()) {
234+
for (auto it = exprs_up_.begin(); it != exprs_up_.end();) {
235+
if (exprs_it->second.count(it->get()) > 0) {
236+
exprs_.erase(it->get());
237+
it = exprs_up_.erase(it);
238+
} else {
239+
++it;
240+
}
241+
}
242+
per_fusion_exprs_.erase(exprs_it);
243+
}
244+
}
245+
246+
std::deque<Val*> IrContainer::deterministicValsOwnedBy(
247+
const Fusion* fusion) const noexcept {
248+
std::deque<Val*> result;
249+
auto it = per_fusion_vals_.find(fusion);
250+
if (it == per_fusion_vals_.end()) {
251+
return result;
252+
}
253+
const auto& owned = it->second;
254+
for (const auto& val_up : vals_up_) {
255+
if (owned.count(val_up.get()) > 0) {
256+
result.push_back(val_up.get());
257+
}
258+
}
259+
return result;
260+
}
261+
262+
std::deque<Expr*> IrContainer::deterministicExprsOwnedBy(
263+
const Fusion* fusion) const noexcept {
264+
std::deque<Expr*> result;
265+
auto it = per_fusion_exprs_.find(fusion);
266+
if (it == per_fusion_exprs_.end()) {
267+
return result;
268+
}
269+
const auto& owned = it->second;
270+
for (const auto& expr_up : exprs_up_) {
271+
if (owned.count(expr_up.get()) > 0) {
272+
result.push_back(expr_up.get());
273+
}
274+
}
275+
return result;
276+
}
277+
278+
std::unordered_map<Val*, int64_t> IrContainer::deterministicValsMapOwnedBy(
279+
const Fusion* fusion) const noexcept {
280+
std::unordered_map<Val*, int64_t> result;
281+
auto it = per_fusion_vals_.find(fusion);
282+
if (it == per_fusion_vals_.end()) {
283+
return result;
284+
}
285+
const auto& owned = it->second;
286+
int64_t count = 0;
287+
for (const auto& val_up : vals_up_) {
288+
if (owned.count(val_up.get()) > 0) {
289+
result[val_up.get()] = count++;
290+
}
291+
}
292+
return result;
293+
}
294+
295+
std::unordered_map<Expr*, int64_t> IrContainer::deterministicExprsMapOwnedBy(
296+
const Fusion* fusion) const noexcept {
297+
std::unordered_map<Expr*, int64_t> result;
298+
auto it = per_fusion_exprs_.find(fusion);
299+
if (it == per_fusion_exprs_.end()) {
300+
return result;
301+
}
302+
const auto& owned = it->second;
303+
int64_t count = 0;
304+
for (const auto& expr_up : exprs_up_) {
305+
if (owned.count(expr_up.get()) > 0) {
306+
result[expr_up.get()] = count++;
307+
}
308+
}
309+
return result;
310+
}
311+
181312
} // namespace nvfuser

csrc/ir/container.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,26 @@ class IrContainer {
137137
bool hasMultipleFusions() const;
138138
const std::unordered_set<Fusion*>& sharingFusions() const;
139139

140+
NVF_API const std::unordered_set<Val*>& valsOwnedBy(
141+
const Fusion* fusion) const;
142+
const std::unordered_set<Expr*>& exprsOwnedBy(const Fusion* fusion) const;
143+
void transferStatementOwnership(const Fusion* from, const Fusion* to);
144+
void removeStatementsOwnedBy(const Fusion* fusion);
145+
146+
std::deque<Val*> deterministicValsOwnedBy(
147+
const Fusion* fusion) const noexcept;
148+
std::deque<Expr*> deterministicExprsOwnedBy(
149+
const Fusion* fusion) const noexcept;
150+
std::unordered_map<Val*, int64_t> deterministicValsMapOwnedBy(
151+
const Fusion* fusion) const noexcept;
152+
std::unordered_map<Expr*, int64_t> deterministicExprsMapOwnedBy(
153+
const Fusion* fusion) const noexcept;
154+
140155
private:
141156
std::unordered_set<Fusion*> sharing_fusions_;
157+
std::unordered_map<const Fusion*, std::unordered_set<Val*>> per_fusion_vals_;
158+
std::unordered_map<const Fusion*, std::unordered_set<Expr*>>
159+
per_fusion_exprs_;
142160
};
143161

144162
} // namespace nvfuser

0 commit comments

Comments
 (0)