Skip to content

Commit 0e1a3aa

Browse files
author
q.yao
authored
[Feature] support MMRotate model with le135 (#788)
* support MMRotate model with le135 * cse before fuse select assign * remove unused import
1 parent 5b31d7a commit 0e1a3aa

File tree

3 files changed

+137
-2
lines changed

3 files changed

+137
-2
lines changed

csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <torch/csrc/jit/passes/dead_code_elimination.h>
44

55
#include "../../ir/subgraph_matcher.h"
6+
#include "common_subgraph_elimination.h"
67
#include "torch/csrc/jit/ir/irparser.h"
78

89
namespace mmdeploy {
@@ -126,14 +127,16 @@ void FuseSelectAssign(Block* block, std::unordered_map<std::string, Tensor>& par
126127

127128
void FuseSelectAssign(std::shared_ptr<Graph>& graph,
128129
std::unordered_map<std::string, Tensor>& params) {
130+
// cse before search
131+
CommonSubgraphElimination(graph, params);
132+
129133
std::string pattern_str = R"IR(
130-
graph(%y, %z, %cmp_1, %cmp_2, %start, %axes):
134+
graph(%y, %z, %cmp_1, %cmp_2, %start, %axes, %shape_2):
131135
%nz_1 = onnx::NonZero(%cmp_1)
132136
%trans_1 = onnx::Transpose(%nz_1)
133137
%gather_1 = onnx::GatherND(%z, %trans_1)
134138
%reshape_1_shape = onnx::Constant()
135139
%reshape_1 = onnx::Reshape(%gather_1, %reshape_1_shape)
136-
%shape_2 = onnx::Shape(%y)
137140
%expand_2 = onnx::Expand(%cmp_2, %shape_2)
138141
%nz_2 = onnx::NonZero(%expand_2)
139142
%trans_2 = onnx::Transpose(%nz_2)

mmdeploy/codebase/mmrotate/core/bbox/transforms.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,65 @@ def poly2obb_le90__tensorrt(ctx, polys: torch.Tensor) -> torch.Tensor:
3838
width, _ = torch.max(edges, 1)
3939
height, _ = torch.min(edges, 1)
4040
return torch.stack([x_ctr, y_ctr, width, height, angles], 1)
41+
42+
43+
@FUNCTION_REWRITER.register_rewriter(
44+
func_name='mmrotate.core.bbox.transforms.poly2obb_le135')
45+
def poly2obb_le135__default(ctx, polys):
46+
"""This is a rewrite for poly2obb to remove NonZero ops.
47+
48+
Args:
49+
polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3]
50+
51+
Returns:
52+
obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle]
53+
"""
54+
polys = torch.reshape(polys, [-1, 8])
55+
pt1, pt2, pt3, pt4 = polys[..., :8].chunk(4, 1)
56+
edge1 = torch.sqrt(
57+
torch.pow(pt1[..., 0] - pt2[..., 0], 2) +
58+
torch.pow(pt1[..., 1] - pt2[..., 1], 2))
59+
edge2 = torch.sqrt(
60+
torch.pow(pt2[..., 0] - pt3[..., 0], 2) +
61+
torch.pow(pt2[..., 1] - pt3[..., 1], 2))
62+
angles1 = torch.atan2((pt2[..., 1] - pt1[..., 1]),
63+
(pt2[..., 0] - pt1[..., 0]))
64+
angles2 = torch.atan2((pt4[..., 1] - pt1[..., 1]),
65+
(pt4[..., 0] - pt1[..., 0]))
66+
angles = torch.where(edge1 > edge2, angles1, angles2)
67+
angles = norm_angle(angles, 'le135')
68+
x_ctr = (pt1[..., 0] + pt3[..., 0]) / 2.0
69+
y_ctr = (pt1[..., 1] + pt3[..., 1]) / 2.0
70+
edges = torch.stack([edge1, edge2], dim=1)
71+
width, _ = torch.max(edges, 1)
72+
height, _ = torch.min(edges, 1)
73+
return torch.stack([x_ctr, y_ctr, width, height, angles], 1)
74+
75+
76+
@FUNCTION_REWRITER.register_rewriter(
77+
func_name='mmrotate.core.bbox.transforms.obb2poly_le135')
78+
def obb2poly_le135__default(ctx, rboxes):
79+
"""Support batched input.
80+
81+
Args:
82+
ctx : context of rewriter
83+
obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle]
84+
85+
Returns:
86+
polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3]
87+
"""
88+
B, N = rboxes.shape[:2]
89+
x_ctr, y_ctr, width, height, angle = rboxes[..., 0], rboxes[
90+
..., 1], rboxes[..., 2], rboxes[..., 3], rboxes[..., 4]
91+
tl_x, tl_y, br_x, br_y = \
92+
-width * 0.5, -height * 0.5, \
93+
width * 0.5, height * 0.5
94+
rects = torch.stack([tl_x, br_x, br_x, tl_x, tl_y, tl_y, br_y, br_y],
95+
dim=-1).reshape(B, N, 2, 4)
96+
sin, cos = torch.sin(angle), torch.cos(angle)
97+
M = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(B, N, 2, 2)
98+
polys = M.matmul(rects).permute(0, 1, 3, 2)
99+
xy_ctr = torch.stack([x_ctr, y_ctr], dim=-1).unsqueeze(-2)
100+
polys += xy_ctr
101+
polys = polys.reshape(B, N, 8)
102+
return polys.contiguous()

tests/test_codebase/test_mmrotate/test_mmrotate_core.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,76 @@ def poly2obb_le90(*args, **kwargs):
312312
assert rewrite_outputs is not None
313313

314314

315+
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
316+
def test_poly2obb_le135(backend_type: Backend):
317+
check_backend(backend_type)
318+
polys = torch.rand(1, 10, 8)
319+
deploy_cfg = mmcv.Config(
320+
dict(
321+
onnx_config=dict(output_names=None, input_shape=None),
322+
backend_config=dict(
323+
type=backend_type.value,
324+
model_inputs=[
325+
dict(
326+
input_shapes=dict(
327+
polys=dict(
328+
min_shape=polys.shape,
329+
opt_shape=polys.shape,
330+
max_shape=polys.shape)))
331+
]),
332+
codebase_config=dict(type='mmrotate', task='RotatedDetection')))
333+
334+
# wrap function to enable rewrite
335+
def poly2obb_le135(*args, **kwargs):
336+
import mmrotate
337+
return mmrotate.core.bbox.transforms.poly2obb_le135(*args, **kwargs)
338+
339+
# wrap function to nn.Module, enable torch.onnx.export
340+
wrapped_func = WrapFunction(poly2obb_le135)
341+
rewrite_outputs, is_backend_output = get_rewrite_outputs(
342+
wrapped_func,
343+
model_inputs={'polys': polys},
344+
deploy_cfg=deploy_cfg,
345+
run_with_backend=False)
346+
347+
assert rewrite_outputs is not None
348+
349+
350+
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
351+
def test_obb2poly_le135(backend_type: Backend):
352+
check_backend(backend_type)
353+
rboxes = torch.rand(1, 10, 5)
354+
deploy_cfg = mmcv.Config(
355+
dict(
356+
onnx_config=dict(output_names=None, input_shape=None),
357+
backend_config=dict(
358+
type=backend_type.value,
359+
model_inputs=[
360+
dict(
361+
input_shapes=dict(
362+
rboxes=dict(
363+
min_shape=rboxes.shape,
364+
opt_shape=rboxes.shape,
365+
max_shape=rboxes.shape)))
366+
]),
367+
codebase_config=dict(type='mmrotate', task='RotatedDetection')))
368+
369+
# wrap function to enable rewrite
370+
def obb2poly_le135(*args, **kwargs):
371+
import mmrotate
372+
return mmrotate.core.bbox.transforms.obb2poly_le135(*args, **kwargs)
373+
374+
# wrap function to nn.Module, enable torch.onnx.export
375+
wrapped_func = WrapFunction(obb2poly_le135)
376+
rewrite_outputs, is_backend_output = get_rewrite_outputs(
377+
wrapped_func,
378+
model_inputs={'rboxes': rboxes},
379+
deploy_cfg=deploy_cfg,
380+
run_with_backend=False)
381+
382+
assert rewrite_outputs is not None
383+
384+
315385
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
316386
def test_gvfixcoder__decode(backend_type: Backend):
317387
check_backend(backend_type)

0 commit comments

Comments
 (0)