Skip to content

Commit 57f342d

Browse files
ArvindSridharnarendasan
authored andcommitted
Minor changes
Signed-off-by: Arvind Sridhar <[email protected]>
1 parent 2e04ce5 commit 57f342d

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

core/lowering/passes/module_fallback.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,21 @@ std::string unmangle_cls_name(const std::string& name) {
2121
if (mangle_pos != std::string::npos) {
2222
unmangled.erase(mangle_pos, 21);
2323
}
24+
2425
return unmangled;
2526
}
2627

27-
void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name, const std::string& method_name, std::unordered_set<std::string> forced_fallback_modules) {
28+
void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name, std::string method_name, std::unordered_set<std::string> forced_fallback_modules) {
2829
auto cls_name = unmangle_cls_name(mod.type()->name()->qualifiedName());
29-
auto g = mod.get_method(method_name).graph();
3030

31+
auto g = mod.get_method(method_name).graph();
3132
auto nodes = g->block()->nodes();
3233
bool changed_mod = false;
3334
for (const auto n : nodes) {
3435
if (n->kind() == torch::jit::prim::GetAttr) {
3536
auto out_type = unmangle_cls_name(c10::toString(n->output(0)->type()));
3637
if (forced_fallback_modules.find(out_type) != forced_fallback_modules.end()) {
37-
LOG_DEBUG("Marking module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name << " (" << cls_name << ")]");
38+
LOG_DEBUG("Notating module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name << " (" << cls_name << ")]");
3839
auto uses = n->output(0)->uses();
3940
for (const auto u : uses) {
4041
auto user = u.user;
@@ -52,7 +53,7 @@ void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name
5253
}
5354

5455
if (changed_mod) {
55-
LOG_DEBUG(*g);
56+
LOG_DEBUG("Notated graph: " << *g);
5657
}
5758

5859
for (const auto sub_mod : mod.named_children()) {
@@ -68,24 +69,24 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g) {
6869
auto n = *it;
6970
if (!mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
7071
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
71-
LOG_DEBUG("Starting to mark new segmented targeted for torch");
72-
mark.push(true);
73-
it.destroyCurrent();
72+
LOG_DEBUG("Starting to mark new segmented block targeted for torch");
73+
mark.push(true);
74+
it.destroyCurrent();
7475
}
7576
} else if (mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
76-
if(n->s(c10::Symbol::attr("compilation_edge")) == "start") {
77+
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
7778
LOG_DEBUG("Found the start of another segmented block targeted for torch while actively marking a block");
7879
mark.push(true);
7980
it.destroyCurrent();
8081
}
8182
} else if (mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
82-
if(n->s(c10::Symbol::attr("compilation_edge")) == "end") {
83+
if (n->s(c10::Symbol::attr("compilation_edge")) == "end") {
8384
LOG_DEBUG("Found the end of segmented block targeted for torch while actively marking a block");
8485
mark.pop();
8586
it.destroyCurrent();
8687
}
8788
} else if (!mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
88-
if(n->s(c10::Symbol::attr("compilation_edge")) == "end") {
89+
if (n->s(c10::Symbol::attr("compilation_edge")) == "end") {
8990
LOG_WARNING("Found the end of segmented block targeted for torch while not actively marking a block");
9091
}
9192
} else if (mark.top()) {
@@ -94,10 +95,10 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g) {
9495
}
9596
}
9697

97-
LOG_GRAPH("Post marking ops for pytorch execution: " << *g);
98+
LOG_DEBUG("After marking operations for torch fallback: " << *g);
9899
}
99100

100-
} // Namespace passes
101+
} // namespace passes
101102
} // namespace lowering
102103
} // namespace core
103-
} // namespace trtorch
104+
} // namespace trtorch

core/lowering/passes/passes.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ namespace core {
77
namespace lowering {
88
namespace passes {
99

10-
void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name, const std::string& method_name, std::unordered_set<std::string> forced_fallback_modules);
11-
10+
void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name, std::string method_name, std::unordered_set<std::string> forced_fallback_modules);
1211
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1312
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1413
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);

0 commit comments

Comments
 (0)