@@ -108,48 +108,89 @@ bool Fusion::sameDefinition(const Fusion& other) const {
108108void Fusion::swap (Fusion& a, Fusion& b) noexcept {
109109 FUSER_PERF_SCOPE (" Fusion swap" );
110110
111- // We need to be careful to call IrContainer swap not unique_ptr swap, which
112- // will only swap the ptrs NOT the contents.
113- IrContainer::swap (*(a. ir_container ()), *(b. ir_container ()));
111+ if (&a == &b) {
112+ return ;
113+ }
114114
115- // After swapping container contents, update Statement::ir_container_
116- // pointers so each Statement points to the Fusion whose container now
117- // holds it. Also fix per-Fusion tracking keys since a's container had
118- // b's entries and vice versa.
119- a.ir_container ()->transferStatementOwnership (&b, &a);
120- b.ir_container ()->transferStatementOwnership (&a, &b);
115+ // Collect statements owned by each Fusion BEFORE swap so we can update
116+ // Statement::ir_container_ pointers afterward.
117+ std::vector<Val*> a_owned_vals, b_owned_vals;
118+ std::vector<Expr*> a_owned_exprs, b_owned_exprs;
121119
122120 if (a.ir_container_ ) {
123- for (auto val : a.vals ()) {
124- val->ir_container_ = &a;
125- }
126- for (auto expr : a.deterministic_exprs ()) {
127- expr->ir_container_ = &a;
128- }
121+ const auto & av = a.ir_container_ ->valsOwnedBy (&a);
122+ const auto & ae = a.ir_container_ ->exprsOwnedBy (&a);
123+ a_owned_vals.assign (av.begin (), av.end ());
124+ a_owned_exprs.assign (ae.begin (), ae.end ());
129125 }
130126 if (b.ir_container_ ) {
131- for (auto val : b.vals ()) {
132- val->ir_container_ = &b;
133- }
134- for (auto expr : b.deterministic_exprs ()) {
135- expr->ir_container_ = &b;
136- }
127+ const auto & bv = b.ir_container_ ->valsOwnedBy (&b);
128+ const auto & be = b.ir_container_ ->exprsOwnedBy (&b);
129+ b_owned_vals.assign (bv.begin (), bv.end ());
130+ b_owned_exprs.assign (be.begin (), be.end ());
131+ }
132+
133+ // Transfer Fusion registrations between containers before pointer swap.
134+ // After swap, a will own b's container and b will own a's container.
135+ if (a.ir_container_ && b.ir_container_ &&
136+ a.ir_container_ .get () != b.ir_container_ .get ()) {
137+ a.ir_container_ ->transferFusion (&a, &b);
138+ b.ir_container_ ->transferFusion (&b, &a);
137139 }
138140
141+ // Swap container pointers
142+ std::swap (a.ir_container_ , b.ir_container_ );
143+
144+ // Swap all Fusion-level members
139145 std::swap (a.inputs_ , b.inputs_ );
140146 std::swap (a.outputs_ , b.outputs_ );
141-
142147 std::swap (a.io_alias_ , b.io_alias_ );
143-
144- // Swap per-Fusion special values (Phase 2)
148+ std::swap (a.all_tv_uses_valid_ , b.all_tv_uses_valid_ );
149+ std::swap (a.is_during_update_uses_ , b.is_during_update_uses_ );
150+ std::swap (a.managed_data_ , b.managed_data_ );
151+ std::swap (a.managed_named_data_ , b.managed_named_data_ );
152+ std::swap (a.expected_dynamic_smem_bytes_ , b.expected_dynamic_smem_bytes_ );
153+ std::swap (a.all_tvs_ptr_ , b.all_tvs_ptr_ );
145154 std::swap (a.zero_val_ , b.zero_val_ );
146155 std::swap (a.one_val_ , b.one_val_ );
147156 std::swap (a.true_val_ , b.true_val_ );
148157 std::swap (a.false_val_ , b.false_val_ );
149158 std::swap (a.magic_zero_val_ , b.magic_zero_val_ );
150-
151159 std::swap (a.axioms_ , b.axioms_ );
152160 std::swap (a.metadata_ , b.metadata_ );
161+
162+ // Update Statement::ir_container_ pointers: a's old statements now belong
163+ // to b, and b's old statements now belong to a
164+ for (auto * val : a_owned_vals) {
165+ val->ir_container_ = &b;
166+ }
167+ for (auto * expr : a_owned_exprs) {
168+ expr->ir_container_ = &b;
169+ }
170+ for (auto * val : b_owned_vals) {
171+ val->ir_container_ = &a;
172+ }
173+ for (auto * expr : b_owned_exprs) {
174+ expr->ir_container_ = &a;
175+ }
176+
177+ // Update per-Fusion tracking keys in containers
178+ if (a.ir_container_ && b.ir_container_ ) {
179+ if (a.ir_container_ .get () == b.ir_container_ .get ()) {
180+ // Same container: directly swap per-Fusion tracking entries
181+ auto * c = a.ir_container_ .get ();
182+ std::swap (c->per_fusion_vals_ [&a], c->per_fusion_vals_ [&b]);
183+ std::swap (c->per_fusion_exprs_ [&a], c->per_fusion_exprs_ [&b]);
184+ } else {
185+ // Different containers: rename tracking keys to match new owners
186+ a.ir_container_ ->transferStatementOwnership (&b, &a);
187+ b.ir_container_ ->transferStatementOwnership (&a, &b);
188+ }
189+ } else if (a.ir_container_ ) {
190+ a.ir_container_ ->transferStatementOwnership (&b, &a);
191+ } else if (b.ir_container_ ) {
192+ b.ir_container_ ->transferStatementOwnership (&a, &b);
193+ }
153194}
154195
155196std::unique_ptr<SegmentedFusion> Fusion::segment (
@@ -161,10 +202,20 @@ std::unique_ptr<SegmentedFusion> Fusion::segment(
161202IrCloner Fusion::copy (const Fusion* from, Fusion* to) {
162203 to->clear ();
163204
164- auto ir_cloner =
165- IrContainer::copy (from->ir_container (), to->ir_container (), to);
205+ IrCloner ir_cloner (to);
206+
207+ // Clone from's vals in insertion order
208+ for (auto val : from->deterministic_vals ()) {
209+ ir_cloner.clone (val);
210+ }
166211
167- // Remap cached special val pointers through the cloner
212+ // Wire up definitions and uses on cloned vals
213+ for (auto val : from->vals ()) {
214+ ir_cloner.clone (val)->setDefinition (ir_cloner.clone (val->definition_ ));
215+ ir_cloner.clone (val)->setUses (ir_cloner.clone (val->uses_ ));
216+ }
217+
218+ // Remap cached special val pointers
168219 if (from->zero_val_ ) {
169220 to->zero_val_ = ir_cloner.clone (from->zero_val_ );
170221 }
@@ -182,11 +233,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
182233 ir_cloner.clone (from->magic_zero_val_ )->as <NamedScalar>();
183234 }
184235
185- for (auto val : from->vals ()) {
186- ir_cloner.clone (val)->setDefinition (ir_cloner.clone (val->definition_ ));
187- ir_cloner.clone (val)->setUses (ir_cloner.clone (val->uses_ ));
188- }
189-
190236 to->inputs_ = ir_cloner.clone (from->inputs_ );
191237 to->outputs_ = ir_cloner.clone (from->outputs_ );
192238 for (auto inp : to->inputs_ ) {
@@ -196,7 +242,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
196242 out->setIsFusionOutput (true );
197243 }
198244
199- // TODO: put this into ir_cloner instead
200245 for (Val* out : from->outputs_ ) {
201246 const AliasInfo& alias = from->io_alias_ .get (out);
202247 if (alias.type == AllocationType::New) {
@@ -209,14 +254,12 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
209254 }
210255
211256 to->all_tv_uses_valid_ = from->all_tv_uses_valid_ ;
212- // This should never be true on copy, but copying for completeness.
213257 to->is_during_update_uses_ = from->is_during_update_uses_ ;
214258
215259 for (const auto & i : from->managed_data_ ) {
216260 if (i.first .has_value ()) {
217261 to->managed_data_ .emplace_back (i.second (ir_cloner, i.first ), i.second );
218262 } else {
219- // Don't clone managed data if it has been reset
220263 to->managed_data_ .emplace_back (i.first , i.second );
221264 }
222265 }
@@ -259,9 +302,10 @@ Fusion::Fusion() : ir_container_(std::make_shared<IrContainer>()) {
259302 ir_container_->addFusion (this );
260303}
261304
262- // Copy constructor
263- Fusion::Fusion (const Fusion& other) : Fusion( ) {
305+ // Copy constructor -- shares the source's container
306+ Fusion::Fusion (const Fusion& other) : ir_container_(other.ir_container_ ) {
264307 FUSER_PERF_SCOPE (" Fusion copy" );
308+ ir_container_->addFusion (this );
265309 Fusion::copy (&other, this );
266310}
267311
@@ -281,6 +325,9 @@ Fusion& Fusion::operator=(const Fusion& other) {
281325
282326Fusion& Fusion::operator =(Fusion&& other) noexcept {
283327 FUSER_PERF_SCOPE (" Fusion move assign" );
328+ if (this == &other) {
329+ return *this ;
330+ }
284331 clear ();
285332 swap (*this , other);
286333 return *this ;
0 commit comments