Skip to content

Commit d4b7ad0

Browse files
committed
refactor: refactor SegmentedBlock inshape to ir::InputRange
Signed-off-by: Bo Wang <[email protected]>
1 parent d73dc42 commit d4b7ad0

File tree

5 files changed

+20
-21
lines changed

5 files changed

+20
-21
lines changed

core/compiler.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ void AddSegmentedBlockToGraph(
176176
return;
177177
}
178178

179+
179180
torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Module& mod, CompileSpec cfg) {
180181
// TODO: Should be doing a functional transform but need PR #31978
181182
// [jit] More robust mangling
@@ -207,7 +208,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
207208
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
208209
std::vector<ir::InputRange> input_ranges;
209210
for (auto& shape : seg_block.in_shape()) {
210-
input_ranges.push_back(ir::InputRange(util::toVec(shape)));
211+
input_ranges.push_back(ir::InputRange(shape));
211212
}
212213
// update the input ranges for each segments
213214
convert_cfg.input_ranges = input_ranges;

core/partitioning/SegmentedBlock.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#include "NvInfer.h"
66
#include "torch/csrc/jit/ir/ir.h"
7-
7+
#include "core/ir/ir.h"
88
#include "core/partitioning/PartitionInfo.h"
99

1010
namespace trtorch {
@@ -61,10 +61,10 @@ struct SegmentedBlock {
6161
bool contain_raw_value(torch::jit::Value* input) {
6262
return old_to_new_.count(input);
6363
}
64-
void register_inshape(std::vector<nvinfer1::Dims>& in_shape) {
64+
void register_inshape(std::vector<ir::InputRange>& in_shape) {
6565
in_shape_ = in_shape;
6666
}
67-
const std::vector<nvinfer1::Dims>& in_shape() const {
67+
const std::vector<ir::InputRange>& in_shape() const {
6868
return in_shape_;
6969
}
7070
void update_target(SegmentedBlockTarget new_target) {
@@ -76,12 +76,11 @@ struct SegmentedBlock {
7676

7777
private:
7878
SegmentedBlockTarget target_;
79-
std::vector<nvinfer1::Dims> in_shape_; // REVIEW: This should just be ir::InputRange
79+
std::vector<ir::InputRange> in_shape_; // REVIEW: This should just be ir::InputRange
8080
std::vector<torch::jit::Value*> inputs_;
8181
std::vector<torch::jit::Value*> outputs_;
8282
std::vector<torch::jit::Node*> nodes_;
8383
std::shared_ptr<torch::jit::Graph> g_;
84-
std::string trt_engine;
8584
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
8685
};
8786

core/partitioning/shape_analysis.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ std::vector<torch::jit::IValue> generateRandomInputs(std::vector<ir::InputRange>
1919
return random_inputs;
2020
}
2121

22+
2223
void getSegmentsOutputByRunning(
2324
SegmentedBlock& seg_block,
2425
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
@@ -86,10 +87,10 @@ void getSegmentsOutputByRunning(
8687
}
8788

8889
// set input shape for each segmented block so we wil use it in conversion process
89-
std::vector<nvinfer1::Dims> input_shape;
90+
std::vector<ir::InputRange> input_shape;
9091
for (auto& i : seg_block.raw_inputs()) {
9192
if (ivalues_maps[i].isTensor()) {
92-
input_shape.push_back(util::toDims(ivalues_maps[i].toTensor().sizes()));
93+
input_shape.push_back(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes())));
9394
}
9495
}
9596

tests/core/partitioning/test_fallback_graph_output.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) {
3131
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
3232
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
3333
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
34-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-8));
34+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
3535
}
3636

3737
TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) {
@@ -57,6 +57,7 @@ TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) {
5757
cfg.partition_info.enabled = true;
5858
cfg.partition_info.forced_fallback_operators.push_back("aten::hardtanh");
5959

60+
6061
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
6162
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
6263
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();

tests/core/partitioning/test_shape_analysis.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
#include "core/partitioning/partitioning.h"
33
#include "core/util/trt_util.h"
44
#include "gtest/gtest.h"
5-
#include "torch/csrc/jit/ir/irparser.h"
65
#include "torch/script.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
78

89
bool checkSegmentedBlockInputShape(
910
std::vector<trtorch::core::partitioning::SegmentedBlock>& segmented_blocks,
@@ -15,7 +16,7 @@ bool checkSegmentedBlockInputShape(
1516
if (cur_block_in_shapes.size() != in_shape[i].size())
1617
return false;
1718
for (size_t j = 0; j < cur_block_in_shapes.size(); ++j) {
18-
auto cur_input_shape = trtorch::core::util::toVec(cur_block_in_shapes[j]);
19+
auto cur_input_shape = trtorch::core::util::toVec(cur_block_in_shapes[j].input_shape);
1920
for (size_t k = 0; k < cur_input_shape.size(); ++k) {
2021
if (cur_input_shape[k] != in_shape[i][j][k])
2122
return false;
@@ -61,11 +62,8 @@ TEST(Partitioning, InferSequentialModelSegmentedBlockShapeCorrectly) {
6162

6263
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks =
6364
trtorch::core::partitioning::Partition(g, input_ranges, partition_info);
64-
ASSERT_TRUE(checkSegmentedBlockInputShape(
65-
segmented_blocks,
66-
{{{3, 3, 16, 16}, {32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}},
67-
{{3, 16, 16, 16}},
68-
{{3, 16, 16, 16}, {8, 16, 3, 3}, {8}}}));
65+
ASSERT_TRUE(checkSegmentedBlockInputShape(segmented_blocks, {{{3, 3, 16, 16}, {32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}},
66+
{{3, 16, 16, 16}}, {{3, 16, 16, 16}, {8, 16, 3, 3}, {8}}}));
6967
}
7068

7169
TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) {
@@ -101,11 +99,10 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) {
10199
input_ranges.push_back(trtorch::core::ir::InputRange({16, 32, 3, 3}));
102100
input_ranges.push_back(trtorch::core::ir::InputRange({16}));
103101

104-
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks =
102+
103+
std::vector<trtorch::core::partitioning::SegmentedBlock> segmented_blocks =
105104
trtorch::core::partitioning::Partition(g, input_ranges, partition_info);
106105
ASSERT_TRUE(checkSegmentedBlockInputShape(
107-
segmented_blocks,
108-
{{{3, 3, 16, 16}, {32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}},
109-
{{3, 32, 16, 16}},
110-
{{3, 32, 16, 16}, {16, 32, 3, 3}, {16}, {3, 16, 16, 16}}}));
106+
segmented_blocks, {{{3, 3, 16, 16}, {32, 3, 3, 3}, {32}, {16, 32, 3, 3}, {16}},
107+
{{3, 32, 16, 16}}, {{3, 32, 16, 16}, {16, 32, 3, 3}, {16}, {3, 16, 16, 16}}}));
111108
}

0 commit comments

Comments
 (0)