Skip to content

Commit 3a199c8

Browse files
committed
shared_ptr<IrContainer> transition and Fusion tracking infrastructure
Change Fusion::ir_container_ from unique_ptr to shared_ptr to enable future container sharing between Fusions. Add Fusion tracking API to IrContainer (addFusion/removeFusion/transferFusion/sharingCount). Disable parallel compilation during the shared_ptr transition.
1 parent 074c814 commit 3a199c8

File tree

5 files changed

+51
-10
lines changed

5 files changed

+51
-10
lines changed

csrc/fusion.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,9 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
254254
}
255255

256256
// Default constructor
257-
Fusion::Fusion() : ir_container_(std::make_unique<IrContainer>()) {
257+
Fusion::Fusion() : ir_container_(std::make_shared<IrContainer>()) {
258258
ir_container_->parent_ = this;
259+
ir_container_->addFusion(this);
259260
}
260261

261262
// Copy constructor
@@ -287,6 +288,9 @@ Fusion& Fusion::operator=(Fusion&& other) noexcept {
287288

288289
Fusion::~Fusion() {
289290
clear();
291+
if (ir_container_) {
292+
ir_container_->removeFusion(this);
293+
}
290294
}
291295

292296
void Fusion::clear() noexcept {
@@ -350,9 +354,7 @@ void Fusion::removeExpr(Expr* expr) {
350354
auto expr_in_deque = std::find_if(
351355
c->exprs_up_.begin(),
352356
c->exprs_up_.end(),
353-
[expr](std::unique_ptr<Expr>& expr_up) {
354-
return expr_up.get() == expr;
355-
});
357+
[expr](std::unique_ptr<Expr>& expr_up) { return expr_up.get() == expr; });
356358
NVF_ERROR(
357359
expr_in_deque != c->exprs_up_.end(),
358360
"Wanted to remove an expression but its unique ptr is missing.");

csrc/fusion.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ class NVF_API Fusion : public PolymorphicBase {
148148
typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap;
149149

150150
protected:
151-
// Direct access to underlying container
152151
IrContainer* ir_container() {
153152
NVF_ERROR(
154153
ir_container_.get() != nullptr,
@@ -163,6 +162,10 @@ class NVF_API Fusion : public PolymorphicBase {
163162
return ir_container_.get();
164163
}
165164

165+
std::shared_ptr<IrContainer> ir_container_ptr() const {
166+
return ir_container_;
167+
}
168+
166169
public:
167170
// Registration (public API with passkey)
168171
virtual void registerStmt(IrBuilderPasskey, Statement* stmt) {
@@ -635,7 +638,7 @@ class NVF_API Fusion : public PolymorphicBase {
635638
std::unique_ptr<std::vector<TensorView*>> all_tvs_ptr_ = nullptr;
636639

637640
inline static const std::string exact_mappings_key = "exact_mappings";
638-
std::unique_ptr<IrContainer> ir_container_;
641+
std::shared_ptr<IrContainer> ir_container_;
639642

640643
Val* zero_val_ = nullptr;
641644
Val* one_val_ = nullptr;

csrc/ir/container.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,29 @@ bool IrContainer::inContainer(const Statement* const_stmt) const {
157157
return true;
158158
}
159159

160+
void IrContainer::addFusion(Fusion* fusion) {
161+
sharing_fusions_.insert(fusion);
162+
}
163+
164+
void IrContainer::removeFusion(Fusion* fusion) {
165+
sharing_fusions_.erase(fusion);
166+
}
167+
168+
void IrContainer::transferFusion(Fusion* from, Fusion* to) {
169+
sharing_fusions_.erase(from);
170+
sharing_fusions_.insert(to);
171+
}
172+
173+
size_t IrContainer::sharingCount() const {
174+
return sharing_fusions_.size();
175+
}
176+
177+
bool IrContainer::hasMultipleFusions() const {
178+
return sharing_fusions_.size() > 1;
179+
}
180+
181+
const std::unordered_set<Fusion*>& IrContainer::sharingFusions() const {
182+
return sharing_fusions_;
183+
}
184+
160185
} // namespace nvfuser

csrc/ir/container.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,16 @@ class IrContainer {
133133
return parent_;
134134
}
135135

136+
void addFusion(Fusion* fusion);
137+
void removeFusion(Fusion* fusion);
138+
void transferFusion(Fusion* from, Fusion* to);
139+
size_t sharingCount() const;
140+
bool hasMultipleFusions() const;
141+
const std::unordered_set<Fusion*>& sharingFusions() const;
142+
136143
private:
137-
// Parent Fusion that owns this container (for pure composition pattern)
138-
// Used by Statement::fusion() to navigate back to owning Fusion
139144
Fusion* parent_ = nullptr;
145+
std::unordered_set<Fusion*> sharing_fusions_;
140146
};
141147

142148
} // namespace nvfuser

csrc/runtime/fusion_kernel_runtime.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929
namespace nvfuser {
3030

31+
// TODO: Remove when std::shared_mutex is added to IrContainer.
32+
constexpr bool kPhase2DisableParallelCompile = true;
33+
3134
namespace {
3235
// Replace CUDA tensor with Meta tensor because storing tensors can cause
3336
// out-of-memory issues. Other arguments are returned as-is.
@@ -454,7 +457,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
454457
try {
455458
for (const auto& [group_to_run, group_runtime_inputs] :
456459
zip(runtime_workspace_.group_run_order, all_runtime_inputs)) {
457-
if (num_groups == 1 || isOptionDisabled(DisableOption::ParallelCompile)) {
460+
if (num_groups == 1 || kPhase2DisableParallelCompile ||
461+
isOptionDisabled(DisableOption::ParallelCompile)) {
458462
compileKernel(group_runtime_inputs, group_to_run);
459463
} else {
460464
// launch compileKernel thread here
@@ -488,7 +492,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
488492
throw;
489493
}
490494

491-
if (num_groups != 1 && !isOptionDisabled(DisableOption::ParallelCompile)) {
495+
if (num_groups != 1 && !kPhase2DisableParallelCompile &&
496+
!isOptionDisabled(DisableOption::ParallelCompile)) {
492497
// Wait until all segments finish compiling
493498
getThreadPool()->waitWorkComplete();
494499
NVF_ERROR(

0 commit comments

Comments
 (0)