Skip to content

Commit 8e60a54

Browse files
committed
feat: Initial implementation of dynamic shapes + fallback
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent a558e2a commit 8e60a54

File tree

8 files changed

+110
-40
lines changed

8 files changed

+110
-40
lines changed

core/compiler.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ partitioning::GraphAndMapping BuildHybridGraph(
137137
auto partitioning_info = cfg.partitioning_info;
138138

139139
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
140-
auto collection_input_ivalues_map =
141-
partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);
142-
143-
partitioning::partition(&partitioning_ctx, collection_input_ivalues_map);
140+
// auto collection_input_ivalues_map =
141+
// partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);
142+
partitioning_ctx.input_types_map = first_use_types;
143+
partitioning::partition(&partitioning_ctx);
144144

145145
for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
146146
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
@@ -151,14 +151,16 @@ partitioning::GraphAndMapping BuildHybridGraph(
151151
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
152152

153153
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
154-
auto shapes = seg_block.in_shapes();
155-
auto types = seg_block.in_types();
156-
std::vector<ir::Input> inputs;
157-
for (size_t i = 0; i < shapes.size(); i++) {
158-
auto in = ir::Input(shapes[i]);
159-
in.dtype = util::ScalarTypeToTRTDataType(types[i]);
160-
inputs.push_back(in);
161-
}
154+
// auto shapes = seg_block.in_shapes();
155+
// auto types = seg_block.in_types();
156+
// std::vector<ir::Input> inputs;
157+
// for (size_t i = 0; i < shapes.size(); i++) {
158+
// auto in = ir::Input(shapes[i]);
159+
// in.dtype = util::ScalarTypeToTRTDataType(types[i]);
160+
// inputs.push_back(in);
161+
// }
162+
auto inputs = seg_block.construct_inputs_spec();
163+
LOG_DEBUG("============ INPUTS: " << inputs);
162164
// update the input ranges for each segments
163165
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
164166

core/partitioning/partitioning.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
436436
return;
437437
}
438438

439-
void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
439+
void partition(PartitioningCtx* ctx) {
440440
LOG_DEBUG(ctx->settings);
441441

442442
// Go through all the blocks to do the partitioning
@@ -453,7 +453,17 @@ void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
453453
registerSegmentsOutputs(ctx, block);
454454

455455
// run shape analysis on each segmented block
456-
runShapeAnalysis(ctx, block, example_tensor_map);
456+
auto min_input_ivalues_map =
457+
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "min");
458+
auto opt_input_ivalues_map =
459+
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "opt");
460+
auto max_input_ivalues_map =
461+
partitioning::generateRandomInputs(ctx->settings.collection_input_spec_map, ctx->input_types_map, "max");
462+
463+
runShapeAnalysis(ctx, block, min_input_ivalues_map, "min");
464+
runShapeAnalysis(ctx, block, opt_input_ivalues_map, "opt");
465+
runShapeAnalysis(ctx, block, max_input_ivalues_map, "max");
466+
457467
}
458468
}
459469

core/partitioning/partitioning.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,20 @@ namespace torch_tensorrt {
1313
namespace core {
1414
namespace partitioning {
1515

16-
typedef std::unordered_map<const torch::jit::Value*, torch::jit::IValue> ExampleIValues;
16+
typedef std::unordered_map<const torch::jit::Value*, c10::IValue> ExampleIValues;
1717

1818
typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
1919
GraphAndMapping;
2020

21-
ExampleIValues generateRandomInputs(ir::CollectionInputSpecMap& input_ranges, ir::CollectionTypeMap& input_types);
21+
ExampleIValues generateRandomInputs(ir::CollectionInputSpecMap& input_ranges, ir::CollectionTypeMap& input_types, const std::string& shape_mode = std::string("opt"));
2222

23-
void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& ivalues_maps);
23+
void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& ivalues_maps, const std::string& shape_mode);
2424

2525
void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block);
2626

2727
GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block);
2828

29-
void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map);
29+
void partition(PartitioningCtx* ctx);
3030

3131
} // namespace partitioning
3232
} // namespace core

