Skip to content

Commit 20543c6

Browse files
committed
chore: optimize minor code problems according to PR
Signed-off-by: Bo Wang <[email protected]>
1 parent c67d8f6 commit 20543c6

File tree

8 files changed

+39
-31
lines changed

8 files changed

+39
-31
lines changed

core/compiler.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "torch/custom_class.h"
1818

1919
#include "core/compiler.h"
20-
#include "core/util/prelude.h"
2120

2221
#include "core/conversion/conversion.h"
2322
#include "core/lowering/lowering.h"
@@ -31,7 +30,8 @@ void AddEngineToGraph(
3130
torch::jit::script::Module mod,
3231
std::shared_ptr<torch::jit::Graph>& g,
3332
const std::string& serialized_engine,
34-
int engine_id = 0) {
33+
int engine_id = 0,
34+
bool fallback = false) {
3535
auto engine_ptr =
3636
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + std::to_string(engine_id), serialized_engine);
3737
// Get required metadata about the engine out
@@ -96,7 +96,7 @@ void AddEngineToGraph(
9696

9797
// If there are multiple output tensors from TensorRT we wrap them in a tuple
9898
// to return, convert to tuple only when we only have 1 segmented graph
99-
if (!engine_id && unpack_node->outputs().size() > 1) {
99+
if (!fallback && unpack_node->outputs().size() > 1) {
100100
// Creates prim::TupleConstruct(<output tensors>) using outputs of the
101101
// unpack node
102102
auto return_tuple_node = g->createTuple(unpack_node->outputs());
@@ -196,10 +196,11 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
196196
// segment the graph and convert segmented TensorRT block
197197
auto segmented_blocks = partitioning::Partition(g, convert_cfg.input_ranges, cfg.partition_info);
198198
if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
199+
LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
199200
return mod;
200201
}
201202

202-
int trt_engine_id = 1;
203+
int trt_engine_id = 0;
203204
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
204205
// add global graph's input to old_to_new_g mapping
205206
for (auto input : g->inputs()) {
@@ -216,7 +217,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
216217
convert_cfg.input_ranges = input_ranges;
217218
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
218219
auto temp_g = std::make_shared<torch::jit::Graph>();
219-
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id++);
220+
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id++, true);
220221

221222
seg_block.update_graph(temp_g);
222223
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);

core/partitioning/SegmentedBlock.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ torch::jit::Value* SegmentedBlock::getOrAddInputForValue(torch::jit::Value* old_
3434
if (node->kind() == torch::jit::prim::Constant) {
3535
auto new_const = g_->createClone(node, {nullptr});
3636
g_->block()->prependNode(new_const);
37+
old_to_new_[old_value] = new_const->output();
3738
return new_const->output();
3839
}
3940
auto new_value = g_->block()->addInput();

core/partitioning/SegmentedBlock.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ struct SegmentedBlock {
7676

7777
private:
7878
SegmentedBlockTarget target_;
79-
std::vector<ir::InputRange> in_shape_; // REVIEW: This should just be ir::InputRange
79+
std::vector<ir::InputRange> in_shape_;
8080
std::vector<torch::jit::Value*> inputs_;
8181
std::vector<torch::jit::Value*> outputs_;
8282
std::vector<torch::jit::Node*> nodes_;

core/partitioning/partitioning.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "core/conversion/conversion.h"
55
#include "core/partitioning/shape_analysis.h"
66
#include "torch/csrc/jit/passes/constant_pooling.h"
7+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
78

89
namespace trtorch {
910
namespace core {
@@ -203,13 +204,17 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr
203204
}
204205
}
205206
}
206-
// erase segments which still have no output
207-
segmented_blocks.erase(
208-
std::remove_if(
209-
segmented_blocks.begin(),
210-
segmented_blocks.end(),
211-
[](SegmentedBlock& seg_block) { return seg_block.raw_outputs().empty(); }),
212-
segmented_blocks.end());
207+
std::for_each(
208+
segmented_blocks.begin(),
209+
segmented_blocks.end(),
210+
[](SegmentedBlock& seg_block) { torch::jit::EliminateDeadCode(seg_block.g()); })
211+
// erase segments which still have no output
212+
segmented_blocks.erase(
213+
std::remove_if(
214+
segmented_blocks.begin(),
215+
segmented_blocks.end(),
216+
[](SegmentedBlock& seg_block) { return seg_block.raw_outputs().empty(); }),
217+
segmented_blocks.end());
213218

214219
return;
215220
}
@@ -225,8 +230,9 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
225230
// segment the nodes
226231
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
227232
for (const auto n : nodes) {
228-
if (n->kind() == torch::jit::prim::Constant)
233+
if (n->kind() == torch::jit::prim::Constant) {
229234
continue;
235+
}
230236

231237
std::string node_string(n->kind().toQualString());
232238
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) {

core/partitioning/shape_analysis.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ void getSegmentsOutputByRunning(
6161
jit_inputs_ivalues.push_back(ivalues_maps[input].toBool());
6262
} else if (input->type()->kind() == torch::jit::TypeKind::ListType) {
6363
jit_inputs_ivalues.push_back(ivalues_maps[input].toList());
64-
} else {
65-
TRTORCH_CHECK(input->type()->kind() == torch::jit::TypeKind::TupleType, "Input for mini graph is not TupleType.");
64+
} else if (input->type()->kind() == torch::jit::TypeKind::TupleType){
6665
jit_inputs_ivalues.push_back(ivalues_maps[input].toTuple());
66+
} else {
67+
TRTORCH_THROW_ERROR("Unable to find type for value: " << input->debugName() << " to get the ivalues.\n");
6768
}
6869
}
6970

