Skip to content

Commit 9143881

Browse files
committed
Per-Fusion name counters fix duplicate TV names after copy
Move val/expr name counters from IrContainer to Fusion so each Fusion independently tracks name assignment. This fixes CI failures where Fusion::copy left the dest counter at N (number of cloned vals) instead of max(name)+1 when source names were non-sequential, causing newly created TVs to collide with existing names. The fix adds val_type_name_map_ and expr_name_counter_ to Fusion, and updates registerVal/registerExpr to use the Fusion-level counters. Fusion::copy syncs counters from source to dest after cloning. Fusion::swap exchanges counters. Fusion::clear resets them.
1 parent 9ae372a commit 9143881

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

csrc/fusion.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
158158
std::swap(a.magic_zero_val_, b.magic_zero_val_);
159159
std::swap(a.axioms_, b.axioms_);
160160
std::swap(a.metadata_, b.metadata_);
161+
std::swap(a.val_type_name_map_, b.val_type_name_map_);
162+
std::swap(a.expr_name_counter_, b.expr_name_counter_);
161163

162164
// Update Statement::ir_container_ pointers: a's old statements now belong
163165
// to b, and b's old statements now belong to a
@@ -209,6 +211,16 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
209211
ir_cloner.clone(val);
210212
}
211213

214+
// Sync per-Fusion name counters from source to dest.
215+
// During cloning, registerVal increments the dest Fusion's counter for each
216+
// val, then IrBuilder::clone overrides the name with setName(src->name()).
217+
// If source names are non-sequential (e.g., {0..10, 22..27} from segmenter
218+
// creating intermediate TVs), the dest counter ends up at N (number of vals)
219+
// instead of max(name)+1. Copying the source's counter state ensures new
220+
// vals created post-copy won't collide with existing names.
221+
to->val_type_name_map_ = from->val_type_name_map_;
222+
to->expr_name_counter_ = from->expr_name_counter_;
223+
212224
// Wire up definitions and uses on cloned vals
213225
for (auto val : from->vals()) {
214226
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
@@ -367,6 +379,9 @@ void Fusion::clear() noexcept {
367379
axioms_.reset();
368380
metadata_.clear();
369381

382+
val_type_name_map_.clear();
383+
expr_name_counter_ = 0;
384+
370385
invalidateTvsAndUses();
371386

372387
is_during_update_uses_ = false;
@@ -972,7 +987,7 @@ void Fusion::registerVal(Val* val) {
972987
c->vals_up_.emplace_back(val);
973988
c->vals_.insert(val);
974989
c->per_fusion_vals_[this].insert(val);
975-
val->setName(IrContainerPasskey(), c->getValName(val->vtype()));
990+
val->setName(IrContainerPasskey(), getValName(val->vtype()));
976991
}
977992

978993
void Fusion::registerExpr(Expr* expr) {
@@ -989,7 +1004,7 @@ void Fusion::registerExpr(Expr* expr) {
9891004
c->exprs_up_.emplace_back(expr);
9901005
c->exprs_.insert(expr);
9911006
c->per_fusion_exprs_[this].insert(expr);
992-
expr->setName(IrContainerPasskey(), c->getExprName());
1007+
expr->setName(IrContainerPasskey(), getExprName());
9931008

9941009
for (Val* input : expr->inputs()) {
9951010
assertInContainer(input, "Input to expr is invalid, ");

csrc/fusion.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,24 @@ class NVF_API Fusion : public PolymorphicBase {
661661
std::unique_ptr<std::vector<Val*>> axioms_;
662662

663663
std::unordered_map<Val*, std::pair<Val*, Expr*>> metadata_;
664+
665+
// Per-Fusion name counters. Each Fusion independently tracks name assignment
666+
// so that cloned Fusions get matching names (T0→T0) regardless of whether
667+
// they share an IrContainer. This is required by downstream consumers that
668+
// use tv->name() as a map key (alias_memory, GreedyParams, etc.).
669+
std::unordered_map<ValType, StmtNameType> val_type_name_map_;
670+
StmtNameType expr_name_counter_ = 0;
671+
672+
StmtNameType getValName(ValType vtype) {
673+
if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) {
674+
val_type_name_map_[vtype] = 0;
675+
}
676+
return val_type_name_map_[vtype]++;
677+
}
678+
679+
StmtNameType getExprName() {
680+
return expr_name_counter_++;
681+
}
664682
};
665683

666684
// Template implementations for Fusion::manage<T>() that use IrCloner

0 commit comments

Comments
 (0)