@@ -297,46 +297,48 @@ void MapInputsAndDetermineDTypes(
297
297
cfg.convert_info .inputs = std::move (ir::associate_specs_with_inputs (g, cfg.inputs , static_params));
298
298
299
299
for (auto & in : g->inputs ()) {
300
- auto est_type_opt = first_use_type_map.find (in)->second ;
301
- ir::Input& spec = cfg.convert_info .inputs .find (in)->second ;
302
- if (est_type_opt && !spec.dtype_is_user_defined ) {
303
- // If we can calculate the type from the graph and the type was not defined by the user then use the calculated
304
- // type
305
- LOG_INFO (
306
- " Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
307
- << in->debugName () << " has type " << est_type_opt.value ()
308
- << " . If this is incorrect explicitly set dtype for input and file a bug" );
309
- spec.dtype = util::ScalarTypeToTRTDataType (est_type_opt.value ());
310
- } else if (!est_type_opt && !spec.dtype_is_user_defined ) {
311
- // If we cannot calculate the type and the user did not define the type, then default to FP32
312
- LOG_WARNING (
313
- " Cannot infer input type from calcuations in graph for input "
314
- << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
315
- spec.dtype = nvinfer1::DataType::kFLOAT ;
316
- } else if (spec.dtype_is_user_defined && cfg.partition_info .enabled ) {
317
- if (!est_type_opt) {
318
- LOG_INFO (" Cannot infer input tensor dtype in graph, unable to verify user input dtype settings" );
319
- } else {
320
- if (util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype ) != est_type_opt.value ()) {
321
- std::stringstream ss;
322
- ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
323
- ss << cfg.convert_info .inputs .find (in)->second .dtype ;
324
- ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
325
- ss << est_type_opt.value () << std::endl;
326
- ss << " The compiler is going to use the user setting " << cfg.convert_info .inputs .find (in)->second .dtype ;
327
- ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
328
- ss << " compatibility with PyTorch's data type convention is required.\n " ;
329
- ss << " If you do indeed see errors at runtime either:\n " ;
330
- ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
331
- ss << " - Disable partial compilation by setting require_full_compilation to True" ;
332
- auto warn_str = ss.str ();
333
- LOG_WARNING (warn_str);
334
- // Overwrite type map with user settings
335
- first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
300
+ if (static_params.find (in) == static_params.end ()) {
301
+ ir::Input& spec = cfg.convert_info .inputs .find (in)->second ;
302
+ auto est_type_opt = first_use_type_map.find (in)->second ;
303
+ if (est_type_opt && !spec.dtype_is_user_defined ) {
304
+ // If we can calculate the type from the graph and the type was not defined by the user then use the calculated
305
+ // type
306
+ LOG_INFO (
307
+ " Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
308
+ << in->debugName () << " has type " << est_type_opt.value ()
309
+ << " . If this is incorrect explicitly set dtype for input and file a bug" );
310
+ spec.dtype = util::ScalarTypeToTRTDataType (est_type_opt.value ());
311
+ } else if (!est_type_opt && !spec.dtype_is_user_defined ) {
312
+ // If we cannot calculate the type and the user did not define the type, then default to FP32
313
+ LOG_WARNING (
314
+ " Cannot infer input type from calcuations in graph for input "
315
+ << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
316
+ spec.dtype = nvinfer1::DataType::kFLOAT ;
317
+ } else if (spec.dtype_is_user_defined && cfg.partition_info .enabled ) {
318
+ if (!est_type_opt) {
319
+ LOG_INFO (" Cannot infer input tensor dtype in graph, unable to verify user input dtype settings" );
320
+ } else {
321
+ if (util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype ) != est_type_opt.value ()) {
322
+ std::stringstream ss;
323
+ ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
324
+ ss << cfg.convert_info .inputs .find (in)->second .dtype ;
325
+ ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
326
+ ss << est_type_opt.value () << std::endl;
327
+ ss << " The compiler is going to use the user setting " << cfg.convert_info .inputs .find (in)->second .dtype ;
328
+ ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
329
+ ss << " compatibility with PyTorch's data type convention is required.\n " ;
330
+ ss << " If you do indeed see errors at runtime either:\n " ;
331
+ ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
332
+ ss << " - Disable partial compilation by setting require_full_compilation to True" ;
333
+ auto warn_str = ss.str ();
334
+ LOG_WARNING (warn_str);
335
+ // Overwrite type map with user settings
336
+ first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
337
+ }
336
338
}
339
+ } else {
340
+ // The user defined the type so no changes are necessary
337
341
}
338
- } else {
339
- // The user defined the type so no changes are necessary
340
342
}
341
343
}
342
344
}
@@ -375,7 +377,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
375
377
376
378
auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
377
379
378
- return std::move ( engine) ;
380
+ return engine;
379
381
}
380
382
381
383
torch::jit::Module CompileGraph (const torch::jit::Module& mod, CompileSpec cfg) {
0 commit comments