Skip to content

Commit c4b1ce5

Browse files
authored
Merge pull request #558 from NVIDIA/arvind/module_fallback
Module-Level Fallback
2 parents c759675 + a473bcf commit c4b1ce5

24 files changed

+501
-35
lines changed

core/lowering/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ cc_library(
1313
"drop_unused_nodes.cpp",
1414
"lowering.cpp",
1515
"register_trt_placeholder_ops.cpp",
16+
"LowerInfo.cpp"
1617
],
1718
hdrs = [
1819
"lowering.h",

core/lowering/LowerInfo.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <iostream>
2+
#include <sstream>
3+
#include <utility>
4+
5+
#include "core/lowering/lowering.h"
6+
7+
namespace trtorch {
8+
namespace core {
9+
namespace lowering {
10+
11+
std::ostream& operator<<(std::ostream& os, const LowerInfo& l) {
12+
os << "Settings requested for Lowering:" << std::endl;
13+
os << " Forced Fallback Modules: [" << std::endl;
14+
for (auto i : l.forced_fallback_modules) {
15+
os << " " << i << std::endl;
16+
}
17+
os << " ]";
18+
return os;
19+
}
20+
21+
} // namespace lowering
22+
} // namespace core
23+
} // namespace trtorch

core/lowering/lowering.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,21 @@ void LowerBlock(torch::jit::Block* b) {
2525
}
2626

2727
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
28-
passes::UnpackHardSwish(g);
2928
torch::jit::EliminateRedundantGuards(g);
3029
torch::jit::RemoveListMutation(g);
3130
torch::jit::RemoveTensorMutation(g);
3231
torch::jit::CreateFunctionalGraphs(g);
3332
torch::jit::InlineFunctionalGraphs(g);
3433
torch::jit::PeepholeOptimize(g, false);
35-
passes::EliminateExceptionOrPassPattern(g);
3634
torch::jit::FuseLinear(g);
3735
torch::jit::LowerAllTuples(g);
36+
if (!lower_info.disable_cse) {
37+
torch::jit::EliminateCommonSubexpression(g);
38+
}
39+
torch::jit::EliminateDeadCode(g);
40+
passes::MarkNodesForFallback(g, true);
41+
passes::UnpackHardSwish(g);
42+
passes::EliminateExceptionOrPassPattern(g);
3843
passes::ReduceToOperation(g);
3944
passes::RemoveContiguous(g);
4045
passes::RemoveDropout(g);
@@ -43,9 +48,6 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
4348
passes::Conv3DToConvolution(g);
4449
passes::FuseAddMMBranches(g);
4550
passes::RemoveBNDimCheck(g);
46-
if (!lower_info.disable_cse) {
47-
torch::jit::EliminateCommonSubexpression(g);
48-
}
4951
// torch::jit::UnrollLoops(g);
5052
passes::UnpackAddMM(g);
5153
// passes::UnpackBatchNorm(g);
@@ -55,23 +57,30 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
5557
passes::RemoveNOPs(g);
5658
passes::AliasOperators(g);
5759
passes::SiluToSigmoidMultipication(g);
58-
torch::jit::EliminateDeadCode(g);
5960
LOG_GRAPH(*g);
6061
}
6162

62-
torch::jit::Module LowerModule(const torch::jit::script::Module& mod) {
63-
LOG_DEBUG("Input module is being frozen by torch::jit::freeze_module");
63+
torch::jit::Module LowerModule(
64+
const torch::jit::Module& mod,
65+
std::string method_name,
66+
std::unordered_set<std::string> forced_fallback_modules) {
67+
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
68+
LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph());
6469
auto mod_ = torch::jit::freeze_module(mod);
70+
LOG_GRAPH("After freeze: " << *mod_.get_method(method_name).graph());
6571
return mod_;
6672
}
6773

