File tree Expand file tree Collapse file tree 6 files changed +15
-6
lines changed
Expand file tree Collapse file tree 6 files changed +15
-6
lines changed Original file line number Diff line number Diff line change @@ -163,11 +163,13 @@ class NVF_API Fusion : public PolymorphicBase {
163163 return ir_container_.get ();
164164 }
165165
166+ public:
167+ // Public for Phase 3: callers that create derived Fusions (HostIrContainer,
168+ // segment copies, etc.) need the shared_ptr to share the IrContainer.
166169 std::shared_ptr<IrContainer> ir_container_ptr () const {
167170 return ir_container_;
168171 }
169172
170- public:
171173 // Registration (public API with passkey)
172174 virtual void registerStmt (IrBuilderPasskey, Statement* stmt) {
173175 if (stmt->isVal ()) {
Original file line number Diff line number Diff line change @@ -2803,8 +2803,10 @@ bool TranslateApplicableWelford::wouldTranslateToPersistent(
28032803 [fusion](WelfordOp* welford) { return welford->fusion () == fusion; }),
28042804 " Welfords in given vector not in the same fusion" );
28052805
2806- // Make initial `in-progress copy`
2807- auto test_copy = std::make_unique<Fusion>();
2806+ // Make initial `in-progress copy` — share the source IrContainer so that
2807+ // traversal via getCurFusion() sees all Vals in the same container.
2808+ auto test_copy =
2809+ std::unique_ptr<Fusion>(new Fusion (fusion->ir_container_ptr ()));
28082810 auto original_to_test_map = Fusion::copy (fusion, test_copy.get ());
28092811
28102812 std::vector<WelfordOp*> copied_welfords;
Original file line number Diff line number Diff line change @@ -23,6 +23,8 @@ namespace nvfuser::hir {
2323class HostIrContainer final : public Fusion {
2424 public:
2525 HostIrContainer () = default ;
26+ explicit HostIrContainer (std::shared_ptr<IrContainer> container)
27+ : Fusion(std::move(container)) {}
2628 HostIrContainer (const HostIrContainer&) = delete ;
2729 HostIrContainer& operator =(const HostIrContainer&) = delete ;
2830
Original file line number Diff line number Diff line change @@ -118,7 +118,8 @@ std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
118118 RuntimeWorkSpace workspace = prepareRuntimeOrder (*staged_fusion);
119119 // Create the HostIrContainer representing the host program. Each segment of
120120 // the segmented fusion will be translated to a HostIR
121- auto hic = std::make_unique<hir::HostIrContainer>();
121+ auto hic = std::make_unique<hir::HostIrContainer>(
122+ staged_fusion->completeFusion ()->ir_container_ptr ());
122123 FusionGuard fg (hic.get ());
123124 IrCloner ir_cloner (hic.get ());
124125 auto clone =
Original file line number Diff line number Diff line change @@ -396,7 +396,8 @@ std::unique_ptr<hir::HostIrContainer> lowerSegmentedFusionToHostIr(
396396 const SegmentedFusion& segmented_fusion,
397397 const std::vector<LaunchParams>& launch_params_per_segment,
398398 std::vector<std::unique_ptr<ExecutorAbstract>>& executors) {
399- auto hic = std::make_unique<hir::HostIrContainer>();
399+ auto hic = std::make_unique<hir::HostIrContainer>(
400+ segmented_fusion.completeFusion ()->ir_container_ptr ());
400401 IrCloner ir_cloner =
401402 Fusion::copy (segmented_fusion.completeFusion (), hic.get ());
402403
Original file line number Diff line number Diff line change @@ -54,7 +54,8 @@ void CommunicationExecutor::compile(Fusion* fusion) {
5454 FusionProfiler::segment (group_id_).startCompile ();
5555 }
5656
57- host_ir_container_ = std::make_unique<hir::HostIrContainer>();
57+ host_ir_container_ =
58+ std::make_unique<hir::HostIrContainer>(fusion->ir_container_ptr ());
5859 IrCloner cloner = Fusion::copy (fusion, host_ir_container_.get ());
5960 if (fusion->isA <hir::HostIrContainer>()) {
6061 for (Expr* e : fusion->as <hir::HostIrContainer>()->topLevelExprs ()) {
You can’t perform that action at this time.
0 commit comments