Skip to content

Commit 9baa7bf

Browse files
committed
Migrate standalone Fusion/HostIrContainer to shared IrContainer
Four code paths created Fusions with fresh IrContainers (default constructor) instead of sharing the source Fusion's container. This broke the getCurFusion()-based traversal in iter_visitor.cpp which assumes all Vals are in the same shared container. Changes: - fusion_segmenter.cpp: Welford translation test copy now shares the source Fusion's IrContainer - host_ir/container.h: Add shared-container constructor to HostIrContainer (forwarding to Fusion's protected constructor) - communication_executor.cpp, host_ir/lowering.cpp, host_ir/lower.cpp: Use shared-container constructor for HostIrContainer creation
1 parent b8a8862 commit 9baa7bf

File tree

5 files changed

+12
-5
lines changed

5 files changed

+12
-5
lines changed

csrc/fusion_segmenter.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2803,8 +2803,10 @@ bool TranslateApplicableWelford::wouldTranslateToPersistent(
28032803
[fusion](WelfordOp* welford) { return welford->fusion() == fusion; }),
28042804
"Welfords in given vector not in the same fusion");
28052805

2806-
// Make initial `in-progress copy`
2807-
auto test_copy = std::make_unique<Fusion>();
2806+
// Make initial `in-progress copy` — share the source IrContainer so that
2807+
// traversal via getCurFusion() sees all Vals in the same container.
2808+
auto test_copy =
2809+
std::unique_ptr<Fusion>(new Fusion(fusion->ir_container_ptr()));
28082810
auto original_to_test_map = Fusion::copy(fusion, test_copy.get());
28092811

28102812
std::vector<WelfordOp*> copied_welfords;

csrc/host_ir/container.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ namespace nvfuser::hir {
2323
class HostIrContainer final : public Fusion {
2424
public:
2525
HostIrContainer() = default;
26+
explicit HostIrContainer(std::shared_ptr<IrContainer> container)
27+
: Fusion(std::move(container)) {}
2628
HostIrContainer(const HostIrContainer&) = delete;
2729
HostIrContainer& operator=(const HostIrContainer&) = delete;
2830

csrc/host_ir/lower.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
118118
RuntimeWorkSpace workspace = prepareRuntimeOrder(*staged_fusion);
119119
// Create the HostIrContainer representing the host program. Each segment of
120120
// the segmented fusion will be translated to a HostIR
121-
auto hic = std::make_unique<hir::HostIrContainer>();
121+
auto hic = std::make_unique<hir::HostIrContainer>(
122+
staged_fusion->completeFusion()->ir_container_ptr());
122123
FusionGuard fg(hic.get());
123124
IrCloner ir_cloner(hic.get());
124125
auto clone =

csrc/host_ir/lowering.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,8 @@ std::unique_ptr<hir::HostIrContainer> lowerSegmentedFusionToHostIr(
396396
const SegmentedFusion& segmented_fusion,
397397
const std::vector<LaunchParams>& launch_params_per_segment,
398398
std::vector<std::unique_ptr<ExecutorAbstract>>& executors) {
399-
auto hic = std::make_unique<hir::HostIrContainer>();
399+
auto hic = std::make_unique<hir::HostIrContainer>(
400+
segmented_fusion.completeFusion()->ir_container_ptr());
400401
IrCloner ir_cloner =
401402
Fusion::copy(segmented_fusion.completeFusion(), hic.get());
402403

csrc/runtime/communication_executor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ void CommunicationExecutor::compile(Fusion* fusion) {
5454
FusionProfiler::segment(group_id_).startCompile();
5555
}
5656

57-
host_ir_container_ = std::make_unique<hir::HostIrContainer>();
57+
host_ir_container_ =
58+
std::make_unique<hir::HostIrContainer>(fusion->ir_container_ptr());
5859
IrCloner cloner = Fusion::copy(fusion, host_ir_container_.get());
5960
if (fusion->isA<hir::HostIrContainer>()) {
6061
for (Expr* e : fusion->as<hir::HostIrContainer>()->topLevelExprs()) {

0 commit comments

Comments
 (0)