Skip to content

Commit 074c814

Browse files
committed
Move statement registration/removal from IrContainer to Fusion
Inlines registerVal, registerExpr, removeVal, and removeExpr logic directly into Fusion, eliminating the delegation to IrContainer. This consolidates the registration path after per-Fusion special values were moved from IrContainer to Fusion. Also removes vestigial friend class StatementGuard from IrContainer (it only uses public Fusion API) and adds Fusion as a friend of IrContainerPasskey so it can construct passkeys for setName() calls.
1 parent 1588d37 commit 074c814

File tree

3 files changed

+32
-78
lines changed

3 files changed

+32
-78
lines changed

csrc/fusion.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,18 @@ void Fusion::removeExpr(Expr* expr) {
346346
}
347347
}
348348

349-
ir_container()->removeExpr(expr);
349+
auto* c = ir_container();
350+
auto expr_in_deque = std::find_if(
351+
c->exprs_up_.begin(),
352+
c->exprs_up_.end(),
353+
[expr](std::unique_ptr<Expr>& expr_up) {
354+
return expr_up.get() == expr;
355+
});
356+
NVF_ERROR(
357+
expr_in_deque != c->exprs_up_.end(),
358+
"Wanted to remove an expression but its unique ptr is missing.");
359+
c->exprs_.erase(expr);
360+
c->exprs_up_.erase(expr_in_deque);
350361
}
351362

352363
void Fusion::removeVal(Val* val) {
@@ -396,7 +407,17 @@ void Fusion::removeVal(Val* val) {
396407
for (auto e : exprs_to_remove) {
397408
removeExpr(e);
398409
}
399-
ir_container()->removeVal(val);
410+
411+
auto* c = ir_container();
412+
auto val_in_deque = std::find_if(
413+
c->vals_up_.begin(),
414+
c->vals_up_.end(),
415+
[val](std::unique_ptr<Val>& val_up) { return val_up.get() == val; });
416+
NVF_ERROR(
417+
val_in_deque != c->vals_up_.end(),
418+
"Wanted to remove a value but its unique ptr is missing.");
419+
c->vals_.erase(val);
420+
c->vals_up_.erase(val_in_deque);
400421

401422
invalidateTvsAndUses();
402423
}
@@ -910,7 +931,10 @@ void Fusion::registerVal(Val* val) {
910931
val->fusion() == this, val, " was not found in the active fusion.");
911932
}
912933

913-
ir_container()->registerVal(val);
934+
auto* c = ir_container();
935+
c->vals_up_.emplace_back(val);
936+
c->vals_.insert(val);
937+
val->setName(IrContainerPasskey(), c->getValName(val->vtype()));
914938
}
915939

916940
void Fusion::registerExpr(Expr* expr) {
@@ -923,7 +947,10 @@ void Fusion::registerExpr(Expr* expr) {
923947
expr->fusion() == this, expr, " was not found in the active fusion.");
924948
}
925949

926-
ir_container()->registerExpr(expr);
950+
auto* c = ir_container();
951+
c->exprs_up_.emplace_back(expr);
952+
c->exprs_.insert(expr);
953+
expr->setName(IrContainerPasskey(), c->getExprName());
927954

928955
for (Val* input : expr->inputs()) {
929956
assertInContainer(input, "Input to expr is invalid, ");

csrc/ir/container.cpp

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -115,66 +115,6 @@ IrContainer::~IrContainer() {
115115
clear();
116116
}
117117

118-
void IrContainer::removeExpr(Expr* expr) {
119-
NVF_ERROR(
120-
exprs_.find(expr) != exprs_.end(),
121-
"Wanted to remove an expression but it doesn't exist in this container.");
122-
auto expr_in_deque = std::find_if(
123-
exprs_up_.begin(),
124-
exprs_up_.end(),
125-
[expr](std::unique_ptr<Expr>& expr_up) { return expr_up.get() == expr; });
126-
127-
NVF_ERROR(
128-
expr_in_deque != exprs_up_.end(),
129-
"Wanted to remove an expression but its unique ptr is missing.");
130-
131-
exprs_.erase(expr);
132-
exprs_up_.erase(expr_in_deque);
133-
}
134-
135-
//! Completely remove val from the fusion, break all dependencies associated
136-
//! with it
137-
void IrContainer::removeVal(Val* val) {
138-
NVF_ERROR(
139-
vals_.find(val) != vals_.end(),
140-
"Wanted to remove a value but it doesn't exist in this container.");
141-
auto val_in_deque = std::find_if(
142-
vals_up_.begin(), vals_up_.end(), [val](std::unique_ptr<Val>& val_up) {
143-
return val_up.get() == val;
144-
});
145-
146-
NVF_ERROR(
147-
val_in_deque != vals_up_.end(),
148-
"Wanted to remove a value but its unique ptr is missing.");
149-
150-
vals_.erase(val);
151-
vals_up_.erase(val_in_deque);
152-
}
153-
154-
//! Register the Val with this container
155-
void IrContainer::registerVal(Val* val) {
156-
if (inContainer(val)) {
157-
return;
158-
}
159-
160-
// Otherwise handle registration locally
161-
vals_up_.emplace_back(val);
162-
vals_.insert(val);
163-
val->setName(IrContainerPasskey(), getValName(val->vtype()));
164-
}
165-
166-
//! Register expr with this container.
167-
void IrContainer::registerExpr(Expr* expr) {
168-
if (inContainer(expr)) {
169-
return;
170-
}
171-
172-
// Otherwise handle registration locally
173-
exprs_up_.emplace_back(expr);
174-
exprs_.insert(expr);
175-
expr->setName(IrContainerPasskey(), getExprName());
176-
}
177-
178118
void IrContainer::clear() noexcept {
179119
FUSER_PERF_SCOPE("IrContainer clear");
180120
vals_.clear();

csrc/ir/container.h

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace nvfuser {
2121
// Passkey for container to register names with statements
2222
class IrContainerPasskey {
2323
friend class IrContainer;
24+
friend class Fusion;
2425

2526
private:
2627
explicit IrContainerPasskey() = default;
@@ -92,18 +93,6 @@ class IrContainer {
9293
// Let Fusion access IrContainer::clear()
9394
friend class Fusion;
9495

95-
void removeExpr(Expr* expr);
96-
97-
//! Completely remove val from the fusion, break all dependencies associated
98-
//! with it
99-
void removeVal(Val* val);
100-
101-
//! Register the Val with this container
102-
NVF_API void registerVal(Val* val);
103-
104-
//! Register expr with this container.
105-
NVF_API void registerExpr(Expr* expr);
106-
10796
StmtNameType getValName(ValType vtype) {
10897
if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) {
10998
val_type_name_map_[vtype] = 0;
@@ -117,8 +106,6 @@ class IrContainer {
117106

118107
void clear() noexcept;
119108

120-
friend class StatementGuard;
121-
122109
// Deque of unique pointer is the memory owning data structure
123110
std::deque<std::unique_ptr<Val>> vals_up_;
124111

0 commit comments

Comments
 (0)