@@ -256,6 +256,7 @@ GraphAndMapping ConstructFallbackGraph(
256
256
// update the input ranges for each segments
257
257
convert_cfg.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
258
258
259
+ // TODO mapping Inputs Ivalue to flatten one here
259
260
auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, static_params);
260
261
auto temp_g = std::make_shared<torch::jit::Graph>();
261
262
auto device_spec = convert_cfg.engine_settings .device ;
@@ -306,57 +307,72 @@ void MapInputsAndDetermineDTypes(
306
307
CompileSpec& cfg,
307
308
std::shared_ptr<torch::jit::Graph>& g,
308
309
ir::StaticParams& static_params,
309
- ir::TypeMap& first_use_type_map) {
310
- // Associate input specs with inputs
311
- cfg.convert_info .inputs = std::move (ir::associate_specs_with_inputs (g, cfg.inputs , static_params));
312
-
313
- for (auto & in : g->inputs ()) {
314
- if (static_params.find (in) == static_params.end ()) {
315
- ir::Input& spec = cfg.convert_info .inputs .find (in)->second ;
316
- auto est_type_opt = first_use_type_map.find (in)->second ;
317
- if (est_type_opt && !spec.dtype_is_user_defined ) {
318
- // If we can calculate the type from the graph and the type was not defined by the user then use the calculated
319
- // type
320
- LOG_INFO (
321
- " Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
322
- << in->debugName () << " has type " << est_type_opt.value ()
323
- << " . If this is incorrect explicitly set dtype for input and file a bug" );
324
- spec.dtype = util::ScalarTypeToTRTDataType (est_type_opt.value ());
325
- } else if (!est_type_opt && !spec.dtype_is_user_defined ) {
326
- // If we cannot calculate the type and the user did not define the type, then default to FP32
327
- LOG_WARNING (
328
- " Cannot infer input type from calcuations in graph for input "
329
- << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
330
- spec.dtype = nvinfer1::DataType::kFLOAT ;
331
- } else if (spec.dtype_is_user_defined && cfg.partition_info .enabled ) {
332
- if (!est_type_opt) {
333
- LOG_INFO (" Cannot infer input tensor dtype in graph. Using user provided input dtype settings" );
334
- first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
335
- } else {
336
- if (util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype ) != est_type_opt.value ()) {
310
+ ir::CollectionTypeMap& first_use_type_map) {
311
+ cfg.convert_info .collection_input_spec_map = std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
312
+
313
+ auto collection_inputs = ir::get_collection_inputs (g, static_params);
314
+ LOG_DEBUG (" In MapInputsAndDetermineDTypes, the g->inputs() size is " << g->inputs ().size () << " , CollectionInputSpecMap size is" << collection_inputs.size ());
315
+
316
+ for (auto in : collection_inputs) {
317
+ std::vector<ir::Input>& spec = cfg.convert_info .collection_input_spec_map .find (in)->second ;
318
+ std::vector<c10::optional<at::ScalarType>> est_type_opt;
319
+
320
+ auto est_it = first_use_type_map.find (in);
321
+ if (est_it != first_use_type_map.end ()) {
322
+ est_type_opt = first_use_type_map.find (in)->second ;
323
+ }
324
+ // traverse elements in est_type_out and spec
325
+ for (int i = 0 ; i < est_type_opt.size (); i++) {
326
+ if (est_type_opt[i] && !spec[i].dtype_is_user_defined ) {
327
+ // If we can calculate the type from the graph and the type was not defined by the user then use the calculated
328
+ // type
329
+ LOG_INFO (
330
+ " Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
331
+ << in->debugName () << " has type " << est_type_opt[i].value ());
332
+ spec[i].dtype = util::ScalarTypeToTRTDataType (est_type_opt[i].value ());
333
+ } else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined ) {
334
+ // If we cannot calculate the type and the user did not define the type, then default to FP32
335
+ LOG_WARNING (
336
+ " Cannot infer input type from calcuations in graph for input "
337
+ << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
338
+ spec[i].dtype = nvinfer1::DataType::kFLOAT ;
339
+ } else if (spec[i].dtype_is_user_defined && cfg.partition_info .enabled ) {
340
+ if (!est_type_opt[i]) {
341
+ LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
337
342
std::stringstream ss;
338
343
ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
339
- ss << cfg.convert_info .inputs .find (in)->second .dtype ;
340
- ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
341
- ss << est_type_opt.value () << std::endl;
342
- ss << " The compiler is going to use the user setting " << cfg.convert_info .inputs .find (in)->second .dtype ;
343
- ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
344
- ss << " compatibility with PyTorch's data type convention is required.\n " ;
345
- ss << " If you do indeed see errors at runtime either:\n " ;
346
- ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
347
- ss << " - Disable partial compilation by setting require_full_compilation to True" ;
344
+ ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
345
+ ss << " . The compiler is going to use the user setting " << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
348
346
auto warn_str = ss.str ();
349
347
LOG_WARNING (warn_str);
348
+ // Overwrite type map with user settings
349
+ first_use_type_map[in][i] = {util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
350
+
351
+ } else {
352
+ if (util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ) != est_type_opt[i].value ()) {
353
+ std::stringstream ss;
354
+ ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
355
+ ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
356
+ ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
357
+ ss << est_type_opt[i].value () << std::endl;
358
+ ss << " The compiler is going to use the user setting " << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
359
+ ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
360
+ ss << " compatibility with PyTorch's data type convention is required.\n " ;
361
+ ss << " If you do indeed see errors at runtime either:\n " ;
362
+ ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
363
+ ss << " - Disable partial compilation by setting require_full_compilation to True" ;
364
+ auto warn_str = ss.str ();
365
+ LOG_WARNING (warn_str);
366
+ // Overwrite type map with user settings
367
+ first_use_type_map[in][i] = {util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
368
+ }
350
369
}
351
- // Overwrite type map with user settings
352
- // We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes
353
- first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
370
+ } else {
371
+ // The user defined the type so no changes are necessary
354
372
}
355
- } else {
356
- // The user defined the type so no changes are necessary
357
373
}
358
374
}
359
- }
375
+ // }
360
376
}
361
377
362
378
std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
@@ -370,7 +386,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
370
386
auto params = graph_and_parameters.second ;
371
387
auto static_params = ir::get_static_params (g->inputs (), params);
372
388
// Infer the type of an input from the weights of the calculation
373
- auto first_use_types = ir::get_block_first_calc_dtypes_opt (g->block ());
389
+ // auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
390
+ auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection (g->block ());
374
391
375
392
MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
376
393
@@ -395,23 +412,25 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
395
412
auto params = graph_and_parameters.second ;
396
413
auto static_params = ir::get_static_params (g->inputs (), params);
397
414
// Infer the type of an input from the weights of the calculation
398
- auto first_use_types = ir::get_block_first_calc_dtypes_opt (g->block ());
415
+ auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection (g->block ());
399
416
400
417
MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
401
418
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
419
+ auto outputIsCollection = conversion::OutputIsCollection (g->block ());
402
420
if (cfg.partition_info .enabled &&
403
421
(cfg.lower_info .forced_fallback_modules .size () == 0 &&
404
422
cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible)) {
405
423
LOG_INFO (" Skipping partitioning since model is fully supported" );
406
424
}
407
425
408
426
if (cfg.partition_info .enabled &&
409
- !(cfg.lower_info .forced_fallback_modules .size () == 0 &&
410
- cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible)) {
411
- auto input_ivalues_map = partitioning::generateRandomInputs (cfg.convert_info .inputs , first_use_types);
427
+ (!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
428
+ cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible)
429
+ || outputIsCollection)) {
430
+
412
431
std::unordered_map<torch::jit::Node*, int > fallback_nodes;
413
- auto graph_and_mapping =
414
- ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map , cfg, static_params, fallback_nodes);
432
+ auto collection_input_ivalues_map = partitioning::generateRandomInputs (cfg. convert_info . collection_input_spec_map , first_use_types);
433
+ auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), collection_input_ivalues_map , cfg, static_params, fallback_nodes);
415
434
new_g = graph_and_mapping.first ;
416
435
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
417
436
for (size_t i = 0 ; i < new_g->inputs ().size (); ++i) {
@@ -429,6 +448,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
429
448
TORCHTRT_CHECK (
430
449
conversion::VerifyConverterSupportForBlock (g->block ()),
431
450
" Not all operations in graph are supported by the compiler" );
451
+ // TODO find the right
432
452
auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
433
453
AddEngineToGraph (new_mod, new_g, engine, cuda_device);
434
454
}
0 commit comments