Skip to content

Commit a2d9152

Browse files
authored
【SCU】【Paddle TensorRT No.36】Add pd_op.flip converter (#69724)
* add * fix codestyle * fix codestyle * fix codestyle * update * fix codestyle * add
1 parent a61ce21 commit a2d9152

File tree

3 files changed

+105
-0
lines changed

3 files changed

+105
-0
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ DEFINE_GENERAL_PATTERN(Floor, paddle::dialect::FloorOp)
8686
DEFINE_GENERAL_PATTERN(Roll, paddle::dialect::RollOp)
8787
DEFINE_GENERAL_PATTERN(Softplus, paddle::dialect::SoftplusOp)
8888
DEFINE_GENERAL_PATTERN(ThresholdedRelu, paddle::dialect::ThresholdedReluOp)
89+
DEFINE_GENERAL_PATTERN(Flip, paddle::dialect::FlipOp)
8990

9091
#undef DEFINE_GENERAL_PATTERN
9192

@@ -2140,6 +2141,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
21402141
ADD_PATTERN(Roll)
21412142
ADD_PATTERN(Softplus)
21422143
ADD_PATTERN(ThresholdedRelu)
2144+
ADD_PATTERN(Flip)
21432145
#if IS_TRT_VERSION_GE(8600)
21442146
ADD_PATTERN(Layer_norm)
21452147
#endif

python/paddle/tensorrt/impls/linalg.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
import tensorrt as trt
1717

1818
from paddle.tensorrt.converter_utils import (
19+
add_1D_constant_layer,
1920
broadcast,
21+
get_shape_tensor_element,
22+
trt_shape,
23+
trt_sum,
2024
)
2125
from paddle.tensorrt.register import converter_registry
2226

@@ -71,3 +75,42 @@ def bmm_converter(network, paddle_op, inputs):
7175
inputs[0], trt.MatrixOperation.NONE, inputs[1], trt.MatrixOperation.NONE
7276
)
7377
return out.get_output(0)
78+
79+
80+
@converter_registry.register("pd_op.flip", trt_version="8.x")
81+
def flip_converter(network, paddle_op, inputs):
82+
input_tensor = inputs[0]
83+
input_dims = input_tensor.shape
84+
rank = len(input_dims)
85+
axis = paddle_op.attrs()["axis"]
86+
axis = [a + rank if a < 0 else a for a in axis]
87+
shape_tensor = trt_shape(network, input_tensor)
88+
89+
def get_axis_length(axis_idx):
90+
dim_val = input_dims[axis_idx]
91+
if dim_val >= 0:
92+
return add_1D_constant_layer(network, [dim_val], is_scalar=True)
93+
else:
94+
return get_shape_tensor_element(
95+
network, shape_tensor, axis_idx, is_scalar=True
96+
)
97+
98+
for axis_idx in axis:
99+
loop_layer = network.add_loop()
100+
trip_limit = get_axis_length(axis_idx)
101+
loop_layer.add_trip_limit(trip_limit, trt.TripLimit.COUNT)
102+
iterator = loop_layer.add_iterator(input_tensor, axis_idx, reverse=True)
103+
zero_tensor = add_1D_constant_layer(network, [0])
104+
one_tensor = add_1D_constant_layer(network, [1])
105+
iRec_layer = loop_layer.add_recurrence(zero_tensor)
106+
iCur = iRec_layer.get_output(0)
107+
iNext_layer = trt_sum(network, iCur, one_tensor)
108+
iRec_layer.set_input(1, iNext_layer)
109+
loop_out_layer = loop_layer.add_loop_output(
110+
iterator.get_output(0), trt.LoopOutput.CONCATENATE, axis_idx
111+
)
112+
loop_out_layer.set_input(1, trip_limit)
113+
input_tensor = loop_out_layer.get_output(0)
114+
115+
identity_layer = network.add_identity(input_tensor)
116+
return identity_layer.get_output(0)

test/tensorrt/test_converter_linalg.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,65 @@ def test_trt_result(self):
6767
self.check_trt_result()
6868

6969

70+
class TestFlipTRTPattern(TensorRTBaseTest):
71+
def setUp(self):
72+
self.python_api = paddle.flip
73+
self.api_args = {
74+
"x": np.random.randn(2, 3, 4).astype("float32"),
75+
"axis": [0, 2],
76+
}
77+
self.program_config = {"feed_list": ["x"]}
78+
self.min_shape = {"x": [1, 3, 4]}
79+
self.max_shape = {"x": [5, 3, 4]}
80+
81+
def test_trt_result(self):
82+
self.check_trt_result()
83+
84+
85+
class TestFlipNegAxisTRTPattern(TensorRTBaseTest):
86+
def setUp(self):
87+
self.python_api = paddle.flip
88+
self.api_args = {
89+
"x": np.random.randn(2, 3, 4).astype("float32"),
90+
"axis": [-1, -3],
91+
}
92+
self.program_config = {"feed_list": ["x"]}
93+
self.min_shape = {"x": [1, 3, 4]}
94+
self.max_shape = {"x": [5, 3, 4]}
95+
96+
def test_trt_result(self):
97+
self.check_trt_result()
98+
99+
100+
class TestFlipIntTRTPattern(TensorRTBaseTest):
101+
def setUp(self):
102+
self.python_api = paddle.flip
103+
self.api_args = {
104+
"x": np.random.randn(2, 3, 4).astype("int64"),
105+
"axis": [0, 2],
106+
}
107+
self.program_config = {"feed_list": ["x"]}
108+
self.min_shape = {"x": [1, 3, 4]}
109+
self.max_shape = {"x": [5, 3, 4]}
110+
111+
def test_trt_result(self):
112+
self.check_trt_result()
113+
114+
115+
class TestFlipIntNegAxisTRTPattern(TensorRTBaseTest):
116+
def setUp(self):
117+
self.python_api = paddle.flip
118+
self.api_args = {
119+
"x": np.random.randn(2, 3, 4).astype("int64"),
120+
"axis": [-1, -3],
121+
}
122+
self.program_config = {"feed_list": ["x"]}
123+
self.min_shape = {"x": [1, 3, 4]}
124+
self.max_shape = {"x": [5, 3, 4]}
125+
126+
def test_trt_result(self):
127+
self.check_trt_result()
128+
129+
70130
if __name__ == '__main__':
71131
unittest.main()

0 commit comments

Comments
 (0)