Skip to content

Commit 7393fa8

Browse files
committed
feat: support for grouped inputs
Signed-off-by: Naren Dasan <[email protected]>
1 parent 2f896b3 commit 7393fa8

33 files changed

+1257
-165
lines changed

core/compiler.cpp

Lines changed: 70 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ GraphAndMapping ConstructFallbackGraph(
256256
// update the input ranges for each segments
257257
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
258258

259+
// TODO mapping Inputs Ivalue to flatten one here
259260
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
260261
auto temp_g = std::make_shared<torch::jit::Graph>();
261262
auto device_spec = convert_cfg.engine_settings.device;
@@ -306,57 +307,72 @@ void MapInputsAndDetermineDTypes(
306307
CompileSpec& cfg,
307308
std::shared_ptr<torch::jit::Graph>& g,
308309
ir::StaticParams& static_params,
309-
ir::TypeMap& first_use_type_map) {
310-
// Associate input specs with inputs
311-
cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));
312-
313-
for (auto& in : g->inputs()) {
314-
if (static_params.find(in) == static_params.end()) {
315-
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
316-
auto est_type_opt = first_use_type_map.find(in)->second;
317-
if (est_type_opt && !spec.dtype_is_user_defined) {
318-
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated
319-
// type
320-
LOG_INFO(
321-
"Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
322-
<< in->debugName() << " has type " << est_type_opt.value()
323-
<< ". If this is incorrect explicitly set dtype for input and file a bug");
324-
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
325-
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
326-
// If we cannot calculate the type and the user did not define the type, then default to FP32
327-
LOG_WARNING(
328-
"Cannot infer input type from calcuations in graph for input "
329-
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
330-
spec.dtype = nvinfer1::DataType::kFLOAT;
331-
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
332-
if (!est_type_opt) {
333-
LOG_INFO("Cannot infer input tensor dtype in graph. Using user provided input dtype settings");
334-
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
335-
} else {
336-
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
310+
ir::CollectionTypeMap& first_use_type_map) {
311+
cfg.convert_info.collection_input_spec_map = std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
312+
313+
auto collection_inputs = ir::get_collection_inputs(g, static_params);
314+
LOG_DEBUG("In MapInputsAndDetermineDTypes, the g->inputs() size is " << g->inputs().size() << ", CollectionInputSpecMap size is" << collection_inputs.size());
315+
316+
for (auto in : collection_inputs) {
317+
std::vector<ir::Input>& spec = cfg.convert_info.collection_input_spec_map.find(in)->second;
318+
std::vector<c10::optional<at::ScalarType>> est_type_opt;
319+
320+
auto est_it = first_use_type_map.find(in);
321+
if (est_it != first_use_type_map.end()) {
322+
est_type_opt = first_use_type_map.find(in)->second;
323+
}
324+
// traverse elements in est_type_out and spec
325+
for (int i = 0; i < est_type_opt.size(); i++) {
326+
if (est_type_opt[i] && !spec[i].dtype_is_user_defined) {
327+
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated
328+
// type
329+
LOG_INFO(
330+
"Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
331+
<< in->debugName() << " has type " << est_type_opt[i].value());
332+
spec[i].dtype = util::ScalarTypeToTRTDataType(est_type_opt[i].value());
333+
} else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) {
334+
// If we cannot calculate the type and the user did not define the type, then default to FP32
335+
LOG_WARNING(
336+
"Cannot infer input type from calcuations in graph for input "
337+
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
338+
spec[i].dtype = nvinfer1::DataType::kFLOAT;
339+
} else if (spec[i].dtype_is_user_defined && cfg.partition_info.enabled) {
340+
if (!est_type_opt[i]) {
341+
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
337342
std::stringstream ss;
338343
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
339-
ss << cfg.convert_info.inputs.find(in)->second.dtype;
340-
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
341-
ss << est_type_opt.value() << std::endl;
342-
ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype;
343-
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
344-
ss << "compatibility with PyTorch's data type convention is required.\n";
345-
ss << "If you do indeed see errors at runtime either:\n";
346-
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
347-
ss << "- Disable partial compilation by setting require_full_compilation to True";
344+
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
345+
ss << ". The compiler is going to use the user setting " << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
348346
auto warn_str = ss.str();
349347
LOG_WARNING(warn_str);
348+
// Overwrite type map with user settings
349+
first_use_type_map[in][i] = {util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
350+
351+
} else {
352+
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) != est_type_opt[i].value()) {
353+
std::stringstream ss;
354+
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
355+
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
356+
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
357+
ss << est_type_opt[i].value() << std::endl;
358+
ss << "The compiler is going to use the user setting " << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
359+
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
360+
ss << "compatibility with PyTorch's data type convention is required.\n";
361+
ss << "If you do indeed see errors at runtime either:\n";
362+
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
363+
ss << "- Disable partial compilation by setting require_full_compilation to True";
364+
auto warn_str = ss.str();
365+
LOG_WARNING(warn_str);
366+
// Overwrite type map with user settings
367+
first_use_type_map[in][i] = {util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
368+
}
350369
}
351-
// Overwrite type map with user settings
352-
// We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes
353-
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
370+
} else {
371+
// The user defined the type so no changes are necessary
354372
}
355-
} else {
356-
// The user defined the type so no changes are necessary
357373
}
358374
}
359-
}
375+
// }
360376
}
361377

362378
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
@@ -370,7 +386,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
370386
auto params = graph_and_parameters.second;
371387
auto static_params = ir::get_static_params(g->inputs(), params);
372388
// Infer the type of an input from the weights of the calculation
373-
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
389+
// auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
390+
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());
374391

375392
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
376393

@@ -395,23 +412,25 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
395412
auto params = graph_and_parameters.second;
396413
auto static_params = ir::get_static_params(g->inputs(), params);
397414
// Infer the type of an input from the weights of the calculation
398-
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
415+
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());
399416

400417
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
401418
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
419+
auto outputIsCollection = conversion::OutputIsCollection(g->block());
402420
if (cfg.partition_info.enabled &&
403421
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
404422
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
405423
LOG_INFO("Skipping partitioning since model is fully supported");
406424
}
407425

408426
if (cfg.partition_info.enabled &&
409-
!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
410-
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
411-
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
427+
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
428+
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)
429+
|| outputIsCollection)) {
430+
412431
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
413-
auto graph_and_mapping =
414-
ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes);
432+
auto collection_input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);
433+
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
415434
new_g = graph_and_mapping.first;
416435
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
417436
for (size_t i = 0; i < new_g->inputs().size(); ++i) {
@@ -429,6 +448,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
429448
TORCHTRT_CHECK(
430449
conversion::VerifyConverterSupportForBlock(g->block()),
431450
"Not all operations in graph are supported by the compiler");
451+
// TODO find the right
432452
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
433453
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
434454
}

