Skip to content

Commit c1934c1

Browse files
committed
chore: improve some minor code problems
Signed-off-by: Bo Wang <[email protected]>
1 parent f722035 commit c1934c1

File tree

4 files changed

+21
-4
lines changed

4 files changed

+21
-4
lines changed

core/compiler.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
201201

202202
int trt_engine_id = 1;
203203
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
204+
// add global graph's input to old_to_new_g mapping
205+
for (auto input : g->inputs()) {
206+
util::getOrAddInputForValue(input, new_g, old_to_new_g);
207+
}
204208
for (auto& seg_block : segmented_blocks) {
205209
LOG_INFO(*g << "(MiniGraphInSegmentedBlock)\n");
206210
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {

core/util/trt_util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ at::ScalarType toATenDType(nvinfer1::DataType t);
110110
nvinfer1::DataType toTRTDataType(at::ScalarType t);
111111
c10::optional<nvinfer1::DataType> toTRTDataType(caffe2::TypeMeta dtype);
112112
c10::FunctionSchema GenerateGraphSchema(std::string method_name, std::shared_ptr<torch::jit::Graph>& g);
113+
torch::jit::Value* getOrAddInputForValue(
114+
torch::jit::Value* old_value,
115+
std::shared_ptr<torch::jit::Graph>& graph,
116+
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new);
113117
torch::jit::Node* cloneNode(
114118
torch::jit::Node* node,
115119
std::shared_ptr<torch::jit::Graph>& graph,

tests/core/partitioning/BUILD

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,20 @@ partitioning_test(
2929
name = "test_stitched_graph",
3030
)
3131

32-
partitioning_test(
32+
cc_test(
3333
name = "test_fallback_graph_output",
34+
srcs = ["test_fallback_graph_output.cpp"],
35+
deps = [
36+
"//tests/util",
37+
"//core",
38+
"@googletest//:gtest_main",
39+
] + select({
40+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
41+
"//conditions:default": ["@libtorch//:libtorch"],
42+
}),
43+
data = [
44+
":jit_models"
45+
]
3446
)
3547

3648
test_suite(

tests/core/partitioning/partitioning_test.bzl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,5 @@ def partitioning_test(name, visibility=None):
1111
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
1212
"//conditions:default": ["@libtorch//:libtorch"],
1313
}),
14-
data = [
15-
":jit_models"
16-
],
1714
timeout="short"
1815
)

0 commit comments

Comments
 (0)