6874
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
69-
const torch::jit::script::Module& mod,
75+
const torch::jit::Module& mod,
7076
std::string method_name,
71-
LowerInfo lower_info) {
72-
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod);
77+
const LowerInfo& lower_info) {
78+
LOG_DEBUG(lower_info);
79+
LOG_GRAPH("Before lowering: " << *mod.get_method(method_name).graph());
80+
std::unordered_set<std::string> forced_fallback_modules(
81+
lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end());
82+
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, forced_fallback_modules);
7383
auto g = lowered_mod.get_method(method_name).graph();
74-
LOG_GRAPH(*g);
7584

7685
LOG_GRAPH("LibTorch Lowering");
7786
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());

core/lowering/lowering.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,20 @@ struct LowerInfo {
1515
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
1616
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
1717
bool disable_cse = false;
18+
std::vector<std::string> forced_fallback_modules;
19+
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
1820
};
1921

2022
void LowerBlock(torch::jit::Block* b);
2123
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
22-
torch::jit::Module LowerModule(const torch::jit::script::Module& mod);
24+
torch::jit::Module LowerModule(
25+
const torch::jit::Module& mod,
26+
std::string method_name,
27+
std::unordered_set<std::string> forced_fallback_modules);
2328
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
24-
const torch::jit::script::Module& mod,
29+
const torch::jit::Module& mod,
2530
std::string method_name,
26-
LowerInfo lower_info);
31+
const LowerInfo& lower_info);
2732

2833
} // namespace lowering
2934
} // namespace core

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
"exception_elimination.cpp",
1616
"fuse_addmm_branches.cpp",
1717
"linear_to_addmm.cpp",
18+
"module_fallback.cpp",
1819
"op_aliasing.cpp",
1920
"reduce_to.cpp",
2021
"remove_bn_dim_check.cpp",
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#include <stack>
2+
#include <unordered_set>
3+
4+
#include "core/lowering/passes/passes.h"
5+
#include "core/util/prelude.h"
6+
7+
namespace trtorch {
8+
namespace core {
9+
namespace lowering {
10+
namespace passes {
11+
12+
std::string unmangle_cls_name(const std::string& name) {
13+
auto unmangled = name;
14+
15+
std::size_t torch_prefix = unmangled.find("__torch__");
16+
if (torch_prefix != std::string::npos) {
17+
unmangled.erase(torch_prefix, 10);
18+
}
19+
20+
std::size_t mangle_pos = unmangled.find("___torch_mangle_");
21+
if (mangle_pos != std::string::npos) {
22+
unmangled.erase(mangle_pos, 21);
23+
}
24+
25+
return unmangled;
26+
}
27+
28+
void NotateModuleForFallback(
29+
const torch::jit::Module& mod,
30+
std::string mod_name,
31+
std::string method_name,
32+
std::unordered_set<std::string> forced_fallback_modules) {
33+
auto cls_name = unmangle_cls_name(mod.type()->name()->qualifiedName());
34+
35+
auto g = mod.get_method(method_name).graph();
36+
auto nodes = g->block()->nodes();
37+
bool changed_mod = false;
38+
for (const auto n : nodes) {
39+
if (n->kind() == torch::jit::prim::GetAttr) {
40+
auto out_type = unmangle_cls_name(c10::toString(n->output(0)->type()));
41+
if (forced_fallback_modules.find(out_type) != forced_fallback_modules.end()) {
42+
LOG_DEBUG(
43+
"Notating module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name
44+
<< " (" << cls_name << ")]");
45+
auto uses = n->output(0)->uses();
46+
for (const auto u : uses) {
47+
auto user = u.user;
48+
auto delim_start_n = g->create(torch::jit::prim::Enter, 0);
49+
delim_start_n->s_(c10::Symbol::attr("compilation_edge"), "start");
50+
auto delim_end_n = g->create(torch::jit::prim::Exit, 0);
51+
delim_end_n->s_(c10::Symbol::attr("compilation_edge"), "end");
52+
delim_start_n->insertBefore(user);
53+
delim_end_n->insertAfter(user);
54+
}
55+
changed_mod = true;
56+
}
57+
}
58+
}
59+
60+
if (changed_mod) {
61+
LOG_DEBUG("Notated graph: " << *g);
62+
}
63+
64+
for (const auto sub_mod : mod.named_children()) {
65+
NotateModuleForFallback(sub_mod.value, sub_mod.name, method_name, forced_fallback_modules);
66+
}
67+
}
68+
69+
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims) {
70+
auto b = g->block();
71+
72+
std::stack<bool> mark = std::stack<bool>({false});
73+
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
74+
auto n = *it;
75+
if (!mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
76+
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
77+
LOG_DEBUG("Starting to mark new segmented block targeted for torch");
78+
mark.push(true);
79+
if (delete_delims) {
80+
it.destroyCurrent();
81+
}
82+
}
83+
} else if (mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
84+
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
85+
LOG_DEBUG("Found the start of another segmented block targeted for torch while actively marking a block");
86+
mark.push(true);
87+
if (delete_delims) {
88+
it.destroyCurrent();
89+
}
90+
}
91+
} else if (mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
92+
if (n->s(c10::Symbol::attr("compilation_edge")) == "end") {
93+
LOG_DEBUG("Found the end of segmented block targeted for torch while actively marking a block");
94+
mark.pop();
95+
if (delete_delims) {
96+
it.destroyCurrent();
97+
}
98+
}
99+
} else if (!mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
100+
if (n->s(c10::Symbol::attr("compilation_edge")) == "end") {
101+
LOG_WARNING("Found the end of segmented block targeted for torch while not actively marking a block");
102+
}
103+
} else if (mark.top()) {
104+
LOG_GRAPH("Marking " << util::node_info(n) << " to run in PyTorch");
105+
n->i_(c10::Symbol::attr("to_compile"), (int64_t) false);
106+
}
107+
}
108+
109+
LOG_DEBUG("After marking operations for torch fallback: " << *g);
110+
}
111+
112+
} // namespace passes
113+
} // namespace lowering
114+
} // namespace core
115+
} // namespace trtorch

