File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -329,7 +329,6 @@ void MapInputsAndDetermineDTypes(
329
329
} else if (spec.dtype_is_user_defined && cfg.partition_info .enabled ) {
330
330
if (!est_type_opt) {
331
331
LOG_INFO (" Cannot infer input tensor dtype in graph, unable to verify user input dtype settings" );
332
- first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
333
332
} else {
334
333
if (util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype ) != est_type_opt.value ()) {
335
334
std::stringstream ss;
@@ -345,9 +344,10 @@ void MapInputsAndDetermineDTypes(
345
344
ss << " - Disable partial compilation by setting require_full_compilation to True" ;
346
345
auto warn_str = ss.str ();
347
346
LOG_WARNING (warn_str);
348
- // Overwrite type map with user settings
349
- first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
350
347
}
348
+ // Overwrite type map with user settings
349
+ // We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes
350
+ first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
351
351
}
352
352
} else {
353
353
// The user defined the type so no changes are necessary
You can’t perform that action at this time.
0 commit comments