Skip to content

Commit 9e8fb38

Browse files
committed
Merge branch 'batch_norm_alt' into pytorch_1.5.0
Closes: #31 Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
2 parents 353f2d2 + fad4a10 commit 9e8fb38

File tree

3 files changed

+37
-28
lines changed

3 files changed

+37
-28
lines changed

core/compiler.cpp

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -71,47 +71,29 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit
7171

7272
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod,
7373
std::string method_name) {
74-
auto g = mod.get_method(method_name).graph();
75-
// Go through PyTorch Lowering to simplify graph and extract weight parameters
76-
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());
74+
// Go through Lowering to simplify graph and extract weight parameters
75+
auto graph_and_parameters = lowering::Lower(mod, method_name);
7776

78-
g = graph_and_parameters.first;
79-
80-
// Go through TRTorch Lowering to reformat graph to be conversion friendly
81-
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
82-
lowering::LowerGraph(g);
83-
84-
auto params = graph_and_parameters.second;
85-
auto named_params = conversion::get_named_params(g->inputs(), params);
77+
auto g = graph_and_parameters.first;
8678
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");
8779

88-
// Is this necessary?
89-
lowering::LowerBlock(g->block());
90-
9180
return conversion::VerifyConverterSupportForBlock(g->block());
9281
}
9382

9483
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
9584
std::string method_name,
9685
ExtraInfo cfg) {
97-
auto convert_cfg = std::move(cfg.convert_info);
98-
99-
auto g = mod.get_method(method_name).graph();
100-
// Go through PyTorch Lowering to simplify graph and extract weight parameters
101-
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());
102-
103-
g = graph_and_parameters.first;
10486

105-
// Go through TRTorch Lowering to reformat graph to be conversion friendly
106-
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
107-
lowering::LowerGraph(g);
87+
// Go through Lowering to simplify graph and extract weight parameters
88+
auto graph_and_parameters = lowering::Lower(mod, method_name);
10889

90+
auto convert_cfg = std::move(cfg.convert_info);
91+
auto g = graph_and_parameters.first;
10992
auto params = graph_and_parameters.second;
11093
auto named_params = conversion::get_named_params(g->inputs(), params);
94+
11195
LOG_INFO(*g << "(CompileGraph)\n");
11296

113-
// Is this necessary?
114-
lowering::LowerBlock(g->block());
11597
auto engine = ConvertBlockToEngine(g->block(), convert_cfg, named_params);
11698
return std::move(engine);
11799
}

core/lowering/lowering.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
#include "torch/csrc/jit/passes/fuse_linear.h"
21
#include "torch/csrc/jit/passes/dead_code_elimination.h"
2+
#include "torch/csrc/jit/passes/fuse_linear.h"
3+
#include "torch/csrc/jit/passes/lower_graph.h"
4+
#include "torch/csrc/jit/passes/quantization.h"
35

46
#include "core/lowering/lowering.h"
57
#include "core/lowering/irfusers/irfusers.h"
@@ -22,7 +24,29 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
2224
//irfusers::UnpackBatchNorm(g);
2325
//torch::jit::EliminateDeadCode(g);
2426
}
25-
27+
28+
void LowerModule(const torch::jit::script::Module& mod) {
29+
torch::jit::FoldConvBatchNorm2d(mod);
30+
}
31+
32+
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(const torch::jit::script::Module& mod,
33+
std::string method_name) {
34+
LowerModule(mod);
35+
auto g = mod.get_method(method_name).graph();
36+
// Go through PyTorch Lowering to simplify graph and extract weight parameters
37+
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());
38+
39+
g = graph_and_parameters.first;
40+
41+
// Go through TRTorch Lowering to reformat graph to be conversion friendly
42+
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
43+
lowering::LowerGraph(g);
44+
// Is this necessary?
45+
lowering::LowerBlock(g->block());
46+
return graph_and_parameters;
47+
}
48+
49+
2650
} // namespace lowering
2751
} // namespace core
2852
} // namespace trtorch

core/lowering/lowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ namespace lowering {
88

99
void LowerBlock(torch::jit::Block* b);
1010
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g);
11+
void LowerModule(const torch::jit::script::Module& mod);
12+
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(const torch::jit::script::Module& mod,
13+
std::string method_name);
1114

1215
} // namespace lowering
1316
} // namespace core

0 commit comments

Comments
 (0)