Skip to content

Commit d638730

Browse files
authored
fix: Bugfix in TRT Engine deserialization indexing (#1646)
1 parent 4fc2935 commit d638730

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,13 @@ TRTEngine::TRTEngine(
128128
for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) {
129129
auto binding_name = _out_binding_names[pyt_idx];
130130
auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str());
131-
std::string engine_binded_name = cuda_engine->getIOTensorName(inputs_size + pyt_idx);
132-
TORCHTRT_CHECK(
133-
(binding_name == engine_binded_name),
134-
"Could not find a TensorRT engine binding for output named " << binding_name);
131+
TORCHTRT_CHECK((trt_idx != -1), "Could not find a TensorRT engine binding for output named " << binding_name);
135132
TORCHTRT_CHECK(
136133
!(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT),
137134
"Binding " << binding_name << " specified as output but found as input in TensorRT engine");
138-
LOG_DEBUG("Output binding name: " << binding_name << "pyt return idx: " << inputs_size + pyt_idx << ")");
135+
LOG_DEBUG(
136+
"Output binding name: " << binding_name << " has TensorRT binding index: " << trt_idx
137+
<< ", Torch binding index: " << inputs_size + pyt_idx);
139138
out_binding_map[trt_idx] = pyt_idx;
140139
out_binding_names[pyt_idx] = binding_name;
141140
}

0 commit comments

Comments
 (0)