Skip to content

Commit 6af9534

Browse files
authored
【Paddle TensorRT】Modified the serialization save path for TensorRT and added an attribute name to the Input class (#71772) (#71767)
* fix * merge
1 parent 42c47d9 commit 6af9534

File tree

5 files changed

+76
-65
lines changed

5 files changed

+76
-65
lines changed

python/paddle/tensorrt/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def convert_subgraph_to_trt(self, program, group_op):
496496
int(hashlib.sha256(group_str.encode('utf-8')).hexdigest(), 16)
497497
% 10**8
498498
)
499-
CACHE_ROOT = get_cache_path()
499+
CACHE_ROOT = get_cache_path(self.trt_config.save_model_dir)
500500
CACHE_FILE = f"{CACHE_ROOT}/engine_{engine_name}_{self.engine_num}.trt"
501501
with open(CACHE_FILE, "wb") as f:
502502
f.write(trt_engine)

python/paddle/tensorrt/export.py

Lines changed: 67 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -51,56 +51,6 @@
5151

5252

5353
class Input:
54-
"""
55-
A class used to configure input data for models. This class serves two purposes:
56-
57-
1. Random Data Generation: When no input data is supplied, it automatically generates random input data based on the specified minimum, optimal, and maximum shapes. In this mode,you can configure the data type (e.g., 'float32', 'int64', etc.) and the range of values (e.g.,(0.0, 1.0) for floats or (1, 10) for integers).
58-
59-
2. User-Provided Input: Alternatively, you can supply your own input data via the `warmup_data` argument. In this case, the provided data will be used directly, and the`input_data_type` and `input_range` settings will be ignored.
60-
61-
Args:
62-
warmup_data (tuple):
63-
The tuple of actual input data (for the automatic shape collection mechanism).
64-
min_input_shape (tuple):
65-
The shape of the minimum input tensor.
66-
max_input_shape (tuple):
67-
The shape of the maximum input tensor.
68-
optim_input_shape (tuple, optional):
69-
The shape of the optimal input tensor (default is None).
70-
input_data_type (str, optional):
71-
The data type for the input tensors, such as 'float32' or 'int64' or 'float32' or 'int32' (default is float32).
72-
This option only applies when min_input_shape, optim_input_shape, and max_input_shape are provided; it does not apply to warmup_data.
73-
input_range (tuple, optional):
74-
The range of values used to generate input data. For floats, the default range is (0.0, 1.0). For integers, the default range is (1, 10).
75-
This option only applies when min_input_shape, optim_input_shape, and max_input_shape are provided; it does not apply to warmup_data.
76-
Returns:
77-
None
78-
79-
Examples:
80-
.. code-block:: python
81-
82-
>>> # example 1:
83-
>>> from paddle.tensorrt.export import Input
84-
>>> input_config = Input(
85-
>>> min_input_shape=(1,100),
86-
>>> optim_input_shape=(4,100),
87-
>>> max_input_shape=(8,100),
88-
>>> )
89-
>>> input_config.input_data_type='int64'
90-
>>> input_config.input_range=(1,10)
91-
92-
>>> # example 2:
93-
>>> from paddle.tensorrt.export import Input
94-
>>> import numpy as np
95-
>>> input_config = Input(
96-
>>> warmup_data=(
97-
>>> np.random.rand(1,100).astype(np.float32),
98-
>>> np.random.rand(4,100).astype(np.float32),
99-
>>> np.random.rand(8,100).astype(np.float32),
100-
>>> )
101-
>>> )
102-
"""
103-
10454
def __init__(
10555
self,
10656
warmup_data: tuple[np.ndarray, ...] | None = None,
@@ -109,7 +59,59 @@ def __init__(
10959
optim_input_shape: tuple | None = None,
11060
input_data_type: str | None = 'float32',
11161
input_range: tuple | None = None,
62+
name: str | None = None,
11263
) -> None:
64+
"""
65+
A class used to configure input data for models. This class serves two purposes:
66+
67+
1. Random Data Generation: When no input data is supplied, it automatically generates random input data based on the specified minimum, optimal, and maximum shapes. In this mode,you can configure the data type (e.g., 'float32', 'int64', etc.) and the range of values (e.g.,(0.0, 1.0) for floats or (1, 10) for integers).
68+
69+
2. User-Provided Input: Alternatively, you can supply your own input data via the `warmup_data` argument. In this case, the provided data will be used directly, and the`input_data_type` and `input_range` settings will be ignored.
70+
71+
Args:
72+
warmup_data (tuple):
73+
The tuple of actual input data (for the automatic shape collection mechanism).
74+
min_input_shape (tuple):
75+
The shape of the minimum input tensor.
76+
max_input_shape (tuple):
77+
The shape of the maximum input tensor.
78+
optim_input_shape (tuple):
79+
The shape of the optimal input tensor.
80+
input_data_type (str, optional):
81+
The data type for the input tensors, such as 'float32' or 'int64' or 'float32' or 'int32' (default is float32).
82+
This option only applies when min_input_shape, optim_input_shape, and max_input_shape are provided; it does not apply to warmup_data.
83+
input_range (tuple, optional):
84+
The range of values used to generate input data. For floats, the default range is (0.0, 1.0). For integers, the default range is (1, 10).
85+
This option only applies when min_input_shape, optim_input_shape, and max_input_shape are provided; it does not apply to warmup_data.
86+
name:(str,optional):
87+
The name of the input to the model.
88+
Returns:
89+
None
90+
91+
Examples:
92+
.. code-block:: python
93+
94+
>>> # example 1:
95+
>>> from paddle.tensorrt.export import Input
96+
>>> input_config = Input(
97+
>>> min_input_shape=(1,100),
98+
>>> optim_input_shape=(4,100),
99+
>>> max_input_shape=(8,100),
100+
>>> )
101+
>>> input_config.input_data_type='int64'
102+
>>> input_config.input_range=(1,10)
103+
104+
>>> # example 2:
105+
>>> from paddle.tensorrt.export import Input
106+
>>> import numpy as np
107+
>>> input_config = Input(
108+
>>> warmup_data=(
109+
>>> np.random.rand(1,100).astype(np.float32),
110+
>>> np.random.rand(4,100).astype(np.float32),
111+
>>> np.random.rand(8,100).astype(np.float32),
112+
>>> )
113+
>>> )
114+
"""
113115
if warmup_data is not None:
114116
if min_input_shape or max_input_shape or optim_input_shape:
115117
raise ValueError(
@@ -132,6 +134,7 @@ def __init__(
132134
self.optim_input_shape = optim_input_shape
133135
self.input_data_type = input_data_type
134136
self.input_range = input_range
137+
self.name = name
135138

136139
def generate_input_data(self):
137140
"""
@@ -331,17 +334,19 @@ def convert_to_trt(program, trt_config, scope):
331334
assert len({len(t) for t in input_tuples}) == 1
332335
num_samples = len(input_tuples[0])
333336
for sample_idx in range(num_samples):
334-
feed_dict = {
335-
name: input_tuples[i][sample_idx]
336-
for i, name in enumerate(feed_name)
337-
}
337+
feed_dict = {}
338+
for i, inp in enumerate(trt_config.inputs):
339+
name = inp.name if inp.name is not None else feed_name[i]
340+
feed_dict[name] = input_tuples[i][sample_idx]
338341
feeds.append(feed_dict)
339342
else:
340343
input_tuples = [i.generate_input_data() for i in trt_config.inputs]
341-
feeds = [
342-
{name: t[i] for t, name in zip(input_tuples, feed_name)}
343-
for i in range(len(input_tuples[0]))
344-
]
344+
for i in range(len(input_tuples[0])):
345+
feed_dict = {}
346+
for j, inp in enumerate(trt_config.inputs):
347+
name = inp.name if inp.name is not None else feed_name[j]
348+
feed_dict[name] = input_tuples[j][i]
349+
feeds.append(feed_dict)
345350
# run pir pass (including trt_op_marker_pass)
346351
program_with_pir = run_pir_pass(
347352
program,
@@ -640,7 +645,8 @@ def convert(model_path, config):
640645
>>> input_config = Input(
641646
>>> min_input_shape=[1, input_dim],
642647
>>> optim_input_shape=[2, input_dim],
643-
>>> max_input_shape=[4, input_dim]
648+
>>> max_input_shape=[4, input_dim],
649+
>>> name='x',
644650
>>> )
645651
646652
>>> trt_config = TensorRTConfig(inputs=[input_config])
@@ -693,7 +699,8 @@ def convert(model_path, config):
693699
>>> np.random.rand(1,3).astype(np.float32),
694700
>>> np.random.rand(2,3).astype(np.float32),
695701
>>> np.random.rand(4,3).astype(np.float32),
696-
>>> )
702+
>>> ),
703+
>>> name='x',
697704
>>> )
698705
699706
>>> trt_config = TensorRTConfig(inputs=[input_config])

python/paddle/tensorrt/util.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,12 @@ def is_shape_tensor(value):
334334
return total_elements <= 8 and total_elements >= 1 and is_int_dtype
335335

336336

337-
def get_cache_path():
338-
home_path = os.path.expanduser("~")
339-
cache_path = os.path.join(home_path, ".pp_trt_cache")
337+
def get_cache_path(cache_path):
338+
if cache_path is not None:
339+
cache_path = cache_path
340+
else:
341+
home_path = os.path.expanduser("~")
342+
cache_path = os.path.join(home_path, ".pp_trt_cache")
340343

341344
if not os.path.exists(cache_path):
342345
os.makedirs(cache_path)

test/ir/inference/test_trt_convert_range.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def generate_trt_nodes_num(attrs, dynamic_shape):
144144
), 1e-2
145145

146146
def test(self):
147-
self.run_test(run_pir=True)
147+
self.run_test()
148148

149149

150150
class TrtConvertRangeStaticTest(TrtLayerAutoScanTest):

test/tensorrt/test_converter_model_resnet50.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def test_paddle_to_tensorrt_conversion_r50(self):
6060
optim_input_shape=(1, 3, 224, 224),
6161
max_input_shape=(4, 3, 224, 224),
6262
input_data_type='float32',
63+
name='input',
6364
)
6465
_, input_optim_data, _ = input_config.generate_input_data()
6566

0 commit comments

Comments
 (0)