|
11 | 11 |
|
12 | 12 | #include "torch/csrc/jit/frontend/function_schema_parser.h"
|
13 | 13 | #include "torch/csrc/jit/ir/ir.h"
|
| 14 | +#include "torch/csrc/jit/ir/constants.h" |
14 | 15 | #include "torch/csrc/jit/ir/ir_views.h"
|
15 | 16 | #include "torch/csrc/jit/passes/graph_fuser.h"
|
16 | 17 | #include "torch/csrc/jit/passes/loop_unrolling.h"
|
|
28 | 29 | namespace torch_tensorrt {
|
29 | 30 | namespace core {
|
30 | 31 |
|
| 32 | +void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph> g, std::vector<torch::jit::IValue> params) { |
| 33 | + auto input_size = g->inputs().size(); |
| 34 | + auto param_it = params.rbegin(); |
| 35 | + for (int i = input_size - 1; i >= 0; --i) { |
| 36 | + if (g->inputs()[i]->type() != c10::TensorType::get() && g->inputs()[i]->type()->kind() != torch::jit::TypeKind::TupleType && |
| 37 | + g->inputs()[i]->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.rend()) { |
| 38 | + auto new_constant = torch::jit::tryInsertConstant(*g, *param_it); |
| 39 | + ++param_it; |
| 40 | + if (new_constant) { |
| 41 | + g->inputs()[i]->replaceAllUsesWith(*new_constant); |
| 42 | + g->eraseInput(i); |
| 43 | + } |
| 44 | + } |
| 45 | + } |
| 46 | +} |
| 47 | + |
31 | 48 | void AddEngineToGraph(
|
32 | 49 | torch::jit::script::Module mod,
|
33 | 50 | std::shared_ptr<torch::jit::Graph>& g,
|
@@ -434,6 +451,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
|
434 | 451 | (!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
|
435 | 452 | cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
|
436 | 453 | outputIsCollection)) {
|
| 454 | + if (!static_params.empty()) { |
| 455 | + RewriteInputsWithParams(g, params); |
| 456 | + } |
437 | 457 | std::unordered_map<torch::jit::Node*, int> fallback_nodes;
|
438 | 458 | auto collection_input_ivalues_map =
|
439 | 459 | partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);
|
|
0 commit comments