Skip to content

Commit 9737ff1

Browse files
committed
Per-Fusion name counters fix duplicate TV names after copy
Move val/expr name counters from IrContainer to Fusion so each Fusion independently tracks name assignment. This fixes CI failures where Fusion::copy left the dest counter at N (number of cloned vals) instead of max(name)+1 when source names were non-sequential, causing newly created TVs to collide with existing names. The fix adds val_type_name_map_ and expr_name_counter_ to Fusion, and updates registerVal/registerExpr to use the Fusion-level counters. Fusion::copy syncs counters from source to dest after cloning. Fusion::swap exchanges counters. Fusion::clear resets them.
1 parent be25eaa commit 9737ff1

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

csrc/fusion.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
157157
std::swap(a.magic_zero_val_, b.magic_zero_val_);
158158
std::swap(a.axioms_, b.axioms_);
159159
std::swap(a.metadata_, b.metadata_);
160+
std::swap(a.val_type_name_map_, b.val_type_name_map_);
161+
std::swap(a.expr_name_counter_, b.expr_name_counter_);
160162

161163
// Update Statement::ir_container_ pointers: a's old statements now belong
162164
// to b, and b's old statements now belong to a
@@ -208,6 +210,16 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
208210
ir_cloner.clone(val);
209211
}
210212

213+
// Sync per-Fusion name counters from source to dest.
214+
// During cloning, registerVal increments the dest Fusion's counter for each
215+
// val, then IrBuilder::clone overrides the name with setName(src->name()).
216+
// If source names are non-sequential (e.g., {0..10, 22..27} from segmenter
217+
// creating intermediate TVs), the dest counter ends up at N (number of vals)
218+
// instead of max(name)+1. Copying the source's counter state ensures new
219+
// vals created post-copy won't collide with existing names.
220+
to->val_type_name_map_ = from->val_type_name_map_;
221+
to->expr_name_counter_ = from->expr_name_counter_;
222+
211223
// Wire up definitions and uses on cloned vals
212224
for (auto val : from->vals()) {
213225
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
@@ -366,6 +378,9 @@ void Fusion::clear() noexcept {
366378
axioms_.reset();
367379
metadata_.clear();
368380

381+
val_type_name_map_.clear();
382+
expr_name_counter_ = 0;
383+
369384
invalidateTvsAndUses();
370385

371386
is_during_update_uses_ = false;
@@ -975,7 +990,7 @@ void Fusion::registerVal(Val* val) {
975990
c->vals_up_.emplace_back(val);
976991
c->vals_.insert(val);
977992
c->per_fusion_vals_[this].insert(val);
978-
val->setName(IrContainerPasskey(), c->getValName(val->vtype()));
993+
val->setName(IrContainerPasskey(), getValName(val->vtype()));
979994
}
980995

981996
void Fusion::registerExpr(Expr* expr) {
@@ -992,7 +1007,7 @@ void Fusion::registerExpr(Expr* expr) {
9921007
c->exprs_up_.emplace_back(expr);
9931008
c->exprs_.insert(expr);
9941009
c->per_fusion_exprs_[this].insert(expr);
995-
expr->setName(IrContainerPasskey(), c->getExprName());
1010+
expr->setName(IrContainerPasskey(), getExprName());
9961011

9971012
for (Val* input : expr->inputs()) {
9981013
assertInContainer(input, "Input to expr is invalid, ");

csrc/fusion.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,24 @@ class NVF_API Fusion : public PolymorphicBase {
660660
std::unique_ptr<std::vector<Val*>> axioms_;
661661

662662
std::unordered_map<Val*, std::pair<Val*, Expr*>> metadata_;
663+
664+
// Per-Fusion name counters. Each Fusion independently tracks name assignment
665+
// so that cloned Fusions get matching names (T0→T0) regardless of whether
666+
// they share an IrContainer. This is required by downstream consumers that
667+
// use tv->name() as a map key (alias_memory, GreedyParams, etc.).
668+
std::unordered_map<ValType, StmtNameType> val_type_name_map_;
669+
StmtNameType expr_name_counter_ = 0;
670+
671+
StmtNameType getValName(ValType vtype) {
672+
if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) {
673+
val_type_name_map_[vtype] = 0;
674+
}
675+
return val_type_name_map_[vtype]++;
676+
}
677+
678+
StmtNameType getExprName() {
679+
return expr_name_counter_++;
680+
}
663681
};
664682

665683
// Template implementations for Fusion::manage<T>() that use IrCloner

0 commit comments

Comments
 (0)