Skip to content

Commit 5d8fb5a

Browse files
committed
refactor(//core): Unify the dtype work for the partition section
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 01c89d1 commit 5d8fb5a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

core/compiler.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,6 @@ void MapInputsAndDetermineDTypes(
329329
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
330330
if (!est_type_opt) {
331331
LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings");
332-
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
333332
} else {
334333
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
335334
std::stringstream ss;
@@ -345,9 +344,10 @@ void MapInputsAndDetermineDTypes(
345344
ss << "- Disable partial compilation by setting require_full_compilation to True";
346345
auto warn_str = ss.str();
347346
LOG_WARNING(warn_str);
348-
// Overwrite type map with user settings
349-
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
350347
}
348+
// Overwrite type map with user settings
349+
// We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes
350+
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
351351
}
352352
} else {
353353
// The user defined the type so no changes are necessary

0 commit comments

Comments
 (0)