core/util/jit_util.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,20 @@ inline std::vector<int64_t> toVec(c10::IntArrayRef a) {
3333
return arr;
3434
}
3535

36+
inline c10::FunctionSchema GenerateGraphSchema(std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
37+
std::vector<c10::Argument> args;
38+
for (auto in : g->inputs()) {
39+
args.push_back(c10::Argument(in->debugName(), in->type()));
40+
}
41+
42+
std::vector<c10::Argument> returns;
43+
for (auto out : g->outputs()) {
44+
returns.push_back(c10::Argument(out->debugName(), out->type()));
45+
}
46+
47+
return c10::FunctionSchema(method_name, method_name, args, returns);
48+
}
49+
3650
} // namespace util
3751
} // namespace core
3852
} // namespace trtorch

core/util/trt_util.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -350,20 +350,6 @@ c10::optional<nvinfer1::DataType> toTRTDataType(caffe2::TypeMeta dtype) {
350350
}
351351
}
352352

353-
c10::FunctionSchema GenerateGraphSchema(std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
354-
std::vector<c10::Argument> args;
355-
for (auto in : g->inputs()) {
356-
args.push_back(c10::Argument(in->debugName(), in->type()));
357-
}
358-
359-
std::vector<c10::Argument> returns;
360-
for (auto out : g->outputs()) {
361-
returns.push_back(c10::Argument(out->debugName(), out->type()));
362-
}
363-
364-
return c10::FunctionSchema(method_name, method_name, args, returns);
365-
}
366-
367353
torch::jit::Value* getOrAddInputForValue(
368354
torch::jit::Value* old_value,
369355
std::shared_ptr<torch::jit::Graph>& graph,

core/util/trt_util.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ std::string toStr(nvinfer1::Dims d);
109109
at::ScalarType toATenDType(nvinfer1::DataType t);
110110
nvinfer1::DataType toTRTDataType(at::ScalarType t);
111111
c10::optional<nvinfer1::DataType> toTRTDataType(caffe2::TypeMeta dtype);
112-
c10::FunctionSchema GenerateGraphSchema(std::string method_name, std::shared_ptr<torch::jit::Graph>& g);
113112
torch::jit::Value* getOrAddInputForValue(
114113
torch::jit::Value* old_value,
115114
std::shared_ptr<torch::jit::Graph>& graph,

0 commit comments

Comments
 (0)