core/partitioning/partitioningctx/PartitioningCtx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ struct PartitioningCtx {
6060
bool shouldNodeRunInTorch(torch::jit::Node* n);
6161
bool shouldNodeRunInTensorRT(torch::jit::Node* n);
6262
std::vector<torch::jit::Node*> getNodesRunInTorch();
63+
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types_map;
6364

6465
private:
6566
void _load_nodes_into_decision_map(torch::jit::Block* b);

core/partitioning/segmentedblock/SegmentedBlock.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "SegmentedBlock.h"
2+
#include "core/util/prelude.h"
23

34
namespace torch_tensorrt {
45
namespace core {
@@ -56,6 +57,26 @@ torch::jit::Value* SegmentedBlock::getOrAddInputForValue(torch::jit::Value* old_
5657
}
5758
}
5859

60+
std::vector<ir::Input> SegmentedBlock::construct_inputs_spec() const {
61+
std::vector<ir::Input> inputs;
62+
if (min_shapes_.size() == opt_shapes_.size() && opt_shapes_.size() == max_shapes_.size()){
63+
LOG_DEBUG("====== IS DYNAMIC ====");
64+
for (uint64_t i=0; i < opt_shapes_.size(); i++){
65+
auto in = ir::Input(min_shapes_[i], opt_shapes_[i], max_shapes_[i]);
66+
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
67+
inputs.push_back(in);
68+
}
69+
} else {
70+
LOG_DEBUG("====== IS STATIC ====");
71+
for (uint64_t i=0; i < opt_shapes_.size(); i++){
72+
auto in = ir::Input(opt_shapes_[i]);
73+
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
74+
inputs.push_back(in);
75+
}
76+
}
77+
return inputs;
78+
}
79+
5980
torch::jit::Node* SegmentedBlock::cloneNode(torch::jit::Node* node) {
6081
auto* block = g_->block();
6182
auto env = [&](torch::jit::Value* v) { return getOrAddInputForValue(v); };

core/partitioning/segmentedblock/SegmentedBlock.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct SegmentedBlock {
3535
SegmentedBlock(BlockID id, SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
3636

3737
torch::jit::Value* getOrAddInputForValue(torch::jit::Value* v);
38+
std::vector<ir::Input> construct_inputs_spec() const;
3839
torch::jit::Node* cloneNode(torch::jit::Node* node);
3940
void appendNode(torch::jit::Node* n) {
4041
cloneNode(n);
@@ -72,18 +73,25 @@ struct SegmentedBlock {
7273
bool contain_raw_value(torch::jit::Value* input) const {
7374
return old_to_new_.count(input);
7475
}
75-
void register_inshapes(std::vector<ir::Input>& in_shapes) {
76-
in_shapes_ = in_shapes;
77-
}
78-
const std::vector<ir::Input>& in_shapes() const {
79-
return in_shapes_;
76+
void register_inshapes(std::vector<std::vector<int64_t>>& in_shapes, const std::string& shape_mode) {
77+
if (shape_mode.compare("min") == 0){
78+
min_shapes_ = in_shapes;
79+
} else if(shape_mode.compare("opt") == 0){
80+
opt_shapes_ = in_shapes;
81+
} else{
82+
max_shapes_ = in_shapes;
83+
}
8084
}
85+
// const std::vector<ir::Input>& in_shapes() const {
86+
// return in_shapes_;
87+
// }
8188
void register_intypes(std::vector<at::ScalarType>& in_types) {
8289
in_types_ = in_types;
8390
}
8491
const std::vector<at::ScalarType>& in_types() const {
8592
return in_types_;
8693
}
94+
8795
void update_id(BlockID new_id) {
8896
id_ = new_id;
8997
}
@@ -99,7 +107,9 @@ struct SegmentedBlock {
99107
private:
100108
BlockID id_;
101109
SegmentedBlockTarget target_;
102-
std::vector<ir::Input> in_shapes_;
110+
std::vector<std::vector<int64_t>> min_shapes_;
111+
std::vector<std::vector<int64_t>> opt_shapes_;
112+
std::vector<std::vector<int64_t>> max_shapes_;
103113
std::vector<at::ScalarType> in_types_;
104114
std::vector<torch::jit::Value*> inputs_;
105115
std::vector<torch::jit::Value*> outputs_;

core/partitioning/shape_analysis.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,33 @@ namespace torch_tensorrt {
99
namespace core {
1010
namespace partitioning {
1111

12-
at::Tensor generateSingleInput(ir::Input& input, c10::optional<at::ScalarType>& type_opt) {
13-
auto cur_shape = input.input_shape;
14-
std::vector<int64_t> shape;
15-
shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims);
16-
// auto type_opt = types[input.first][i];
12+
at::Tensor generateSingleInput(ir::Input& input, c10::optional<at::ScalarType>& type_opt, const std::string& shape_mode) {
13+
nvinfer1::Dims input_shape = input.input_shape;
14+
if (input.input_is_dynamic){
15+
if (shape_mode.compare("min") == 0){
16+
input_shape = input.min;
17+
} else if(shape_mode.compare("opt") == 0){
18+
input_shape = input.opt;
19+
} else {
20+
input_shape = input.max;
21+
}
22+
}
23+
1724
auto type = at::kFloat;
1825
if (type_opt) {
1926
type = type_opt.value();
2027
} else {
2128
LOG_WARNING("Input type for doing shape analysis could not be determined, defaulting to F32");
2229
}
23-
auto in = at::randint(5, shape, {at::kCUDA}).to(type);
24-
// ivalue_map[input.first] = in.clone();
30+
auto in = at::randint(5, util::toVec(input_shape), {at::kCUDA}).to(type);
31+
2532
return in;
2633
}
2734

2835
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomInputs(
2936
std::unordered_map<const torch::jit::Value*, std::vector<ir::Input>>& inputs,
30-
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>>& types) {
37+
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>>& types,
38+
const std::string& shape_mode) {
3139
// generate random inputs for running pytorch segments
3240
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> ivalue_map;
3341

@@ -36,21 +44,21 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
3644
c10::TypePtr elementType = c10::TensorType::get();
3745
auto generic_list = c10::impl::GenericList(elementType);
3846
for (size_t i = 0; i < input.second.size(); i++) {
39-
auto in = generateSingleInput(input.second[i], types[input.first][i]);
47+
auto in = generateSingleInput(input.second[i], types[input.first][i], shape_mode);
4048
generic_list.push_back(in.clone());
4149
}
4250
ivalue_map[input.first] = c10::IValue(generic_list);
4351
} else if (input.first->type()->kind() == torch::jit::TypeKind::TupleType) {
4452
// create tuple
4553
std::vector<torch::jit::IValue> list;
4654
for (size_t i = 0; i < input.second.size(); i++) {
47-
auto in = generateSingleInput(input.second[i], types[input.first][i]);
55+
auto in = generateSingleInput(input.second[i], types[input.first][i], shape_mode);
4856
list.push_back(in.clone());
4957
}
5058
auto tuple = c10::ivalue::Tuple::create(list); // create tuple ptr
5159
ivalue_map[input.first] = c10::IValue(tuple);
5260
} else {
53-
auto in = generateSingleInput(input.second[0], types[input.first][0]);
61+
auto in = generateSingleInput(input.second[0], types[input.first][0], shape_mode);
5462
ivalue_map[input.first] = in.clone();
5563
}
5664
}
@@ -60,7 +68,8 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
6068
void getSegmentsOutputByRunning(
6169
SegmentedBlock& seg_block,
6270
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
63-
const PartitioningInfo& partitioning_info) {
71+
const PartitioningInfo& partitioning_info,
72+
const std::string& shape_mode) {
6473
// create a module to run the graph
6574
auto g = seg_block.g();
6675
auto copy_g = g->copy();
@@ -141,7 +150,7 @@ void getSegmentsOutputByRunning(
141150
}
142151

143152
// set input shape for each segmented block so we wil use it in conversion process
144-
std::vector<ir::Input> input_shapes;
153+
std::vector<std::vector<int64_t>> input_shapes;
145154
std::vector<at::ScalarType> input_types;
146155
for (auto& i : seg_block.raw_inputs()) {
147156
if (ivalues_maps[i].isTensor()) {
@@ -175,15 +184,15 @@ void getSegmentsOutputByRunning(
175184
// TODO: tuple and list inputs in subgraph
176185
}
177186

178-
seg_block.register_inshapes(input_shapes);
187+
seg_block.register_inshapes(input_shapes, shape_mode);
179188
seg_block.register_intypes(input_types);
180189
}
181190

182-
void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) {
191+
void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map, const std::string& shape_mode) {
183192
// register every segment's input shape, and it's running output IValues
184193
for (auto& seg_block : ctx->partitioned_blocks[block]) {
185194
torch::jit::ConstantPooling(seg_block.g());
186-
getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings);
195+
getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings, shape_mode);
187196
}
188197
return;
189198
}

tests/cpp/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ test_suite(
1717
":test_default_input_types",
1818
":test_example_tensors",
1919
":test_module_fallback",
20+
":test_dynamic_fallback",
2021
":test_modules_as_engines",
2122
":test_multiple_registered_engines",
2223
":test_runtime_thread_safety",
@@ -32,6 +33,7 @@ test_suite(
3233
":test_default_input_types",
3334
":test_example_tensors",
3435
":test_module_fallback",
36+
":test_dynamic_fallback",
3537
":test_modules_as_engines",
3638
":test_multiple_registered_engines",
3739
":test_runtime_thread_safety",
@@ -125,6 +127,21 @@ cc_test(
125127
}),
126128
)
127129

130+
cc_test(
131+
name = "test_dynamic_fallback",
132+
srcs = ["test_dynamic_fallback.cpp"],
133+
data = [
134+
"//tests/modules:jit_models",
135+
],
136+
deps = [
137+
"//tests/util",
138+
"@googletest//:gtest_main",
139+
] + select({
140+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
141+
"//conditions:default": ["@libtorch//:libtorch"],
142+
}),
143+
)
144+
128145
cc_test(
129146
name = "test_collections",
130147
srcs = ["test_collections.cpp"],

0 commit comments

Comments
 (0)