Skip to content

Commit 27c6334

Browse files
authored
[IR Refactor] Fusion Base Type (#5902)
## Summary This PR removes the intermediate layer between Fusion and the underlying container introduced in #5865. For review see also #5905 (separated for better diffs) that focusses on file re-name that should be merged into this first. ### Open Question for Reviewers With this PR `nvfuser::Fusion` has over 90 function definitions in it's header (>500 lines long now). This PR adds ~20 functions which are mostly forwarding calls for `IrContainer` methods. Is it really best to populate the Fusion interface directly with all of these functions? Perhaps an interface class that handles the ownership of the `IrContainer` is better practice for proper encapsulation.
1 parent 901f30a commit 27c6334

File tree

16 files changed

+739
-880
lines changed

16 files changed

+739
-880
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ list(APPEND NVFUSER_SRCS
216216
${NVFUSER_SRCS_DIR}/ir/builder.cpp
217217
${NVFUSER_SRCS_DIR}/ir/cloner.cpp
218218
${NVFUSER_SRCS_DIR}/ir/container.cpp
219-
${NVFUSER_SRCS_DIR}/ir/storage.cpp
220219
${NVFUSER_SRCS_DIR}/ir/graphviz.cpp
221220
${NVFUSER_SRCS_DIR}/ir/iostream.cpp
222221
${NVFUSER_SRCS_DIR}/ir/internal_base_nodes.cpp

csrc/dispatch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ class NVF_API OptOutMutator : public PolymorphicBase {
343343
}
344344

345345
protected:
346-
virtual void removeExpr(IrContainer*, Expr*) const;
346+
virtual void removeExpr(Fusion*, Expr*) const;
347347
virtual void registerNewExpr(Expr*) {}
348348

349349
private:

csrc/fusion.cpp

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,34 @@ bool Fusion::sameDefinition(const Fusion& other) const {
104104
void Fusion::swap(Fusion& a, Fusion& b) noexcept {
105105
FUSER_PERF_SCOPE("Fusion swap");
106106

107-
// Swap IrContainer base class (contains IrStorage)
108-
IrContainer::swap(static_cast<IrContainer&>(a), static_cast<IrContainer&>(b));
107+
// We need to be careful to call IrContainer swap not unique_ptr swap, which
108+
// will only swap the ptrs NOT the contents.
109+
IrContainer::swap(*(a.ir_container()), *(b.ir_container()));
110+
111+
// Fix parent pointers after swapping containers
112+
// After swap, each Fusion owns a different IrContainer, so we must
113+
// update the parent backpointers in those containers to point to their new
114+
// owners
115+
if (a.ir_container_) {
116+
// Also update all Statement ir_container_ pointers to point to new owner
117+
a.ir_container()->parent_ = &a;
118+
for (auto val : a.vals()) {
119+
val->ir_container_ = &a;
120+
}
121+
for (auto expr : a.deterministic_exprs()) {
122+
expr->ir_container_ = &a;
123+
}
124+
}
125+
if (b.ir_container_) {
126+
// Also update all Statement ir_container_ pointers to point to new owner
127+
b.ir_container()->parent_ = &b;
128+
for (auto val : b.vals()) {
129+
val->ir_container_ = &b;
130+
}
131+
for (auto expr : b.deterministic_exprs()) {
132+
expr->ir_container_ = &b;
133+
}
134+
}
109135

110136
std::swap(a.inputs_, b.inputs_);
111137
std::swap(a.outputs_, b.outputs_);
@@ -122,7 +148,7 @@ std::unique_ptr<SegmentedFusion> Fusion::segment(
122148
IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
123149
to->clear();
124150

125-
auto ir_cloner = IrContainer::copy(from, to);
151+
auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container());
126152

127153
for (auto val : from->vals()) {
128154
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
@@ -183,14 +209,19 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
183209
return ir_cloner;
184210
}
185211

212+
// Default constructor
213+
Fusion::Fusion() : ir_container_(std::make_unique<IrContainer>()) {
214+
ir_container_->parent_ = this;
215+
}
216+
186217
// Copy constructor
187-
Fusion::Fusion(const Fusion& other) {
218+
Fusion::Fusion(const Fusion& other) : Fusion() {
188219
FUSER_PERF_SCOPE("Fusion copy");
189220
Fusion::copy(&other, this);
190221
}
191222

192223
// Move constructor
193-
Fusion::Fusion(Fusion&& other) noexcept {
224+
Fusion::Fusion(Fusion&& other) noexcept : Fusion() {
194225
FUSER_PERF_SCOPE("Fusion move");
195226
swap(*this, other);
196227
}
@@ -223,7 +254,7 @@ void Fusion::clear() noexcept {
223254
// Clear container contents instead of destroying it
224255
// This preserves the container object so Statement pointers don't become
225256
// dangling
226-
ir_storage()->clear();
257+
ir_container()->clear();
227258

228259
inputs_.clear();
229260
outputs_.clear();
@@ -260,7 +291,7 @@ void Fusion::removeExpr(Expr* expr) {
260291
}
261292
}
262293

263-
IrContainer::removeExpr(expr);
294+
ir_container()->removeExpr(expr);
264295
}
265296

266297
void Fusion::removeVal(Val* val) {
@@ -304,7 +335,7 @@ void Fusion::removeVal(Val* val) {
304335
for (auto e : exprs_to_remove) {
305336
removeExpr(e);
306337
}
307-
IrContainer::removeVal(val);
338+
ir_container()->removeVal(val);
308339

309340
invalidateTvsAndUses();
310341
}
@@ -668,7 +699,7 @@ void Fusion::registerVal(Val* val) {
668699
val->fusion() == this, val, " was not found in the active fusion.");
669700
}
670701

671-
IrContainer::registerVal(val);
702+
ir_container()->registerVal(val);
672703
}
673704

674705
void Fusion::registerExpr(Expr* expr) {
@@ -681,7 +712,7 @@ void Fusion::registerExpr(Expr* expr) {
681712
expr->fusion() == this, expr, " was not found in the active fusion.");
682713
}
683714

684-
IrContainer::registerExpr(expr);
715+
ir_container()->registerExpr(expr);
685716

686717
for (Val* input : expr->inputs()) {
687718
assertInContainer(input, "Input to expr is invalid, ");

csrc/fusion.h

Lines changed: 175 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <unordered_map>
1313
#include <unordered_set>
1414
#include <vector>
15+
#include "base.h"
1516

1617
#include <ATen/core/ivalue.h>
1718

@@ -142,11 +143,36 @@ class AliasInfoMap {
142143
//! The Fusion owns the whole IR graph (Vals and Exprs)
143144
//!
144145
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
145-
class NVF_API Fusion : public IrContainer {
146+
class NVF_API Fusion : public PolymorphicBase {
146147
typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap;
147148

149+
protected:
150+
// Direct access to underlying container
151+
IrContainer* ir_container() {
152+
NVF_ERROR(
153+
ir_container_.get() != nullptr,
154+
"Accessing an uninitialized IrContainer!.")
155+
return ir_container_.get();
156+
}
157+
158+
const IrContainer* ir_container() const {
159+
NVF_ERROR(
160+
ir_container_.get() != nullptr,
161+
"Accessing an uninitialized IrContainer!.")
162+
return ir_container_.get();
163+
}
164+
148165
public:
149-
Fusion() = default;
166+
// Registration (public API with passkey)
167+
virtual void registerStmt(IrBuilderPasskey, Statement* stmt) {
168+
if (stmt->isVal()) {
169+
registerVal(stmt->asVal());
170+
} else {
171+
registerExpr(stmt->asExpr());
172+
}
173+
}
174+
175+
Fusion();
150176

151177
Fusion(const Fusion& other);
152178
Fusion(Fusion&& other) noexcept;
@@ -168,11 +194,11 @@ class NVF_API Fusion : public IrContainer {
168194

169195
//! Break dependency chains associated with Expr, remove references to expr
170196
//! delete expr
171-
void removeExpr(Expr* expr) override;
197+
virtual void removeExpr(Expr* expr);
172198

173199
//! Completely remove val from the fusion, break all dependencies associated
174200
//! with it
175-
void removeVal(Val* val) override;
201+
virtual void removeVal(Val* val);
176202

177203
//! Register input as an input of the fusion
178204
void addInput(Val* input);
@@ -477,25 +503,126 @@ class NVF_API Fusion : public IrContainer {
477503

478504
void resetExactMappings();
479505

506+
//===================================================================
507+
// IrContainer API Forwarding (Public Methods)
508+
//===================================================================
509+
510+
// Container queries
511+
bool inContainer(const Statement* stmt) const {
512+
return ir_container()->inContainer(stmt);
513+
}
514+
515+
void assertInContainer(const Statement* stmt, const std::string& msg) const {
516+
ir_container()->assertInContainer(stmt, msg);
517+
}
518+
519+
// Collections access (return values in insertion order)
520+
const std::deque<Val*> deterministic_vals() const noexcept {
521+
return ir_container()->deterministic_vals();
522+
}
523+
524+
const std::deque<Expr*> deterministic_exprs() const noexcept {
525+
return ir_container()->deterministic_exprs();
526+
}
527+
528+
const std::unordered_map<Val*, int64_t> deterministic_vals_map()
529+
const noexcept {
530+
return ir_container()->deterministic_vals_map();
531+
}
532+
533+
const std::unordered_map<Expr*, int64_t> deterministic_exprs_map()
534+
const noexcept {
535+
return ir_container()->deterministic_exprs_map();
536+
}
537+
538+
// Collections access (unordered sets)
539+
const std::unordered_set<Expr*>& unordered_exprs() const noexcept {
540+
return ir_container()->unordered_exprs();
541+
}
542+
543+
const std::unordered_set<Val*>& vals() const noexcept {
544+
return ir_container()->vals();
545+
}
546+
547+
// Count queries
548+
int64_t numExprs() const noexcept {
549+
return ir_container()->numExprs();
550+
}
551+
552+
int64_t numVals(bool include_shortcuts) const noexcept {
553+
return ir_container()->numVals(include_shortcuts);
554+
}
555+
556+
// Shortcut values (frequently used constants)
557+
Val* zeroVal() {
558+
return ir_container()->zeroVal();
559+
}
560+
561+
Val* oneVal() {
562+
return ir_container()->oneVal();
563+
}
564+
565+
Val* falseVal() {
566+
return ir_container()->falseVal();
567+
}
568+
569+
Val* trueVal() {
570+
return ir_container()->trueVal();
571+
}
572+
573+
NamedScalar* magicZeroVal() {
574+
return ir_container()->magicZeroVal();
575+
}
576+
577+
Val* zeroVal(DataType dtype) {
578+
return ir_container()->zeroVal(dtype);
579+
}
580+
581+
Val* oneVal(DataType dtype) {
582+
return ir_container()->oneVal(dtype);
583+
}
584+
585+
Val* metadataOf(Val* val) {
586+
return ir_container()->metadataOf(val);
587+
}
588+
589+
// Axioms (CUDA programming assumptions)
590+
const std::vector<Val*>& axioms() {
591+
return ir_container()->axioms();
592+
}
593+
594+
void assumePositive(Val* val) {
595+
ir_container()->assumePositive(val);
596+
}
597+
598+
void assumeNonNegative(Val* val) {
599+
ir_container()->assumeNonNegative(val);
600+
}
601+
602+
// Statement removal
603+
void removeStatementsCreatedAfter(
604+
int64_t num_exprs_before,
605+
int64_t num_vals_before) {
606+
ir_container()->removeStatementsCreatedAfter(
607+
num_exprs_before, num_vals_before);
608+
}
609+
480610
protected:
481611
friend SegmentCandidateFinder;
482612
friend SegmentedFusion;
483613
friend class TranslateApplicableWelford;
484614
friend Val;
485615

486-
using IrContainer::registerExpr;
487-
using IrContainer::registerVal;
488-
489616
//! Register the Val with this fusion
490-
void registerVal(Val* val) override;
617+
virtual void registerVal(Val* val);
491618

492619
//! Register expr with this fusion.
493620
//! When we register an expression, we want to update the dependency tracking
494621
//! of Vals. If this container is a not a Kernel, it will remove previous
495622
//! definitions of outputs and register this Expr as the definition. Otherwise
496623
//! will update definition if not previously set, but will not remove old
497624
//! definitions.
498-
void registerExpr(Expr* expr) override;
625+
virtual void registerExpr(Expr* expr);
499626

500627
//! Clear Expr's from TV uses that are not required to produce outputs from
501628
//! inputs. Only other place this is used (other than Fusion) is in
@@ -539,22 +666,18 @@ class NVF_API Fusion : public IrContainer {
539666
std::unique_ptr<std::vector<TensorView*>> all_tvs_ptr_ = nullptr;
540667

541668
inline static const std::string exact_mappings_key = "exact_mappings";
669+
std::unique_ptr<IrContainer> ir_container_;
542670
};
543671

672+
// Template implementations for Fusion::manage<T>() that use IrCloner
544673
template <typename T>
545674
std::any defaultCloneFunction(IrCloner& cloner, std::any data) {
546675
auto cloned_data = cloner.clone(std::any_cast<T>(data));
547-
// Adding a static_assert to improve error message. Without this
548-
// static_assert, the following cast will still fail, but the error message
549-
// will be unreadable.
550676
static_assert(
551677
std::is_convertible_v<decltype(cloned_data), T>,
552678
"IrCloner::clone returns a data type that is not compatible with the "
553679
"original managed data type. "
554680
"Likely you will need to check IrCloner::clone for your data type.");
555-
// Convert the result of the clone back to T before assigning to std::any.
556-
// This ensures the type of the std::any does not change over the clone of
557-
// fusion.
558681
return std::any((T)cloned_data);
559682
}
560683

@@ -568,4 +691,41 @@ void Fusion::manage(std::string key, T data) {
568691
return manage(key, std::any(data), defaultCloneFunction<T>);
569692
}
570693

694+
// Template implementations for IrBuilder that require Fusion to be fully
695+
// defined
696+
template <class T, class... Args>
697+
T* IrBuilder::createInContainer(Fusion* container, Args&&... args) {
698+
NVF_ERROR(container != nullptr, "Need an active container to build IR.");
699+
T* node = new T(IrBuilderPasskey(container), std::forward<Args>(args)...);
700+
container->registerStmt(IrBuilderPasskey(container), node);
701+
return node;
702+
}
703+
704+
template <class T>
705+
T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) {
706+
NVF_ERROR(
707+
ir_cloner != nullptr,
708+
"Cannot use create when a cloner object is set. Use clone.");
709+
NVF_ERROR(
710+
ir_cloner->container() != nullptr,
711+
"Cloner doesn't have a valid container to store cloned object.");
712+
713+
T* dest = new T(src, ir_cloner);
714+
const auto* src_stmt = dynamic_cast<const Statement*>(src);
715+
auto* dest_stmt = dynamic_cast<Statement*>(dest);
716+
717+
auto dest_container = ir_cloner->container();
718+
auto src_container = src_stmt->container();
719+
720+
dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt);
721+
722+
if (src_container != dest_container) {
723+
dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name());
724+
}
725+
726+
ir_cloner->registerClone(src_stmt, dest_stmt);
727+
728+
return dest;
729+
}
730+
571731
} // namespace nvfuser

0 commit comments

Comments
 (0)