1212#include < unordered_map>
1313#include < unordered_set>
1414#include < vector>
15+ #include " base.h"
1516
1617#include < ATen/core/ivalue.h>
1718
@@ -142,11 +143,36 @@ class AliasInfoMap {
142143// ! The Fusion owns the whole IR graph (Vals and Exprs)
143144// !
144145// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
145- class NVF_API Fusion : public IrContainer {
146+ class NVF_API Fusion : public PolymorphicBase {
146147 typedef std::unordered_map<int , std::vector<int64_t >> PermutationMap;
147148
149+ protected:
150+ // Direct access to underlying container
151+ IrContainer* ir_container () {
152+ NVF_ERROR (
153+ ir_container_.get () != nullptr ,
154+ " Accessing an uninitialized IrContainer!." )
155+ return ir_container_.get ();
156+ }
157+
158+ const IrContainer* ir_container () const {
159+ NVF_ERROR (
160+ ir_container_.get () != nullptr ,
161+ " Accessing an uninitialized IrContainer!." )
162+ return ir_container_.get ();
163+ }
164+
148165 public:
149- Fusion () = default ;
166+ // Registration (public API with passkey)
167+ virtual void registerStmt (IrBuilderPasskey, Statement* stmt) {
168+ if (stmt->isVal ()) {
169+ registerVal (stmt->asVal ());
170+ } else {
171+ registerExpr (stmt->asExpr ());
172+ }
173+ }
174+
175+ Fusion ();
150176
151177 Fusion (const Fusion& other);
152178 Fusion (Fusion&& other) noexcept ;
@@ -168,11 +194,11 @@ class NVF_API Fusion : public IrContainer {
168194
169195 // ! Break dependency chains associated with Expr, remove references to expr
170196 // ! delete expr
171- void removeExpr (Expr* expr) override ;
197+ virtual void removeExpr (Expr* expr);
172198
173199 // ! Completely remove val from the fusion, break all dependencies associated
174200 // ! with it
175- void removeVal (Val* val) override ;
201+ virtual void removeVal (Val* val);
176202
177203 // ! Register input as an input of the fusion
178204 void addInput (Val* input);
@@ -477,25 +503,126 @@ class NVF_API Fusion : public IrContainer {
477503
478504 void resetExactMappings ();
479505
506+ // ===================================================================
507+ // IrContainer API Forwarding (Public Methods)
508+ // ===================================================================
509+
510+ // Container queries
511+ bool inContainer (const Statement* stmt) const {
512+ return ir_container ()->inContainer (stmt);
513+ }
514+
515+ void assertInContainer (const Statement* stmt, const std::string& msg) const {
516+ ir_container ()->assertInContainer (stmt, msg);
517+ }
518+
519+ // Collections access (return values in insertion order)
520+ const std::deque<Val*> deterministic_vals () const noexcept {
521+ return ir_container ()->deterministic_vals ();
522+ }
523+
524+ const std::deque<Expr*> deterministic_exprs () const noexcept {
525+ return ir_container ()->deterministic_exprs ();
526+ }
527+
528+ const std::unordered_map<Val*, int64_t > deterministic_vals_map ()
529+ const noexcept {
530+ return ir_container ()->deterministic_vals_map ();
531+ }
532+
533+ const std::unordered_map<Expr*, int64_t > deterministic_exprs_map ()
534+ const noexcept {
535+ return ir_container ()->deterministic_exprs_map ();
536+ }
537+
538+ // Collections access (unordered sets)
539+ const std::unordered_set<Expr*>& unordered_exprs () const noexcept {
540+ return ir_container ()->unordered_exprs ();
541+ }
542+
543+ const std::unordered_set<Val*>& vals () const noexcept {
544+ return ir_container ()->vals ();
545+ }
546+
547+ // Count queries
548+ int64_t numExprs () const noexcept {
549+ return ir_container ()->numExprs ();
550+ }
551+
552+ int64_t numVals (bool include_shortcuts) const noexcept {
553+ return ir_container ()->numVals (include_shortcuts);
554+ }
555+
556+ // Shortcut values (frequently used constants)
557+ Val* zeroVal () {
558+ return ir_container ()->zeroVal ();
559+ }
560+
561+ Val* oneVal () {
562+ return ir_container ()->oneVal ();
563+ }
564+
565+ Val* falseVal () {
566+ return ir_container ()->falseVal ();
567+ }
568+
569+ Val* trueVal () {
570+ return ir_container ()->trueVal ();
571+ }
572+
573+ NamedScalar* magicZeroVal () {
574+ return ir_container ()->magicZeroVal ();
575+ }
576+
577+ Val* zeroVal (DataType dtype) {
578+ return ir_container ()->zeroVal (dtype);
579+ }
580+
581+ Val* oneVal (DataType dtype) {
582+ return ir_container ()->oneVal (dtype);
583+ }
584+
585+ Val* metadataOf (Val* val) {
586+ return ir_container ()->metadataOf (val);
587+ }
588+
589+ // Axioms (CUDA programming assumptions)
590+ const std::vector<Val*>& axioms () {
591+ return ir_container ()->axioms ();
592+ }
593+
594+ void assumePositive (Val* val) {
595+ ir_container ()->assumePositive (val);
596+ }
597+
598+ void assumeNonNegative (Val* val) {
599+ ir_container ()->assumeNonNegative (val);
600+ }
601+
602+ // Statement removal
603+ void removeStatementsCreatedAfter (
604+ int64_t num_exprs_before,
605+ int64_t num_vals_before) {
606+ ir_container ()->removeStatementsCreatedAfter (
607+ num_exprs_before, num_vals_before);
608+ }
609+
480610 protected:
481611 friend SegmentCandidateFinder;
482612 friend SegmentedFusion;
483613 friend class TranslateApplicableWelford ;
484614 friend Val;
485615
486- using IrContainer::registerExpr;
487- using IrContainer::registerVal;
488-
489616 // ! Register the Val with this fusion
490- void registerVal (Val* val) override ;
617+ virtual void registerVal (Val* val);
491618
492619 // ! Register expr with this fusion.
493620 // ! When we register an expression, we want to update the dependency tracking
494621 // ! of Vals. If this container is a not a Kernel, it will remove previous
495622 // ! definitions of outputs and register this Expr as the definition. Otherwise
496623 // ! will update definition if not previously set, but will not remove old
497624 // ! definitions.
498- void registerExpr (Expr* expr) override ;
625+ virtual void registerExpr (Expr* expr);
499626
500627 // ! Clear Expr's from TV uses that are not required to produce outputs from
501628 // ! inputs. Only other place this is used (other than Fusion) is in
@@ -539,22 +666,18 @@ class NVF_API Fusion : public IrContainer {
539666 std::unique_ptr<std::vector<TensorView*>> all_tvs_ptr_ = nullptr ;
540667
541668 inline static const std::string exact_mappings_key = " exact_mappings" ;
669+ std::unique_ptr<IrContainer> ir_container_;
542670};
543671
672+ // Template implementations for Fusion::manage<T>() that use IrCloner
544673template <typename T>
545674std::any defaultCloneFunction (IrCloner& cloner, std::any data) {
546675 auto cloned_data = cloner.clone (std::any_cast<T>(data));
547- // Adding a static_assert to improve error message. Without this
548- // static_assert, the following cast will still fail, but the error message
549- // will be unreadable.
550676 static_assert (
551677 std::is_convertible_v<decltype (cloned_data), T>,
552678 " IrCloner::clone returns a data type that is not compatible with the "
553679 " original managed data type. "
554680 " Likely you will need to check IrCloner::clone for your data type." );
555- // Convert the result of the clone back to T before assigning to std::any.
556- // This ensures the type of the std::any does not change over the clone of
557- // fusion.
558681 return std::any ((T)cloned_data);
559682}
560683
@@ -568,4 +691,41 @@ void Fusion::manage(std::string key, T data) {
568691 return manage (key, std::any (data), defaultCloneFunction<T>);
569692}
570693
694+ // Template implementations for IrBuilder that require Fusion to be fully
695+ // defined
696+ template <class T , class ... Args>
697+ T* IrBuilder::createInContainer (Fusion* container, Args&&... args) {
698+ NVF_ERROR (container != nullptr , " Need an active container to build IR." );
699+ T* node = new T (IrBuilderPasskey (container), std::forward<Args>(args)...);
700+ container->registerStmt (IrBuilderPasskey (container), node);
701+ return node;
702+ }
703+
704+ template <class T >
705+ T* IrBuilder::clone (const T* src, IrCloner* ir_cloner) {
706+ NVF_ERROR (
707+ ir_cloner != nullptr ,
708+ " Cannot use create when a cloner object is set. Use clone." );
709+ NVF_ERROR (
710+ ir_cloner->container () != nullptr ,
711+ " Cloner doesn't have a valid container to store cloned object." );
712+
713+ T* dest = new T (src, ir_cloner);
714+ const auto * src_stmt = dynamic_cast <const Statement*>(src);
715+ auto * dest_stmt = dynamic_cast <Statement*>(dest);
716+
717+ auto dest_container = ir_cloner->container ();
718+ auto src_container = src_stmt->container ();
719+
720+ dest_container->registerStmt (IrBuilderPasskey (dest_container), dest_stmt);
721+
722+ if (src_container != dest_container) {
723+ dest_stmt->setName (IrBuilderPasskey (dest_container), src_stmt->name ());
724+ }
725+
726+ ir_cloner->registerClone (src_stmt, dest_stmt);
727+
728+ return dest;
729+ }
730+
571731} // namespace nvfuser
0 commit comments