@@ -21,20 +21,21 @@ std::string unmangle_cls_name(const std::string& name) {
21
21
if (mangle_pos != std::string::npos) {
22
22
unmangled.erase (mangle_pos, 21 );
23
23
}
24
+
24
25
return unmangled;
25
26
}
26
27
27
- void NotateModuleForFallback (const torch::jit::Module& mod, std::string mod_name, const std::string& method_name, std::unordered_set<std::string> forced_fallback_modules) {
28
+ void NotateModuleForFallback (const torch::jit::Module& mod, std::string mod_name, std::string method_name, std::unordered_set<std::string> forced_fallback_modules) {
28
29
auto cls_name = unmangle_cls_name (mod.type ()->name ()->qualifiedName ());
29
- auto g = mod.get_method (method_name).graph ();
30
30
31
+ auto g = mod.get_method (method_name).graph ();
31
32
auto nodes = g->block ()->nodes ();
32
33
bool changed_mod = false ;
33
34
for (const auto n : nodes) {
34
35
if (n->kind () == torch::jit::prim::GetAttr) {
35
36
auto out_type = unmangle_cls_name (c10::toString (n->output (0 )->type ()));
36
37
if (forced_fallback_modules.find (out_type) != forced_fallback_modules.end ()) {
37
- LOG_DEBUG (" Marking module for fallback: " << n->s (c10::attr::name) << " (" << out_type << " ) [owner: " << mod_name << " (" << cls_name << " )]" );
38
+ LOG_DEBUG (" Notating module for fallback: " << n->s (c10::attr::name) << " (" << out_type << " ) [owner: " << mod_name << " (" << cls_name << " )]" );
38
39
auto uses = n->output (0 )->uses ();
39
40
for (const auto u : uses) {
40
41
auto user = u.user ;
@@ -52,7 +53,7 @@ void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name
52
53
}
53
54
54
55
if (changed_mod) {
55
- LOG_DEBUG (*g);
56
+ LOG_DEBUG (" Notated graph: " << *g);
56
57
}
57
58
58
59
for (const auto sub_mod : mod.named_children ()) {
@@ -68,24 +69,24 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g) {
68
69
auto n = *it;
69
70
if (!mark.top () && n->kind () == torch::jit::prim::Enter && n->hasAttributeS (" compilation_edge" )) {
70
71
if (n->s (c10::Symbol::attr (" compilation_edge" )) == " start" ) {
71
- LOG_DEBUG (" Starting to mark new segmented targeted for torch" );
72
- mark.push (true );
73
- it.destroyCurrent ();
72
+ LOG_DEBUG (" Starting to mark new segmented block targeted for torch" );
73
+ mark.push (true );
74
+ it.destroyCurrent ();
74
75
}
75
76
} else if (mark.top () && n->kind () == torch::jit::prim::Enter && n->hasAttributeS (" compilation_edge" )) {
76
- if (n->s (c10::Symbol::attr (" compilation_edge" )) == " start" ) {
77
+ if (n->s (c10::Symbol::attr (" compilation_edge" )) == " start" ) {
77
78
LOG_DEBUG (" Found the start of another segmented block targeted for torch while actively marking a block" );
78
79
mark.push (true );
79
80
it.destroyCurrent ();
80
81
}
81
82
} else if (mark.top () && n->kind () == torch::jit::prim::Exit && n->hasAttributeS (" compilation_edge" )) {
82
- if (n->s (c10::Symbol::attr (" compilation_edge" )) == " end" ) {
83
+ if (n->s (c10::Symbol::attr (" compilation_edge" )) == " end" ) {
83
84
LOG_DEBUG (" Found the end of segmented block targeted for torch while actively marking a block" );
84
85
mark.pop ();
85
86
it.destroyCurrent ();
86
87
}
87
88
} else if (!mark.top () && n->kind () == torch::jit::prim::Exit && n->hasAttributeS (" compilation_edge" )) {
88
- if (n->s (c10::Symbol::attr (" compilation_edge" )) == " end" ) {
89
+ if (n->s (c10::Symbol::attr (" compilation_edge" )) == " end" ) {
89
90
LOG_WARNING (" Found the end of segmented block targeted for torch while not actively marking a block" );
90
91
}
91
92
} else if (mark.top ()) {
@@ -94,10 +95,10 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g) {
94
95
}
95
96
}
96
97
97
- LOG_GRAPH ( " Post marking ops for pytorch execution : " << *g);
98
+ LOG_DEBUG ( " After marking operations for torch fallback : " << *g);
98
99
}
99
100
100
- } // Namespace passes
101
+ } // namespace passes
101
102
} // namespace lowering
102
103
} // namespace core
103
- } // namespace trtorch
104
+ } // namespace trtorch
0 commit comments