Skip to content

Commit 05bf80c

Browse files
committed
feat: rewriting param to a Constant if it's a introduced input
Signed-off-by: Bo Wang <[email protected]>
1 parent 9bb0087 commit 05bf80c

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

core/compiler.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "torch/csrc/jit/frontend/function_schema_parser.h"
1313
#include "torch/csrc/jit/ir/ir.h"
14+
#include "torch/csrc/jit/ir/constants.h"
1415
#include "torch/csrc/jit/ir/ir_views.h"
1516
#include "torch/csrc/jit/passes/graph_fuser.h"
1617
#include "torch/csrc/jit/passes/loop_unrolling.h"
@@ -28,6 +29,22 @@
2829
namespace torch_tensorrt {
2930
namespace core {
3031

32+
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph> g, std::vector<torch::jit::IValue> params) {
33+
auto input_size = g->inputs().size();
34+
auto param_it = params.rbegin();
35+
for (int i = input_size - 1; i >= 0; --i) {
36+
if (g->inputs()[i]->type() != c10::TensorType::get() && g->inputs()[i]->type()->kind() != torch::jit::TypeKind::TupleType &&
37+
g->inputs()[i]->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.rend()) {
38+
auto new_constant = torch::jit::tryInsertConstant(*g, *param_it);
39+
++param_it;
40+
if (new_constant) {
41+
g->inputs()[i]->replaceAllUsesWith(*new_constant);
42+
g->eraseInput(i);
43+
}
44+
}
45+
}
46+
}
47+
3148
void AddEngineToGraph(
3249
torch::jit::script::Module mod,
3350
std::shared_ptr<torch::jit::Graph>& g,
@@ -434,6 +451,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
434451
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
435452
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
436453
outputIsCollection)) {
454+
if (!static_params.empty()) {
455+
RewriteInputsWithParams(g, params);
456+
}
437457
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
438458
auto collection_input_ivalues_map =
439459
partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);

0 commit comments

Comments
 (0)