Skip to content

Commit 02b23cb

Browse files
committed
Add notation logic
Signed-off-by: Arvind Sridhar <[email protected]>
1 parent 744b417 commit 02b23cb

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

core/lowering/lowering.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void LowerBlock(torch::jit::Block* b) {
2525
}
2626

2727
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
28+
passes::MarkNodesForFallback(g, false);
2829
passes::UnpackHardSwish(g);
2930
torch::jit::EliminateRedundantGuards(g);
3031
torch::jit::RemoveListMutation(g);
@@ -56,22 +57,27 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
5657
passes::AliasOperators(g);
5758
passes::SiluToSigmoidMultipication(g);
5859
torch::jit::EliminateDeadCode(g);
60+
passes::MarkNodesForFallback(g, true);
5961
LOG_GRAPH(*g);
6062
}
6163

62-
torch::jit::Module LowerModule(const torch::jit::script::Module& mod) {
63-
LOG_DEBUG("Input module is being frozen by torch::jit::freeze_module");
64+
torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, std::unordered_set<std::string> forced_fallback_modules) {
65+
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
66+
LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph());
6467
auto mod_ = torch::jit::freeze_module(mod);
68+
LOG_GRAPH("After freeze: " << *mod_.get_method(method_name).graph());
6569
return mod_;
6670
}
6771

6872
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
69-
const torch::jit::script::Module& mod,
70-
std::string method_name,
71-
LowerInfo lower_info) {
72-
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod);
73+
const torch::jit::Module& mod,
74+
std::string method_name, const LowerInfo& lower_info) {
75+
LOG_DEBUG(lower_info);
76+
LOG_GRAPH("Before lowering: " << *mod.get_method(method_name).graph());
77+
std::unordered_set<std::string> forced_fallback_modules(
78+
lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end());
79+
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, forced_fallback_modules);
7380
auto g = lowered_mod.get_method(method_name).graph();
74-
LOG_GRAPH(*g);
7581

7682
LOG_GRAPH("LibTorch Lowering");
7783
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());

core/lowering/lowering.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@ struct LowerInfo {
1515
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
1616
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
1717
bool disable_cse = false;
18-
};
18+
std::vector<std::string> forced_fallback_modules;
19+
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
20+
}
1921

2022
void LowerBlock(torch::jit::Block* b);
2123
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
22-
torch::jit::Module LowerModule(const torch::jit::script::Module& mod);
24+
torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, std::unordered_set<std::string> forced_fallback_modules);
2325
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
24-
const torch::jit::script::Module& mod,
25-
std::string method_name,
26-
LowerInfo lower_info);
26+
const torch::jit::Module& mod,
27+
std::string method_name, const LowerInfo& lower_info);
2728

2829
} // namespace lowering
2930
} // namespace core

0 commit comments

Comments
 (0)