@@ -25,6 +25,7 @@ void LowerBlock(torch::jit::Block* b) {
25
25
}
26
26
27
27
void LowerGraph (std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
28
+ passes::MarkNodesForFallback (g, false );
28
29
passes::UnpackHardSwish (g);
29
30
torch::jit::EliminateRedundantGuards (g);
30
31
torch::jit::RemoveListMutation (g);
@@ -56,22 +57,27 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
56
57
passes::AliasOperators (g);
57
58
passes::SiluToSigmoidMultipication (g);
58
59
torch::jit::EliminateDeadCode (g);
60
+ passes::MarkNodesForFallback (g, true );
59
61
LOG_GRAPH (*g);
60
62
}
61
63
62
- torch::jit::Module LowerModule (const torch::jit::script::Module& mod) {
63
- LOG_DEBUG (" Input module is being frozen by torch::jit::freeze_module" );
64
+ torch::jit::Module LowerModule (const torch::jit::Module& mod, std::string method_name, std::unordered_set<std::string> forced_fallback_modules) {
65
+ passes::NotateModuleForFallback (mod, " " , method_name, forced_fallback_modules);
66
+ LOG_GRAPH (" After MLF notation pass: " << *mod.get_method (method_name).graph ());
64
67
auto mod_ = torch::jit::freeze_module (mod);
68
+ LOG_GRAPH (" After freeze: " << *mod_.get_method (method_name).graph ());
65
69
return mod_;
66
70
}
67
71
68
72
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower (
69
- const torch::jit::script::Module& mod,
70
- std::string method_name,
71
- LowerInfo lower_info) {
72
- auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule (mod);
73
+ const torch::jit::Module& mod,
74
+ std::string method_name, const LowerInfo& lower_info) {
75
+ LOG_DEBUG (lower_info);
76
+ LOG_GRAPH (" Before lowering: " << *mod.get_method (method_name).graph ());
77
+ std::unordered_set<std::string> forced_fallback_modules (
78
+ lower_info.forced_fallback_modules .begin (), lower_info.forced_fallback_modules .end ());
79
+ auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule (mod, method_name, forced_fallback_modules);
73
80
auto g = lowered_mod.get_method (method_name).graph ();
74
- LOG_GRAPH (*g);
75
81
76
82
LOG_GRAPH (" LibTorch Lowering" );
77
83
auto graph_and_ivalues = torch::jit::LowerGraph (*g, lowered_mod._ivalue ());
0 commit comments