Skip to content

Commit f2ef0eb

Browse files
committed
refactor: refactor RewriteInputsWithParams() to a lowering pass
Signed-off-by: Bo Wang <[email protected]>
1 parent ab977f5 commit f2ef0eb

File tree

6 files changed

+47
-31
lines changed

6 files changed

+47
-31
lines changed

core/compiler.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "ATen/core/jit_type.h"
1111

1212
#include "torch/csrc/jit/frontend/function_schema_parser.h"
13-
#include "torch/csrc/jit/ir/constants.h"
1413
#include "torch/csrc/jit/ir/ir.h"
1514
#include "torch/csrc/jit/ir/ir_views.h"
1615
#include "torch/csrc/jit/passes/graph_fuser.h"
@@ -29,31 +28,6 @@
2928
namespace torch_tensorrt {
3029
namespace core {
3130

32-
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph> g, std::vector<torch::jit::IValue> params) {
33-
auto input_size = g->inputs().size();
34-
auto param_it = params.rbegin();
35-
for (int i = input_size - 1; i >= 0; --i) {
36-
if (g->inputs()[i]->type() != c10::TensorType::get() &&
37-
g->inputs()[i]->type()->kind() != torch::jit::TypeKind::TupleType &&
38-
g->inputs()[i]->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.rend()) {
39-
auto val = *param_it;
40-
if (val.isTensor()) {
41-
at::Tensor val_tensor = val.toTensor();
42-
if (val_tensor.requires_grad()) {
43-
val_tensor.set_requires_grad(false);
44-
val = val_tensor;
45-
}
46-
}
47-
auto new_constant = torch::jit::tryInsertConstant(*g, val);
48-
++param_it;
49-
if (new_constant) {
50-
g->inputs()[i]->replaceAllUsesWith(*new_constant);
51-
g->eraseInput(i);
52-
}
53-
}
54-
}
55-
}
56-
5731
void AddEngineToGraph(
5832
torch::jit::script::Module mod,
5933
std::shared_ptr<torch::jit::Graph>& g,
@@ -460,9 +434,6 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
460434
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
461435
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
462436
outputIsCollection)) {
463-
if (!static_params.empty()) {
464-
RewriteInputsWithParams(g, params);
465-
}
466437
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
467438
auto collection_input_ivalues_map =
468439
partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);

core/lowering/lowering.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ void LowerBlock(torch::jit::Block* b) {
2525
DropUnusedNodes(b);
2626
}
2727

28-
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
28+
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params, LowerInfo lower_info) {
2929
torch::jit::EliminateRedundantGuards(g);
3030
torch::jit::RemoveListMutation(g);
3131
torch::jit::RemoveTensorMutation(g);
@@ -66,6 +66,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
6666
passes::SiluToSigmoidMultipication(g);
6767
passes::RemoveSingleUse0DTensors(g);
6868
passes::RemoveUnnecessaryCasts(g);
69+
passes::RewriteInputsWithParams(g, params);
6970
LOG_GRAPH(*g);
7071
}
7172

@@ -99,7 +100,7 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
99100
// In quantization aware trained (QAT) models, weights are passed through quantize and
100101
// dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models.
101102
LOG_GRAPH("Torch-TensorRT.TorchScript Graph Lowering");
102-
lowering::LowerGraph(graph_and_ivalues.first, lower_info);
103+
lowering::LowerGraph(graph_and_ivalues.first, graph_and_ivalues.second, lower_info);
103104

104105
// Is this necessary?
105106
// lowering::LowerBlock(g->block());

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ cc_library(
2727
"remove_dropout.cpp",
2828
"remove_nops.cpp",
2929
"remove_unnecessary_casts.cpp",
30+
"rewrite_inputs_with_params.cpp",
3031
"silu_to_sigmoid_multiplication.cpp",
3132
"unpack_addmm.cpp",
3233
"unpack_batch_norm.cpp",

core/lowering/passes/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ target_sources(${lib_name}
2222
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp"
2323
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_var.cpp"
2424
"${CMAKE_CURRENT_SOURCE_DIR}/view_to_reshape.cpp"
25+
"${CMAKE_CURRENT_SOURCE_DIR}/rewrite_inputs_with_params.cpp"
2526
)
2627

2728
set(HEADER_FILES

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
3838
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
3939
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
4040
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
41+
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params);
4142

4243
} // namespace passes
4344
} // namespace lowering
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include "torch/csrc/jit/ir/constants.h"
2+
#include "core/util/prelude.h"
3+
4+
5+
namespace torch_tensorrt {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
11+
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params) {
12+
auto input_size = g->inputs().size();
13+
auto param_it = params.rbegin();
14+
for (int i = input_size - 1; i >= 0; --i) {
15+
if (g->inputs()[i]->type() != c10::TensorType::get() &&
16+
g->inputs()[i]->type()->kind() != torch::jit::TypeKind::TupleType &&
17+
g->inputs()[i]->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.rend()) {
18+
auto val = *param_it;
19+
if (val.isTensor()) {
20+
at::Tensor val_tensor = val.toTensor();
21+
if (val_tensor.requires_grad()) {
22+
val_tensor.set_requires_grad(false);
23+
val = val_tensor;
24+
}
25+
}
26+
auto new_constant = torch::jit::tryInsertConstant(*g, val);
27+
++param_it;
28+
if (new_constant) {
29+
g->inputs()[i]->replaceAllUsesWith(*new_constant);
30+
g->eraseInput(i);
31+
// erase an iterator, should be safe
32+
params.erase(param_it.base());
33+
}
34+
}
35+
}
36+
}
37+
38+
} // namespace passes
39+
} // namespace lowering
40+
} // namespace core
41+
} // namespace torch_tensorrt

0 commit comments

Comments
 (0)