Skip to content

Commit 4edada6

Browse files
authored
Remove explicit batch network flag for TRT 10+ (microsoft#24298)
### Description `nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH` [has been deprecated since 10.0](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/c-api/namespacenvinfer1.html#aa8f406be96c14b7dbea548cf19f09a08) and is always implicitly set for versions 10.0+. Change the EP code to only set this flag for TRT versions 8 and below. ### Motivation and Context Removes deprecated API usages in the TRT EP code. Signed-off-by: Kevin Chen <[email protected]>
1 parent 10e51d2 commit 4edada6

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,8 +2296,9 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
22962296
auto network_flags = 0;
22972297
#if NV_TENSORRT_MAJOR > 8
22982298
network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED);
2299-
#endif
2299+
#else
23002300
network_flags |= 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
2301+
#endif
23012302

23022303
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(network_flags));
23032304
auto trt_parser = tensorrt_ptr::unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
@@ -2907,8 +2908,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
29072908
auto network_flags = 0;
29082909
#if NV_TENSORRT_MAJOR > 8
29092910
network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED);
2910-
#endif
2911+
#else
29112912
network_flags |= 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
2913+
#endif
29122914
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(network_flags));
29132915
auto trt_config = std::unique_ptr<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
29142916
auto trt_parser = tensorrt_ptr::unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));

0 commit comments

Comments
 (0)