Skip to content

Commit 375bdfc

Browse files
committed
chore: Clean up and refactor code
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 86982e1 commit 375bdfc

File tree

8 files changed

+95
-48
lines changed

8 files changed

+95
-48
lines changed

core/compiler.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ partitioning::GraphAndMapping BuildHybridGraph(
137137
auto partitioning_info = cfg.partitioning_info;
138138

139139
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
140-
// auto collection_input_ivalues_map =
141-
// partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);
142140
partitioning_ctx.input_types_map = first_use_types;
143141
partitioning::partition(&partitioning_ctx);
144142

@@ -151,16 +149,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
151149
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
152150

153151
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
154-
// auto shapes = seg_block.in_shapes();
155-
// auto types = seg_block.in_types();
156-
// std::vector<ir::Input> inputs;
157-
// for (size_t i = 0; i < shapes.size(); i++) {
158-
// auto in = ir::Input(shapes[i]);
159-
// in.dtype = util::ScalarTypeToTRTDataType(types[i]);
160-
// inputs.push_back(in);
161-
// }
162152
auto inputs = seg_block.construct_inputs_spec();
163-
LOG_DEBUG("============ INPUTS: " << inputs);
164153
// update the input ranges for each segments
165154
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
166155

core/partitioning/partitioning.cpp

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,20 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
436436
return;
437437
}
438438

439+
bool isInputDynamic(PartitioningCtx* ctx) {
440+
// Check if inputs have dynamic shapes
441+
bool input_is_dynamic = true;
442+
auto inputs_map = ctx->settings.collection_input_spec_map;
443+
for (auto inputs : inputs_map) {
444+
for (auto input : inputs.second) {
445+
if (!input.input_is_dynamic) {
446+
input_is_dynamic = false;
447+
}
448+
}
449+
}
450+
return input_is_dynamic;
451+
}
452+
439453
void partition(PartitioningCtx* ctx) {
440454
LOG_DEBUG(ctx->settings);
441455

@@ -446,24 +460,33 @@ void partition(PartitioningCtx* ctx) {
446460

447461
// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
448462
// resolve nonTensor inputs/outputs
463+
LOG_DEBUG("Resolving non-tensor inputs for segmented blocks");
449464
resolveTRTNonTensorInputs(ctx, block);
450465

451466
// register input/output torch::jit::Value for segmented graphs
452467
LOG_DEBUG("Registering input/output torch::jit::Value for segmented graphs");
453468
registerSegmentsOutputs(ctx, block);
454469

455-
// run shape analysis on each segmented block
456-
auto min_input_ivalues_map =
457-
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "min");
458-
auto opt_input_ivalues_map =
459-
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "opt");
460-
auto max_input_ivalues_map =
461-
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "max");
462-
463-
runShapeAnalysis(ctx, block, min_input_ivalues_map, "min");
464-
runShapeAnalysis(ctx, block, opt_input_ivalues_map, "opt");
465-
runShapeAnalysis(ctx, block, max_input_ivalues_map, "max");
466-
470+
// Incase of dynamic shape inputs, run shape analysis on each segmented block for min/opt/max ranges and register
471+
// output shapes for each block accordingly
472+
if (isInputDynamic(ctx)) {
473+
LOG_DEBUG("Performing shape analysis for segmented blocks using min/opt/max shapes for inputs");
474+
auto min_input_ivalues_map =
475+
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "min");
476+
auto opt_input_ivalues_map =
477+
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "opt");
478+
auto max_input_ivalues_map =
479+
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "max");
480+
481+
runShapeAnalysis(ctx, block, min_input_ivalues_map, "min");
482+
runShapeAnalysis(ctx, block, opt_input_ivalues_map, "opt");
483+
runShapeAnalysis(ctx, block, max_input_ivalues_map, "max");
484+
} else {
485+
LOG_DEBUG("Performing shape analysis for segmented blocks using static shapes for inputs");
486+
auto opt_input_ivalues_map =
487+
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "opt");
488+
runShapeAnalysis(ctx, block, opt_input_ivalues_map, "opt");
489+
}
467490
}
468491
}
469492

core/partitioning/partitioning.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,16 @@ typedef std::unordered_map<const torch::jit::Value*, c10::IValue> ExampleIValues
1818
typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
1919
GraphAndMapping;
2020

