Skip to content

Commit 7007bf9

Browse files
committed
shared_ptr<IrContainer> transition and Fusion tracking infrastructure
Change Fusion::ir_container_ from unique_ptr to shared_ptr to enable future container sharing between Fusions. Add Fusion tracking API to IrContainer (addFusion/removeFusion/transferFusion/sharingCount). Remove IrContainer::parent_ since the 1:1 relationship no longer holds. Disable parallel compilation during the shared_ptr transition.
1 parent 0a16eaa commit 7007bf9

File tree

5 files changed

+62
-28
lines changed

5 files changed

+62
-28
lines changed

csrc/fusion.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,6 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
117117
// update the parent backpointers in those containers to point to their new
118118
// owners
119119
if (a.ir_container_) {
120-
// Also update all Statement ir_container_ pointers to point to new owner
121-
a.ir_container()->parent_ = &a;
122120
for (auto val : a.vals()) {
123121
val->ir_container_ = &a;
124122
}
@@ -127,8 +125,6 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
127125
}
128126
}
129127
if (b.ir_container_) {
130-
// Also update all Statement ir_container_ pointers to point to new owner
131-
b.ir_container()->parent_ = &b;
132128
for (auto val : b.vals()) {
133129
val->ir_container_ = &b;
134130
}
@@ -162,7 +158,8 @@ std::unique_ptr<SegmentedFusion> Fusion::segment(
162158
IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
163159
to->clear();
164160

165-
auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container());
161+
auto ir_cloner =
162+
IrContainer::copy(from->ir_container(), to->ir_container(), to);
166163

167164
// Remap cached special val pointers through the cloner
168165
if (from->zero_val_) {
@@ -255,8 +252,8 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
255252
}
256253

257254
// Default constructor
258-
Fusion::Fusion() : ir_container_(std::make_unique<IrContainer>()) {
259-
ir_container_->parent_ = this;
255+
Fusion::Fusion() : ir_container_(std::make_shared<IrContainer>()) {
256+
ir_container_->addFusion(this);
260257
}
261258

262259
// Copy constructor
@@ -288,6 +285,9 @@ Fusion& Fusion::operator=(Fusion&& other) noexcept {
288285

289286
Fusion::~Fusion() {
290287
clear();
288+
if (ir_container_) {
289+
ir_container_->removeFusion(this);
290+
}
291291
}
292292

293293
void Fusion::clear() noexcept {
@@ -351,9 +351,7 @@ void Fusion::removeExpr(Expr* expr) {
351351
auto expr_in_deque = std::find_if(
352352
c->exprs_up_.begin(),
353353
c->exprs_up_.end(),
354-
[expr](std::unique_ptr<Expr>& expr_up) {
355-
return expr_up.get() == expr;
356-
});
354+
[expr](std::unique_ptr<Expr>& expr_up) { return expr_up.get() == expr; });
357355
NVF_ERROR(
358356
expr_in_deque != c->exprs_up_.end(),
359357
"Wanted to remove an expression but its unique ptr is missing.");

csrc/fusion.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ class NVF_API Fusion : public PolymorphicBase {
149149
typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap;
150150

151151
protected:
152-
// Direct access to underlying container
153152
IrContainer* ir_container() {
154153
NVF_ERROR(
155154
ir_container_.get() != nullptr,
@@ -164,6 +163,10 @@ class NVF_API Fusion : public PolymorphicBase {
164163
return ir_container_.get();
165164
}
166165

166+
std::shared_ptr<IrContainer> ir_container_ptr() const {
167+
return ir_container_;
168+
}
169+
167170
public:
168171
// Registration (public API with passkey)
169172
virtual void registerStmt(IrBuilderPasskey, Statement* stmt) {
@@ -636,7 +639,7 @@ class NVF_API Fusion : public PolymorphicBase {
636639
std::unique_ptr<std::vector<TensorView*>> all_tvs_ptr_ = nullptr;
637640

638641
inline static const std::string exact_mappings_key = "exact_mappings";
639-
std::unique_ptr<IrContainer> ir_container_;
642+
std::shared_ptr<IrContainer> ir_container_;
640643

641644
Val* zero_val_ = nullptr;
642645
Val* one_val_ = nullptr;

csrc/ir/container.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,15 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept {
8080

8181
std::swap(a.val_type_name_map_, b.val_type_name_map_);
8282
std::swap(a.expr_name_counter_, b.expr_name_counter_);
83-
84-
std::swap(a.parent_, b.parent_);
8583
}
8684

87-
IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) {
85+
IrCloner IrContainer::copy(
86+
const IrContainer* from,
87+
IrContainer* to,
88+
Fusion* dest_fusion) {
8889
to->clear();
8990

90-
IrCloner ir_cloner(to->parent());
91+
IrCloner ir_cloner(dest_fusion);
9192

9293
// Copy values in deterministic order
9394
for (auto val : from->deterministic_vals()) {
@@ -138,7 +139,7 @@ bool IrContainer::inContainer(const Statement* const_stmt) const {
138139
}
139140

140141
NVF_ERROR(
141-
const_stmt->container() == this->parent(),
142+
sharing_fusions_.count(const_stmt->container()) > 0,
142143
"Container claims to own stmt, but stmt disagrees.");
143144

144145
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
@@ -157,4 +158,29 @@ bool IrContainer::inContainer(const Statement* const_stmt) const {
157158
return true;
158159
}
159160

161+
void IrContainer::addFusion(Fusion* fusion) {
162+
sharing_fusions_.insert(fusion);
163+
}
164+
165+
void IrContainer::removeFusion(Fusion* fusion) {
166+
sharing_fusions_.erase(fusion);
167+
}
168+
169+
void IrContainer::transferFusion(Fusion* from, Fusion* to) {
170+
sharing_fusions_.erase(from);
171+
sharing_fusions_.insert(to);
172+
}
173+
174+
size_t IrContainer::sharingCount() const {
175+
return sharing_fusions_.size();
176+
}
177+
178+
bool IrContainer::hasMultipleFusions() const {
179+
return sharing_fusions_.size() > 1;
180+
}
181+
182+
const std::unordered_set<Fusion*>& IrContainer::sharingFusions() const {
183+
return sharing_fusions_;
184+
}
185+
160186
} // namespace nvfuser

csrc/ir/container.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ class IrContainer {
8686
}
8787

8888
protected:
89-
static IrCloner copy(const IrContainer* from, IrContainer* to);
89+
static IrCloner copy(
90+
const IrContainer* from,
91+
IrContainer* to,
92+
Fusion* dest_fusion);
9093

9194
static void swap(IrContainer& a, IrContainer& b) noexcept;
9295

@@ -127,16 +130,15 @@ class IrContainer {
127130
StmtNameType expr_name_counter_ = 0;
128131

129132
public:
130-
Fusion* parent() const {
131-
NVF_ERROR(
132-
parent_ != nullptr, "Call to IrContainer::parent() holds nullptr.")
133-
return parent_;
134-
}
133+
void addFusion(Fusion* fusion);
134+
void removeFusion(Fusion* fusion);
135+
void transferFusion(Fusion* from, Fusion* to);
136+
size_t sharingCount() const;
137+
bool hasMultipleFusions() const;
138+
const std::unordered_set<Fusion*>& sharingFusions() const;
135139

136140
private:
137-
// Parent Fusion that owns this container (for pure composition pattern)
138-
// Used by Statement::fusion() to navigate back to owning Fusion
139-
Fusion* parent_ = nullptr;
141+
std::unordered_set<Fusion*> sharing_fusions_;
140142
};
141143

142144
} // namespace nvfuser

csrc/runtime/fusion_kernel_runtime.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626

2727
namespace nvfuser {
2828

29+
// TODO: Remove when std::shared_mutex is added to IrContainer.
30+
constexpr bool kPhase2DisableParallelCompile = true;
31+
2932
namespace {
3033
// Replace CUDA tensor with Meta tensor because storing tensors can cause
3134
// out-of-memory issues. Other arguments are returned as-is.
@@ -436,7 +439,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
436439
try {
437440
for (const auto& [group_to_run, group_runtime_inputs] :
438441
zip(runtime_workspace_.group_run_order, all_runtime_inputs)) {
439-
if (num_groups == 1 || isOptionDisabled(DisableOption::ParallelCompile)) {
442+
if (num_groups == 1 || kPhase2DisableParallelCompile ||
443+
isOptionDisabled(DisableOption::ParallelCompile)) {
440444
compileKernel(group_runtime_inputs, group_to_run);
441445
} else {
442446
// launch compileKernel thread here
@@ -470,7 +474,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
470474
throw;
471475
}
472476

473-
if (num_groups != 1 && !isOptionDisabled(DisableOption::ParallelCompile)) {
477+
if (num_groups != 1 && !kPhase2DisableParallelCompile &&
478+
!isOptionDisabled(DisableOption::ParallelCompile)) {
474479
// Wait until all segments finish compiling
475480
getThreadPool()->waitWorkComplete();
476481
NVF_ERROR(

0 commit comments

Comments
 (0)