Skip to content

Commit 2918675

Browse files
authored
Merge pull request #619 from NVIDIA/fix_mod_fallback_methods
fix(module_fallback): Catching recursive search if method doesnt exist
2 parents 0e3532b + f94ae8f commit 2918675

File tree

2 files changed

+40
-18
lines changed

2 files changed

+40
-18
lines changed

core/lowering/lowering.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
3737
torch::jit::EliminateCommonSubexpression(g);
3838
}
3939
torch::jit::EliminateDeadCode(g);
40-
passes::MarkNodesForFallback(g, true);
40+
if (lower_info.forced_fallback_modules.size() > 0) {
41+
passes::MarkNodesForFallback(g, true);
42+
}
4143
passes::UnpackHardSwish(g);
4244
passes::EliminateExceptionOrPassPattern(g);
4345
passes::ReduceToOperation(g);
@@ -60,12 +62,13 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
6062
LOG_GRAPH(*g);
6163
}
6264

63-
torch::jit::Module LowerModule(
64-
const torch::jit::Module& mod,
65-
std::string method_name,
66-
std::unordered_set<std::string> forced_fallback_modules) {
67-
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
68-
LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph());
65+
torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, const LowerInfo& lower_info) {
66+
std::unordered_set<std::string> forced_fallback_modules(
67+
lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end());
68+
if (forced_fallback_modules.size() > 0) {
69+
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
70+
LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph());
71+
}
6972
auto mod_ = torch::jit::freeze_module(mod);
7073
LOG_GRAPH("After freeze: " << *mod_.get_method(method_name).graph());
7174
return mod_;
@@ -77,9 +80,7 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
7780
const LowerInfo& lower_info) {
7881
LOG_DEBUG(lower_info);
7982
LOG_GRAPH("Before lowering: " << *mod.get_method(method_name).graph());
80-
std::unordered_set<std::string> forced_fallback_modules(
81-
lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end());
82-
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, forced_fallback_modules);
83+
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, lower_info);
8384
auto g = lowered_mod.get_method(method_name).graph();
8485

8586
LOG_GRAPH("LibTorch Lowering");

core/lowering/passes/module_fallback.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void NotateModuleForFallback(
3939
if (n->kind() == torch::jit::prim::GetAttr) {
4040
auto out_type = unmangle_cls_name(c10::toString(n->output(0)->type()));
4141
if (forced_fallback_modules.find(out_type) != forced_fallback_modules.end()) {
42-
LOG_DEBUG(
42+
LOG_GRAPH(
4343
"Notating module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name
4444
<< " (" << cls_name << ")]");
4545
auto uses = n->output(0)->uses();
@@ -58,11 +58,32 @@ void NotateModuleForFallback(
5858
}
5959

6060
if (changed_mod) {
61-
LOG_DEBUG("Notated graph: " << *g);
61+
LOG_GRAPH("Notated graph: " << *g);
6262
}
6363

64-
for (const auto sub_mod : mod.named_children()) {
65-
NotateModuleForFallback(sub_mod.value, sub_mod.name, method_name, forced_fallback_modules);
64+
if (mod.named_children().size() > 0) {
65+
for (const auto n : nodes) {
66+
std::string sub_method_name = "";
67+
if (n->kind() == torch::jit::prim::CallMethod) {
68+
sub_method_name = n->s(c10::Symbol::attr("name"));
69+
auto sub_mod_val = n->input(0);
70+
auto sub_mod_src_n = sub_mod_val->node();
71+
if (!sub_mod_src_n->hasAttributeS("name")) {
72+
LOG_GRAPH("Node: " << util::node_info(sub_mod_src_n) << " manages a module with no name, skipping");
73+
break;
74+
}
75+
auto sub_mod_name = sub_mod_src_n->s(c10::Symbol::attr("name"));
76+
for (const auto sub_mod : mod.named_children()) {
77+
// Theres probably a way to directly access the module we care about
78+
if (sub_mod.name == sub_mod_name) {
79+
LOG_GRAPH(
80+
"Looking at <module>.<method>() next: " << sub_mod_name << "." << sub_method_name
81+
<< "() (lowering.passes.NotateModuleForFallback)");
82+
NotateModuleForFallback(sub_mod.value, sub_mod.name, sub_method_name, forced_fallback_modules);
83+
}
84+
}
85+
}
86+
}
6687
}
6788
}
6889

@@ -74,23 +95,23 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
7495
auto n = *it;
7596
if (!mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
7697
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
77-
LOG_DEBUG("Starting to mark new segmented block targeted for torch");
98+
LOG_GRAPH("Starting to mark new segmented block targeted for torch");
7899
mark.push(true);
79100
if (delete_delims) {
80101
it.destroyCurrent();
81102
}
82103
}
83104
} else if (mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
84105
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
85-
LOG_DEBUG("Found the start of another segmented block targeted for torch while actively marking a block");
106+
LOG_GRAPH("Found the start of another segmented block targeted for torch while actively marking a block");
86107
mark.push(true);
87108
if (delete_delims) {
88109
it.destroyCurrent();
89110
}
90111
}
91112
} else if (mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
92113
if (n->s(c10::Symbol::attr("compilation_edge")) == "end") {
93-
LOG_DEBUG("Found the end of segmented block targeted for torch while actively marking a block");
114+
LOG_GRAPH("Found the end of segmented block targeted for torch while actively marking a block");
94115
mark.pop();
95116
if (delete_delims) {
96117
it.destroyCurrent();
@@ -106,7 +127,7 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
106127
}
107128
}
108129

109-
LOG_DEBUG("After marking operations for torch fallback: " << *g);
130+
LOG_GRAPH("After marking operations for torch fallback: " << *g);
110131
}
111132

112133
} // namespace passes

0 commit comments

Comments
 (0)