File tree Expand file tree Collapse file tree 5 files changed +51
-10
lines changed
Expand file tree Collapse file tree 5 files changed +51
-10
lines changed Original file line number Diff line number Diff line change @@ -254,8 +254,9 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
254254}
255255
256256// Default constructor
257- Fusion::Fusion () : ir_container_(std::make_unique <IrContainer>()) {
257+ Fusion::Fusion () : ir_container_(std::make_shared <IrContainer>()) {
258258 ir_container_->parent_ = this ;
259+ ir_container_->addFusion (this );
259260}
260261
261262// Copy constructor
@@ -287,6 +288,9 @@ Fusion& Fusion::operator=(Fusion&& other) noexcept {
287288
288289Fusion::~Fusion () {
289290 clear ();
291+ if (ir_container_) {
292+ ir_container_->removeFusion (this );
293+ }
290294}
291295
292296void Fusion::clear () noexcept {
@@ -350,9 +354,7 @@ void Fusion::removeExpr(Expr* expr) {
350354 auto expr_in_deque = std::find_if (
351355 c->exprs_up_ .begin (),
352356 c->exprs_up_ .end (),
353- [expr](std::unique_ptr<Expr>& expr_up) {
354- return expr_up.get () == expr;
355- });
357+ [expr](std::unique_ptr<Expr>& expr_up) { return expr_up.get () == expr; });
356358 NVF_ERROR (
357359 expr_in_deque != c->exprs_up_ .end (),
358360 " Wanted to remove an expression but its unique ptr is missing." );
Original file line number Diff line number Diff line change @@ -148,7 +148,6 @@ class NVF_API Fusion : public PolymorphicBase {
148148 typedef std::unordered_map<int , std::vector<int64_t >> PermutationMap;
149149
150150 protected:
151- // Direct access to underlying container
152151 IrContainer* ir_container () {
153152 NVF_ERROR (
154153 ir_container_.get () != nullptr ,
@@ -163,6 +162,10 @@ class NVF_API Fusion : public PolymorphicBase {
163162 return ir_container_.get ();
164163 }
165164
165+ std::shared_ptr<IrContainer> ir_container_ptr () const {
166+ return ir_container_;
167+ }
168+
166169 public:
167170 // Registration (public API with passkey)
168171 virtual void registerStmt (IrBuilderPasskey, Statement* stmt) {
@@ -635,7 +638,7 @@ class NVF_API Fusion : public PolymorphicBase {
635638 std::unique_ptr<std::vector<TensorView*>> all_tvs_ptr_ = nullptr ;
636639
637640 inline static const std::string exact_mappings_key = " exact_mappings" ;
638- std::unique_ptr <IrContainer> ir_container_;
641+ std::shared_ptr <IrContainer> ir_container_;
639642
640643 Val* zero_val_ = nullptr ;
641644 Val* one_val_ = nullptr ;
Original file line number Diff line number Diff line change @@ -157,4 +157,29 @@ bool IrContainer::inContainer(const Statement* const_stmt) const {
157157 return true ;
158158}
159159
160+ void IrContainer::addFusion (Fusion* fusion) {
161+ sharing_fusions_.insert (fusion);
162+ }
163+
164+ void IrContainer::removeFusion (Fusion* fusion) {
165+ sharing_fusions_.erase (fusion);
166+ }
167+
168+ void IrContainer::transferFusion (Fusion* from, Fusion* to) {
169+ sharing_fusions_.erase (from);
170+ sharing_fusions_.insert (to);
171+ }
172+
173+ size_t IrContainer::sharingCount () const {
174+ return sharing_fusions_.size ();
175+ }
176+
177+ bool IrContainer::hasMultipleFusions () const {
178+ return sharing_fusions_.size () > 1 ;
179+ }
180+
181+ const std::unordered_set<Fusion*>& IrContainer::sharingFusions () const {
182+ return sharing_fusions_;
183+ }
184+
160185} // namespace nvfuser
Original file line number Diff line number Diff line change @@ -133,10 +133,16 @@ class IrContainer {
133133 return parent_;
134134 }
135135
136+ void addFusion (Fusion* fusion);
137+ void removeFusion (Fusion* fusion);
138+ void transferFusion (Fusion* from, Fusion* to);
139+ size_t sharingCount () const ;
140+ bool hasMultipleFusions () const ;
141+ const std::unordered_set<Fusion*>& sharingFusions () const ;
142+
136143 private:
137- // Parent Fusion that owns this container (for pure composition pattern)
138- // Used by Statement::fusion() to navigate back to owning Fusion
139144 Fusion* parent_ = nullptr ;
145+ std::unordered_set<Fusion*> sharing_fusions_;
140146};
141147
142148} // namespace nvfuser
Original file line number Diff line number Diff line change 2828
2929namespace nvfuser {
3030
31+ // TODO: Remove when std::shared_mutex is added to IrContainer.
32+ constexpr bool kPhase2DisableParallelCompile = true ;
33+
3134namespace {
3235// Replace CUDA tensor with Meta tensor because storing tensors can cause
3336// out-of-memory issues. Other arguments are returned as-is.
@@ -454,7 +457,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
454457 try {
455458 for (const auto & [group_to_run, group_runtime_inputs] :
456459 zip (runtime_workspace_.group_run_order , all_runtime_inputs)) {
457- if (num_groups == 1 || isOptionDisabled (DisableOption::ParallelCompile)) {
460+ if (num_groups == 1 || kPhase2DisableParallelCompile ||
461+ isOptionDisabled (DisableOption::ParallelCompile)) {
458462 compileKernel (group_runtime_inputs, group_to_run);
459463 } else {
460464 // launch compileKernel thread here
@@ -488,7 +492,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
488492 throw ;
489493 }
490494
491- if (num_groups != 1 && !isOptionDisabled (DisableOption::ParallelCompile)) {
495+ if (num_groups != 1 && !kPhase2DisableParallelCompile &&
496+ !isOptionDisabled (DisableOption::ParallelCompile)) {
492497 // Wait until all segments finish compiling
493498 getThreadPool ()->waitWorkComplete ();
494499 NVF_ERROR (
You can’t perform that action at this time.
0 commit comments