Skip to content

Commit b132a91

Browse files
feat(inference_engine): add TensorRT engine support (#623)
* feat(inference_engine): add TensorRT engine support - Add TensorRT inference session with dynamic shape support - Implement automatic engine building and caching per GPU architecture - Add pre-allocated buffers with pinned memory optimization - Support FP16/FP32/INT8 precision modes - Add optimization profiles for det/rec/cls models - Include comprehensive error handling and resource cleanup - Add unit tests for TensorRT engine (skipped when unavailable) - Add TensorRT configuration section in config.yaml * fix(tensorrt): dynamically allocate output buffers to prevent overflow * Update test_engine.py with TensorRT support final version * feat(tensorrt): add model_type to engine cache key - Include model_type in TensorRT engine cache key generation - Ensures proper cache differentiation based on model_type * fix(tensorrt): update maximum shape configurations for input and output tensors - Increased max_shape values in config.yaml and engine_builder.py to 2048 for both detection and recognition tasks. - Updated default maximum shapes in memory_utils.py to reflect the new limits for input and output buffers. * feat(tensorrt): implement square input handling for MULTI models - Added functionality to check if the model requires square input based on configuration. - Implemented padding of input to square shape and cropping of output back to original dimensions. - Introduced methods to determine maximum profile size and handle square input requirements for TensorRT inference. --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 771b558 commit b132a91

File tree

8 files changed

+1107
-0
lines changed

8 files changed

+1107
-0
lines changed

python/rapidocr/config.yaml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,32 @@ EngineConfig:
8181
npu_ep_cfg:
8282
device_id: 0
8383

84+
tensorrt:
85+
device_id: 0
86+
use_fp16: true
87+
use_int8: false
88+
workspace_size: 1073741824 # 1GB = 1 << 30
89+
90+
# Engine caching
91+
cache_dir: null # null = use default models dir
92+
force_rebuild: false
93+
94+
# Dynamic shape optimization profiles
95+
det_profile:
96+
min_shape: [1, 3, 32, 32]
97+
opt_shape: [1, 3, 736, 736]
98+
max_shape: [1, 3, 2048, 2048]
99+
100+
rec_profile:
101+
min_shape: [1, 3, 48, 32]
102+
opt_shape: [6, 3, 48, 320]
103+
max_shape: [6, 3, 48, 2048]
104+
105+
cls_profile:
106+
min_shape: [1, 3, 48, 32]
107+
opt_shape: [6, 3, 48, 192]
108+
max_shape: [6, 3, 48, 192]
109+
84110
mnn: {}
85111

86112
Det:

python/rapidocr/inference_engine/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ def get_engine(engine_type: EngineType):
5353

5454
return TorchInferSession
5555

56+
if engine_type == EngineType.TENSORRT:
57+
if not import_package("tensorrt"):
58+
raise ImportError("tensorrt is not installed")
59+
60+
from .tensorrt import TRTInferSession
61+
62+
return TRTInferSession
63+
5664
if engine_type == EngineType.MNN:
5765
if not import_package("MNN"):
5866
raise ImportError("MNN is not installed")
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# -*- encoding: utf-8 -*-
2+
# @Author: SWHL
3+
# @Contact: liekkaskono@163.com
4+
from .main import TRTInferSession
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# -*- encoding: utf-8 -*-
2+
# @Author: SWHL
3+
# @Contact: liekkaskono@163.com
4+
from pathlib import Path
5+
from typing import Any, Dict
6+
7+
import tensorrt as trt
8+
9+
from ...utils.log import logger
10+
11+
12+
class TRTEngineBuilder:
13+
"""Build TensorRT engine from ONNX model"""
14+
15+
def __init__(
16+
self,
17+
onnx_path: Path,
18+
engine_path: Path,
19+
cfg: Dict[str, Any],
20+
task_type: str,
21+
trt_logger: trt.Logger,
22+
):
23+
self.onnx_path = onnx_path
24+
self.engine_path = engine_path
25+
self.cfg = cfg
26+
self.task_type = task_type
27+
self.trt_logger = trt_logger
28+
29+
def build(self) -> trt.ICudaEngine:
30+
"""Build TensorRT engine from ONNX"""
31+
builder = trt.Builder(self.trt_logger)
32+
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
33+
network = builder.create_network(network_flags)
34+
parser = trt.OnnxParser(network, self.trt_logger)
35+
36+
# Parse ONNX model
37+
with open(self.onnx_path, "rb") as f:
38+
if not parser.parse(f.read()):
39+
for i in range(parser.num_errors):
40+
logger.error(f"ONNX parse error: {parser.get_error(i)}")
41+
raise RuntimeError("Failed to parse ONNX model")
42+
43+
# Configure builder
44+
config = builder.create_builder_config()
45+
46+
# Set workspace size
47+
workspace_size = self.cfg.get("workspace_size", 1 << 30) # 1GB default
48+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_size)
49+
50+
# Set precision
51+
if self.cfg.get("use_fp16", True) and builder.platform_has_fast_fp16:
52+
config.set_flag(trt.BuilderFlag.FP16)
53+
logger.info("Using FP16 precision")
54+
else:
55+
logger.info("Using FP32 precision")
56+
57+
if self.cfg.get("use_int8", False) and builder.platform_has_fast_int8:
58+
config.set_flag(trt.BuilderFlag.INT8)
59+
logger.info("Using INT8 precision")
60+
61+
# Add optimization profile for dynamic shapes
62+
profile = builder.create_optimization_profile()
63+
self._set_dynamic_shapes(network, profile)
64+
config.add_optimization_profile(profile)
65+
66+
# Build engine
67+
logger.info("Building TensorRT engine (this may take a few minutes)...")
68+
serialized_engine = builder.build_serialized_network(network, config)
69+
70+
if serialized_engine is None:
71+
raise RuntimeError("Failed to build TensorRT engine")
72+
73+
# Save engine to cache
74+
self.engine_path.parent.mkdir(parents=True, exist_ok=True)
75+
with open(self.engine_path, "wb") as f:
76+
f.write(serialized_engine)
77+
logger.info(f"TensorRT engine saved to {self.engine_path}")
78+
79+
# Deserialize and return
80+
runtime = trt.Runtime(self.trt_logger)
81+
return runtime.deserialize_cuda_engine(serialized_engine)
82+
83+
def _set_dynamic_shapes(
84+
self, network: trt.INetworkDefinition, profile: trt.IOptimizationProfile
85+
):
86+
"""Set dynamic shape optimization profiles"""
87+
profile_key = f"{self.task_type}_profile"
88+
profile_cfg = self.cfg.get(profile_key, {})
89+
90+
# Default profiles based on task type
91+
if self.task_type == "det":
92+
min_shape = profile_cfg.get("min_shape", (1, 3, 32, 32))
93+
opt_shape = profile_cfg.get("opt_shape", (1, 3, 736, 736))
94+
max_shape = profile_cfg.get("max_shape", (1, 3, 2048, 2048))
95+
elif self.task_type == "rec":
96+
min_shape = profile_cfg.get("min_shape", (1, 3, 48, 32))
97+
opt_shape = profile_cfg.get("opt_shape", (6, 3, 48, 320))
98+
max_shape = profile_cfg.get("max_shape", (6, 3, 48, 2048))
99+
elif self.task_type == "cls":
100+
min_shape = profile_cfg.get("min_shape", (1, 3, 48, 32))
101+
opt_shape = profile_cfg.get("opt_shape", (6, 3, 48, 192))
102+
max_shape = profile_cfg.get("max_shape", (6, 3, 48, 192))
103+
else:
104+
# Generic fallback
105+
min_shape = (1, 3, 32, 32)
106+
opt_shape = (1, 3, 224, 224)
107+
max_shape = (1, 3, 2048, 2048)
108+
109+
# Set shapes for input tensor
110+
input_tensor = network.get_input(0)
111+
input_name = input_tensor.name
112+
113+
profile.set_shape(
114+
input_name,
115+
min=tuple(min_shape),
116+
opt=tuple(opt_shape),
117+
max=tuple(max_shape),
118+
)
119+
logger.info(
120+
f"Set optimization profile for {input_name}: "
121+
f"min={min_shape}, opt={opt_shape}, max={max_shape}"
122+
)

0 commit comments

Comments
 (0)