core/lowering/passes/passes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,18 @@ namespace core {
77
namespace lowering {
88
namespace passes {
99

10+
void NotateModuleForFallback(
11+
const torch::jit::Module& mod,
12+
std::string mod_name,
13+
std::string method_name,
14+
std::unordered_set<std::string> forced_fallback_modules);
1015
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1116
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1217
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
1318
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
1419
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
1520
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
21+
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims);
1622
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
1723
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
1824
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);

core/partitioning/partitioning.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,9 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
274274
}
275275

276276
std::string node_string(n->kind().toQualString());
277-
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) {
277+
auto has_compile_attribute = n->hasAttribute(c10::Symbol::attr("to_compile"));
278+
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string) &&
279+
(!has_compile_attribute || n->i(c10::Symbol::attr("to_compile")) == (int64_t) true)) {
278280
tensorrt_nodes.push_back(n);
279281
if (tensorrt_nodes.size() >= min_block_size && !pytorch_nodes.empty()) {
280282
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);

cpp/api/include/trtorch/trtorch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,9 @@ struct TRTORCH_API CompileSpec {
577577
/// A list of names of operations that will explicitly run in PyTorch
578578
std::vector<std::string> forced_fallback_ops;
579579

580+
/// A list of names of modules that will explicitly run in PyTorch
581+
std::vector<std::string> forced_fallback_modules;
582+
580583
/**
581584
* @brief Construct a default Torch Fallback object, fallback will be off
582585
*/

cpp/api/src/compile_spec.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
375375
internal.partition_info.enabled = external.torch_fallback.enabled;
376376
internal.partition_info.min_block_size = external.torch_fallback.min_block_size;
377377
internal.partition_info.forced_fallback_operators = external.torch_fallback.forced_fallback_ops;
378+
internal.lower_info.forced_fallback_modules = external.torch_fallback.forced_fallback_modules;
378379

379380
switch (external.device.device_type) {
380381
case CompileSpec::Device::DeviceType::kDLA:

0 commit comments

Comments
 (0)