Skip to content

Commit 9ae372a

Browse files
committed
Copy/move/swap semantics for shared containers
Copy constructor now shares the source's container pointer instead of creating a new one. Fusion::copy clones directly from per-Fusion filtered vals rather than delegating to IrContainer::copy. Swap changed from content-based (IrContainer::swap) to pointer-based with per-Fusion ownership tracking for both same-container and different-container cases.
1 parent 9f944b4 commit 9ae372a

File tree

1 file changed

+85
-38
lines changed

1 file changed

+85
-38
lines changed

csrc/fusion.cpp

Lines changed: 85 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -108,48 +108,89 @@ bool Fusion::sameDefinition(const Fusion& other) const {
108108
void Fusion::swap(Fusion& a, Fusion& b) noexcept {
109109
FUSER_PERF_SCOPE("Fusion swap");
110110

111-
// We need to be careful to call IrContainer swap not unique_ptr swap, which
112-
// will only swap the ptrs NOT the contents.
113-
IrContainer::swap(*(a.ir_container()), *(b.ir_container()));
111+
if (&a == &b) {
112+
return;
113+
}
114114

115-
// After swapping container contents, update Statement::ir_container_
116-
// pointers so each Statement points to the Fusion whose container now
117-
// holds it. Also fix per-Fusion tracking keys since a's container had
118-
// b's entries and vice versa.
119-
a.ir_container()->transferStatementOwnership(&b, &a);
120-
b.ir_container()->transferStatementOwnership(&a, &b);
115+
// Collect statements owned by each Fusion BEFORE swap so we can update
116+
// Statement::ir_container_ pointers afterward.
117+
std::vector<Val*> a_owned_vals, b_owned_vals;
118+
std::vector<Expr*> a_owned_exprs, b_owned_exprs;
121119

122120
if (a.ir_container_) {
123-
for (auto val : a.vals()) {
124-
val->ir_container_ = &a;
125-
}
126-
for (auto expr : a.deterministic_exprs()) {
127-
expr->ir_container_ = &a;
128-
}
121+
const auto& av = a.ir_container_->valsOwnedBy(&a);
122+
const auto& ae = a.ir_container_->exprsOwnedBy(&a);
123+
a_owned_vals.assign(av.begin(), av.end());
124+
a_owned_exprs.assign(ae.begin(), ae.end());
129125
}
130126
if (b.ir_container_) {
131-
for (auto val : b.vals()) {
132-
val->ir_container_ = &b;
133-
}
134-
for (auto expr : b.deterministic_exprs()) {
135-
expr->ir_container_ = &b;
136-
}
127+
const auto& bv = b.ir_container_->valsOwnedBy(&b);
128+
const auto& be = b.ir_container_->exprsOwnedBy(&b);
129+
b_owned_vals.assign(bv.begin(), bv.end());
130+
b_owned_exprs.assign(be.begin(), be.end());
131+
}
132+
133+
// Transfer Fusion registrations between containers before pointer swap.
134+
// After swap, a will own b's container and b will own a's container.
135+
if (a.ir_container_ && b.ir_container_ &&
136+
a.ir_container_.get() != b.ir_container_.get()) {
137+
a.ir_container_->transferFusion(&a, &b);
138+
b.ir_container_->transferFusion(&b, &a);
137139
}
138140

141+
// Swap container pointers
142+
std::swap(a.ir_container_, b.ir_container_);
143+
144+
// Swap all Fusion-level members
139145
std::swap(a.inputs_, b.inputs_);
140146
std::swap(a.outputs_, b.outputs_);
141-
142147
std::swap(a.io_alias_, b.io_alias_);
143-
144-
// Swap per-Fusion special values (Phase 2)
148+
std::swap(a.all_tv_uses_valid_, b.all_tv_uses_valid_);
149+
std::swap(a.is_during_update_uses_, b.is_during_update_uses_);
150+
std::swap(a.managed_data_, b.managed_data_);
151+
std::swap(a.managed_named_data_, b.managed_named_data_);
152+
std::swap(a.expected_dynamic_smem_bytes_, b.expected_dynamic_smem_bytes_);
153+
std::swap(a.all_tvs_ptr_, b.all_tvs_ptr_);
145154
std::swap(a.zero_val_, b.zero_val_);
146155
std::swap(a.one_val_, b.one_val_);
147156
std::swap(a.true_val_, b.true_val_);
148157
std::swap(a.false_val_, b.false_val_);
149158
std::swap(a.magic_zero_val_, b.magic_zero_val_);
150-
151159
std::swap(a.axioms_, b.axioms_);
152160
std::swap(a.metadata_, b.metadata_);
161+
162+
// Update Statement::ir_container_ pointers: a's old statements now belong
163+
// to b, and b's old statements now belong to a
164+
for (auto* val : a_owned_vals) {
165+
val->ir_container_ = &b;
166+
}
167+
for (auto* expr : a_owned_exprs) {
168+
expr->ir_container_ = &b;
169+
}
170+
for (auto* val : b_owned_vals) {
171+
val->ir_container_ = &a;
172+
}
173+
for (auto* expr : b_owned_exprs) {
174+
expr->ir_container_ = &a;
175+
}
176+
177+
// Update per-Fusion tracking keys in containers
178+
if (a.ir_container_ && b.ir_container_) {
179+
if (a.ir_container_.get() == b.ir_container_.get()) {
180+
// Same container: directly swap per-Fusion tracking entries
181+
auto* c = a.ir_container_.get();
182+
std::swap(c->per_fusion_vals_[&a], c->per_fusion_vals_[&b]);
183+
std::swap(c->per_fusion_exprs_[&a], c->per_fusion_exprs_[&b]);
184+
} else {
185+
// Different containers: rename tracking keys to match new owners
186+
a.ir_container_->transferStatementOwnership(&b, &a);
187+
b.ir_container_->transferStatementOwnership(&a, &b);
188+
}
189+
} else if (a.ir_container_) {
190+
a.ir_container_->transferStatementOwnership(&b, &a);
191+
} else if (b.ir_container_) {
192+
b.ir_container_->transferStatementOwnership(&a, &b);
193+
}
153194
}
154195

155196
std::unique_ptr<SegmentedFusion> Fusion::segment(
@@ -161,10 +202,20 @@ std::unique_ptr<SegmentedFusion> Fusion::segment(
161202
IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
162203
to->clear();
163204

164-
auto ir_cloner =
165-
IrContainer::copy(from->ir_container(), to->ir_container(), to);
205+
IrCloner ir_cloner(to);
206+
207+
// Clone from's vals in insertion order
208+
for (auto val : from->deterministic_vals()) {
209+
ir_cloner.clone(val);
210+
}
166211

167-
// Remap cached special val pointers through the cloner
212+
// Wire up definitions and uses on cloned vals
213+
for (auto val : from->vals()) {
214+
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
215+
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
216+
}
217+
218+
// Remap cached special val pointers
168219
if (from->zero_val_) {
169220
to->zero_val_ = ir_cloner.clone(from->zero_val_);
170221
}
@@ -182,11 +233,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
182233
ir_cloner.clone(from->magic_zero_val_)->as<NamedScalar>();
183234
}
184235

185-
for (auto val : from->vals()) {
186-
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
187-
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
188-
}
189-
190236
to->inputs_ = ir_cloner.clone(from->inputs_);
191237
to->outputs_ = ir_cloner.clone(from->outputs_);
192238
for (auto inp : to->inputs_) {
@@ -196,7 +242,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
196242
out->setIsFusionOutput(true);
197243
}
198244

199-
// TODO: put this into ir_cloner instead
200245
for (Val* out : from->outputs_) {
201246
const AliasInfo& alias = from->io_alias_.get(out);
202247
if (alias.type == AllocationType::New) {
@@ -209,14 +254,12 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
209254
}
210255

211256
to->all_tv_uses_valid_ = from->all_tv_uses_valid_;
212-
// This should never be true on copy, but copying for completeness.
213257
to->is_during_update_uses_ = from->is_during_update_uses_;
214258

215259
for (const auto& i : from->managed_data_) {
216260
if (i.first.has_value()) {
217261
to->managed_data_.emplace_back(i.second(ir_cloner, i.first), i.second);
218262
} else {
219-
// Don't clone managed data if it has been reset
220263
to->managed_data_.emplace_back(i.first, i.second);
221264
}
222265
}
@@ -259,9 +302,10 @@ Fusion::Fusion() : ir_container_(std::make_shared<IrContainer>()) {
259302
ir_container_->addFusion(this);
260303
}
261304

262-
// Copy constructor
263-
Fusion::Fusion(const Fusion& other) : Fusion() {
305+
// Copy constructor -- shares the source's container
306+
Fusion::Fusion(const Fusion& other) : ir_container_(other.ir_container_) {
264307
FUSER_PERF_SCOPE("Fusion copy");
308+
ir_container_->addFusion(this);
265309
Fusion::copy(&other, this);
266310
}
267311

@@ -281,6 +325,9 @@ Fusion& Fusion::operator=(const Fusion& other) {
281325

282326
Fusion& Fusion::operator=(Fusion&& other) noexcept {
283327
FUSER_PERF_SCOPE("Fusion move assign");
328+
if (this == &other) {
329+
return *this;
330+
}
284331
clear();
285332
swap(*this, other);
286333
return *this;

0 commit comments

Comments
 (0)