Skip to content

Commit 318520d

Browse files
ArvindSridharnarendasan
authored andcommitted
Add second marking pass
Signed-off-by: Arvind Sridhar <[email protected]>
1 parent 57f342d commit 318520d

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

core/lowering/passes/module_fallback.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name
6161
}
6262
}
6363

64-
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g) {
64+
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims) {
6565
auto b = g->block();
6666

6767
std::stack<bool> mark = std::stack<bool>({false});
@@ -71,19 +71,25 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g) {
7171
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
7272
LOG_DEBUG("Starting to mark new segmented block targeted for torch");
7373
mark.push(true);
74-
it.destroyCurrent();
74+
if (delete_delims) {
75+
it.destroyCurrent();
76+
}
7577
}
7678
} else if (mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
7779
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
7880
LOG_DEBUG("Found the start of another segmented block targeted for torch while actively marking a block");
7981
mark.push(true);
80-
it.destroyCurrent();
82+
if (delete_delims) {
83+
it.destroyCurrent();
84+
}
8185
}
8286
} else if (mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
8387
if (n->s(c10::Symbol::attr("compilation_edge")) == "end") {
8488
LOG_DEBUG("Found the end of segmented block targeted for torch while actively marking a block");
8589
mark.pop();
86-
it.destroyCurrent();
90+
if (delete_delims) {
91+
it.destroyCurrent();
92+
}
8793
}
8894
} else if (!mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
8995
if (n->s(c10::Symbol::attr("compilation_edge")) == "end") {

core/lowering/passes/passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
1414
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
1515
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
1616
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
17-
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g);
17+
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims);
1818
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
1919
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
2020
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);

0 commit comments

Comments
 (0)