21-
ExampleIValues generateRandomInputs(ir::CollectionInputSpecMap& input_ranges, ir::CollectionTypeMap& input_types, const std::string& shape_mode = std::string("opt"));
22-
23-
void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& ivalues_maps, const std::string& shape_mode);
21+
ExampleIValues generateRandomInputs(
22+
ir::CollectionInputSpecMap& input_ranges,
23+
ir::CollectionTypeMap& input_types,
24+
const std::string& shape_mode = std::string("opt"));
25+
26+
void runShapeAnalysis(
27+
PartitioningCtx* ctx,
28+
torch::jit::Block* block,
29+
ExampleIValues& ivalues_maps,
30+
const std::string& shape_mode);
2431

2532
void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block);
2633

core/partitioning/segmentedblock/SegmentedBlock.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,14 @@ torch::jit::Value* SegmentedBlock::getOrAddInputForValue(torch::jit::Value* old_
5959

6060
std::vector<ir::Input> SegmentedBlock::construct_inputs_spec() const {
6161
std::vector<ir::Input> inputs;
62-
if (min_shapes_.size() == opt_shapes_.size() && opt_shapes_.size() == max_shapes_.size()){
63-
LOG_DEBUG("====== IS DYNAMIC ====");
64-
for (uint64_t i=0; i < opt_shapes_.size(); i++){
62+
if (min_shapes_.size() == opt_shapes_.size() && opt_shapes_.size() == max_shapes_.size()) {
63+
for (uint64_t i = 0; i < opt_shapes_.size(); i++) {
6564
auto in = ir::Input(min_shapes_[i], opt_shapes_[i], max_shapes_[i]);
6665
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
6766
inputs.push_back(in);
6867
}
6968
} else {
70-
LOG_DEBUG("====== IS STATIC ====");
71-
for (uint64_t i=0; i < opt_shapes_.size(); i++){
69+
for (uint64_t i = 0; i < opt_shapes_.size(); i++) {
7270
auto in = ir::Input(opt_shapes_[i]);
7371
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
7472
inputs.push_back(in);

core/partitioning/segmentedblock/SegmentedBlock.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ struct SegmentedBlock {
7474
return old_to_new_.count(input);
7575
}
7676
void register_inshapes(std::vector<std::vector<int64_t>>& in_shapes, const std::string& shape_mode) {
77-
if (shape_mode.compare("min") == 0){
77+
if (shape_mode.compare("min") == 0) {
7878
min_shapes_ = in_shapes;
79-
} else if(shape_mode.compare("opt") == 0){
79+
} else if (shape_mode.compare("opt") == 0) {
8080
opt_shapes_ = in_shapes;
81-
} else{
81+
} else {
8282
max_shapes_ = in_shapes;
8383
}
8484
}

core/partitioning/shape_analysis.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@ namespace torch_tensorrt {
99
namespace core {
1010
namespace partitioning {
1111

12-
at::Tensor generateSingleInput(ir::Input& input, c10::optional<at::ScalarType>& type_opt, const std::string& shape_mode) {
12+
at::Tensor generateSingleInput(
13+
ir::Input& input,
14+
c10::optional<at::ScalarType>& type_opt,
15+
const std::string& shape_mode) {
1316
nvinfer1::Dims input_shape = input.input_shape;
14-
if (input.input_is_dynamic){
15-
if (shape_mode.compare("min") == 0){
17+
if (input.input_is_dynamic) {
18+
if (shape_mode.compare("min") == 0) {
1619
input_shape = input.min;
17-
} else if(shape_mode.compare("opt") == 0){
20+
} else if (shape_mode.compare("opt") == 0) {
1821
input_shape = input.opt;
1922
} else {
2023
input_shape = input.max;
@@ -188,7 +191,11 @@ void getSegmentsOutputByRunning(
188191
seg_block.register_intypes(input_types);
189192
}
190193

191-
void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map, const std::string& shape_mode) {
194+
void runShapeAnalysis(
195+
PartitioningCtx* ctx,
196+
torch::jit::Block* block,
197+
ExampleIValues& example_tensor_map,
198+
const std::string& shape_mode) {
192199
// register every segment's input shape, and it's running output IValues
193200
for (auto& seg_block : ctx->partitioned_blocks[block]) {
194201
torch::jit::ConstantPooling(seg_block.g());

tests/cpp/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ test_suite(
1515
":test_collections",
1616
":test_compiled_modules",
1717
":test_default_input_types",
18+
":test_dynamic_fallback",
1819
":test_example_tensors",
1920
":test_module_fallback",
20-
":test_dynamic_fallback",
2121
":test_modules_as_engines",
2222
":test_multiple_registered_engines",
2323
":test_runtime_thread_safety",
@@ -31,9 +31,9 @@ test_suite(
3131
":test_collections",
3232
":test_compiled_modules",
3333
":test_default_input_types",
34+
":test_dynamic_fallback",
3435
":test_example_tensors",
3536
":test_module_fallback",
36-
":test_dynamic_fallback",
3737
":test_modules_as_engines",
3838
":test_multiple_registered_engines",
3939
":test_runtime_thread_safety",

tests/cpp/test_dynamic_fallback.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "torch/script.h"
55
#include "torch_tensorrt/torch_tensorrt.h"
66

7-
TEST(CppAPITest, ResNet50DynamicFallbackGraphCorrectly) {
7+
TEST(CppAPITest, ResNet18DynamicBatchFallbackCorrectly) {
88
torch::jit::script::Module mod;
99
try {
1010
mod = torch::jit::load("tests/modules/resnet18_scripted.jit.pt");
@@ -16,17 +16,40 @@ TEST(CppAPITest, ResNet50DynamicFallbackGraphCorrectly) {
1616
const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}, {4, 3, 224, 224}, {8, 3, 224, 224}};
1717
std::vector<torch::jit::IValue> jit_inputs_ivalues;
1818
std::vector<torch::jit::IValue> trt_inputs_ivalues;
19-
auto in = at::randint(5, input_shapes[0], {at::kCUDA});
20-
jit_inputs_ivalues.push_back(in.clone());
21-
trt_inputs_ivalues.push_back(in.clone());
19+
auto in_bs1 = at::randint(5, input_shapes[0], {at::kCUDA});
20+
jit_inputs_ivalues.push_back(in_bs1.clone());
21+
trt_inputs_ivalues.push_back(in_bs1.clone());
2222

2323
std::vector<torch_tensorrt::Input> inputs;
2424
inputs.push_back(torch_tensorrt::Input(input_shapes[0], input_shapes[1], input_shapes[2]));
2525
torch_tensorrt::ts::CompileSpec cfg(inputs);
2626
cfg.torch_executed_ops.push_back("aten::add");
2727

28-
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
28+
auto jit_results_bs1 = mod.forward(jit_inputs_ivalues).toTensor();
29+
// Compile and build the hybrid graph with dynamic shapes
2930
auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);
30-
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
31-
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results));
31+
auto trt_results_bs1 = trt_mod.forward(trt_inputs_ivalues).toTensor();
32+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results_bs1, trt_results_bs1));
33+
jit_inputs_ivalues.clear();
34+
trt_inputs_ivalues.clear();
35+
36+
// Run with batch size of 4
37+
auto in_bs4 = at::randint(5, input_shapes[1], {at::kCUDA});
38+
jit_inputs_ivalues.push_back(in_bs4.clone());
39+
trt_inputs_ivalues.push_back(in_bs4.clone());
40+
41+
auto jit_results_bs4 = mod.forward(jit_inputs_ivalues).toTensor();
42+
auto trt_results_bs4 = trt_mod.forward(trt_inputs_ivalues).toTensor();
43+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results_bs4, trt_results_bs4));
44+
jit_inputs_ivalues.clear();
45+
trt_inputs_ivalues.clear();
46+
47+
// Run with batch size of 8
48+
auto in_bs8 = at::randint(5, input_shapes[2], {at::kCUDA});
49+
jit_inputs_ivalues.push_back(in_bs8.clone());
50+
trt_inputs_ivalues.push_back(in_bs8.clone());
51+
52+
auto jit_results_bs8 = mod.forward(jit_inputs_ivalues).toTensor();
53+
auto trt_results_bs8 = trt_mod.forward(trt_inputs_ivalues).toTensor();
54+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results_bs8, trt_results_bs8));
3255
}

0 commit comments

Comments
 (0)