core/compiler.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
#include "core/partitioning/partitioning.h"
99
#include "core/runtime/runtime.h"
1010
#include "torch/csrc/jit/api/module.h"
11+
#include "torch/csrc/jit/ir/ir.h"
1112

1213
namespace torch_tensorrt {
1314
namespace core {
1415

1516
struct CompileSpec {
16-
CompileSpec(std::vector<ir::Input> inputs) : inputs(inputs) {}
17-
std::vector<ir::Input> inputs;
17+
CompileSpec(std::vector<ir::Input> inputs) : graph_inputs(inputs) {}
18+
CompileSpec(torch::jit::IValue& input_signature) : graph_inputs(input_signature) {}
19+
ir::GraphInputs graph_inputs;
1820
conversion::ConversionInfo convert_info;
1921
lowering::LowerInfo lower_info;
2022
partitioning::PartitionInfo partition_info;

core/conversion/conversion.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,10 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
138138
void AddInputs(
139139
ConversionCtx* ctx,
140140
c10::ArrayRef<const torch::jit::Value*> inputs,
141-
std::unordered_map<const torch::jit::Value*, ir::Input>& input_specs) {
141+
ConversionInfo& conversion_info) {
142+
std::unordered_map<const torch::jit::Value*, ir::Input>& input_specs = conversion_info.inputs;
143+
std::unordered_map<const torch::jit::Value*, std::vector<ir::Input>> collection_input_spec = conversion_info.collection_input_spec_map;
144+
142145
std::vector<const torch::jit::Value*> input_tensors;
143146
for (auto in : inputs) {
144147
// Disregarding inputs that are not tensors
@@ -166,9 +169,15 @@ void AddInputs(
166169
for (auto input : input_tensors) {
167170
const torch::jit::Value* in = input;
168171
TORCHTRT_CHECK(
169-
input_specs.find(in) != input_specs.end(),
172+
input_specs.find(in) != input_specs.end() || collection_input_spec.find(in) != collection_input_spec.end(),
170173
"Cannot find an input spec associated with input: " << in->debugName());
171-
ir::Input& spec = input_specs.find(in)->second;
174+
ir::Input spec;
175+
if (input_specs.find(in) != input_specs.end()) {
176+
spec = input_specs.find(in)->second;
177+
} else {
178+
spec = collection_input_spec.find(in)->second[0]; // assume input is tensor
179+
}
180+
// ir::Input& spec = input_specs.find(in)->second;
172181

173182
std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
174183
LOG_INFO(
@@ -408,7 +417,7 @@ void ConvertBlockToNetDef(
408417

409418
auto inputs = b->inputs();
410419
AddParamsToCtxValueMap(ctx, static_params);
411-
AddInputs(ctx, inputs, build_info.inputs);
420+
AddInputs(ctx, inputs, build_info);
412421

413422
auto nodes = b->nodes();
414423

@@ -549,6 +558,15 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
549558
return convertable_ops;
550559
}
551560

561+
bool OutputIsCollection(const torch::jit::Block* b) {
562+
for (auto out: b->outputs()) {
563+
if(out->type()->kind() == torch::jit::TypeKind::TupleType || out->type()->kind() == torch::jit::TypeKind::ListType) {
564+
return true;
565+
}
566+
}
567+
return false;
568+
}
569+
552570
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors) {
553571
auto unsupported_ops = GetUnsupportedOpsInBlock(b);
554572
if (unsupported_ops.size() != 0) {

core/conversion/conversion.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ namespace conversion {
1313

1414
struct ConversionInfo {
1515
ir::InputSpecMap inputs;
16+
ir::CollectionInputSpecMap collection_input_spec_map;
1617
BuilderSettings engine_settings;
1718
};
1819

@@ -25,6 +26,8 @@ std::string ConvertBlockToEngine(
2526

2627
bool OpSupported(const torch::jit::Node* n);
2728

29+
bool OutputIsCollection(const torch::jit::Block* b);
30+
2831
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false);
2932

3033
c10::optional<torch::jit::IValue> EvaluateNode(

core/conversion/evaluators/aten.cpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -264,21 +264,6 @@ auto aten_registrations TORCHTRT_UNUSED =
264264
},
265265
EvalOptions().validSchemas(
266266
{"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})})
267-
.evaluator({c10::Symbol::fromQualString("aten::__getitem__"),
268-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
269-
auto list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
270-
auto idx = args.at(n->input(1)).unwrapToInt();
271-
272-
const int64_t list_size = list.size();
273-
const int64_t normalized_idx = normalizeIndex(idx, list_size);
274-
TORCHTRT_CHECK(
275-
normalized_idx >= 0 || normalized_idx < list_size,
276-
"List index out of range (aten::__getitem__)");
277-
return list.get(normalized_idx);
278-
},
279-
EvalOptions().validSchemas({
280-
"aten::__getitem__.t(t[](a) list, int idx) -> (t(*))",
281-
})})
282267
.evaluator({c10::Symbol::fromQualString("aten::append"),
283268
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
284269
auto list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();

core/ir/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ cc_library(
1515
srcs = [
1616
"ir.cpp",
1717
"Input.cpp",
18-
"StaticParams.cpp"
18+
"StaticParams.cpp",
19+
"GraphInputs.cpp"
1920
],
2021
deps = [
2122
"@tensorrt//:nvinfer",

core/ir/GraphInputs.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#include "core/ir/ir.h"
2+
#include "core/util/prelude.h"
3+
4+
namespace torch_tensorrt {
5+
namespace core {
6+
namespace ir {
7+
8+
void flatten_dfs(std::vector<torch_tensorrt::core::ir::Input>& flattened_inputs, std::vector<std::vector<torch_tensorrt::core::ir::Input>>& collection_inputs,
9+
torch::jit::IValue input_ivalue, int level, int index) {
10+
if (input_ivalue.isTuple()) {
11+
auto input_tuple = input_ivalue.toTuple();
12+
int idx = 0;
13+
if (level == 0) {
14+
collection_inputs.resize(input_tuple->elements().size());
15+
}
16+
for (auto item: input_tuple->elements()) {
17+
torch::jit::IValue converted_item;
18+
int cur_idx = level < 1 ? idx: index;
19+
flatten_dfs(flattened_inputs, collection_inputs, item, level+1, cur_idx);
20+
idx++;
21+
}
22+
} else if(input_ivalue.isList()) {
23+
auto input_list = input_ivalue.toList().vec();
24+
if (level == 0) {
25+
collection_inputs.resize(input_list.size());
26+
}
27+
c10::TypePtr type = input_list[0].type();
28+
auto converted_elements = c10::impl::GenericList(type);
29+
int idx = 0;
30+
for (auto item: input_list) {
31+
int cur_idx = level < 1 ? idx: index;
32+
flatten_dfs(flattened_inputs, collection_inputs, item, level+1, cur_idx);
33+
idx++;
34+
}
35+
} else if(input_ivalue.isCustomClass()) {
36+
torch_tensorrt::core::ir::Input cur_input = *(input_ivalue.toCustomClass<torch_tensorrt::core::ir::Input>());
37+
flattened_inputs.push_back(cur_input);
38+
if (level == 0) { // a single value like A
39+
collection_inputs.resize(1);
40+
collection_inputs[0].push_back(cur_input);
41+
} else if (level == 1) { // like A in [A, A] or [(B, B), A]
42+
collection_inputs[index].push_back(cur_input);
43+
} else if (level == 2) { // like A in [(A, A), C]
44+
collection_inputs[index].push_back(cur_input);
45+
} else {// only support 2 level
46+
LOG_ERROR("Input nesting depth exceeds currently supported depth (3), use 1 level: [A, B], or 2 level: [A, (B, C)]");
47+
}
48+
}
49+
}
50+
51+
52+
GraphInputs::GraphInputs(std::vector<ir::Input> inputs_) {
53+
LOG_DEBUG("Construct GraphInput with ir::Input");
54+
inputs = inputs_;
55+
collection_inputs.resize(inputs_.size());
56+
for (int i = 0; i < inputs_.size(); i++) {
57+
collection_inputs[i].push_back(inputs_[i]);
58+
}
59+
}
60+
61+
GraphInputs::GraphInputs(torch::jit::IValue& input_signature_) {
62+
LOG_DEBUG("Construct GraphInput with IValue");
63+
64+
std::vector<torch_tensorrt::core::ir::Input> flattened_inputs;
65+
std::vector<std::vector<torch_tensorrt::core::ir::Input>> collection_inputs_;
66+
67+
flatten_dfs(flattened_inputs, collection_inputs_, input_signature_, 0, 0);
68+
inputs = flattened_inputs;
69+
input_signature = input_signature_;
70+
collection_inputs = collection_inputs_;
71+
LOG_DEBUG("Collection Input Size: " << collection_inputs_.size());
72+
}
73+
74+
} // namespace ir
75+
} // namespace core
76+
} // namespace torch_tensorrt

core/ir/StaticParams.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ StaticParams get_static_params(c10::ArrayRef<torch::jit::Value*> inputs, std::ve
1111
StaticParams static_params;
1212
auto param_it = params.begin();
1313
for (auto in : inputs) {
14-
if (in->type() != c10::TensorType::get() && param_it != params.end()) {
14+
// handle TensorType, TupleType and ListType
15+
if (in->type() != c10::TensorType::get() &&
16+
in->type()->kind() != torch::jit::TypeKind::TupleType &&
17+
in->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.end()) {
1518
static_params[in] = *param_it;
1619
++param_it;
1720
}

0 commit comments

Comments
 (0)