1+ #!/usr/bin/env python3
2+ # -*- coding: utf-8 -*-
3+ # Copyright (c) Megvii, Inc. and its affiliates.
4+
5+ # ROS2 rclpy -- Ar-Ray-code 2022
6+ import argparse
7+ import os
8+
9+ import cv2
10+ import numpy as np
11+
12+ import onnxruntime
13+
14+ from yolox .data .data_augment import preproc as preprocess
15+ from yolox .data .datasets import COCO_CLASSES
16+ from yolox .utils import mkdir , multiclass_nms , demo_postprocess , vis
17+
18+ # ROS2 =====================================
19+ import rclpy
20+ from rclpy .node import Node
21+
22+ from std_msgs .msg import Header
23+ from cv_bridge import CvBridge
24+ from sensor_msgs .msg import Image
25+
26+ from bboxes_ex_msgs .msg import BoundingBoxes
27+ from bboxes_ex_msgs .msg import BoundingBox
28+
29+ # from darkself.net_ros_msgs.msg import BoundingBoxes
30+ # from darkself.net_ros_msgs.msg import BoundingBox
31+
32+ class yolox_ros (Node ):
33+ def __init__ (self ) -> None :
34+
35+ # ROS2 init
36+ super ().__init__ ('yolox_ros' )
37+
38+ self .setting_yolox_exp ()
39+
40+ if (self .imshow_isshow ):
41+ cv2 .namedWindow ("YOLOX" )
42+
43+ self .bridge = CvBridge ()
44+
45+ self .pub = self .create_publisher (BoundingBoxes ,"yolox/bounding_boxes" , 10 )
46+ self .pub_image = self .create_publisher (Image ,"yolox/image_raw" , 10 )
47+ self .sub = self .create_subscription (Image ,"image_raw" ,self .imageflow_callback , 10 )
48+
49+ def setting_yolox_exp (self ) -> None :
50+ # set environment variables for distributed training
51+
52+ # ==============================================================
53+
54+ ONNX_PATH = './install/yolox_ros_py/share/yolox_ros_py/yolox_nano.onnx'
55+
56+ self .declare_parameter ('imshow_isshow' ,True )
57+
58+ self .declare_parameter ('model_path' , ONNX_PATH )
59+ self .declare_parameter ('conf' , 0.3 )
60+ self .declare_parameter ('with_p6' , False )
61+ self .declare_parameter ('input_shape/width' , 416 )
62+ self .declare_parameter ('input_shape/height' , 416 )
63+
64+ self .declare_parameter ('image_size/width' , 640 )
65+ self .declare_parameter ('image_size/height' , 480 )
66+
67+ # =============================================================
68+ self .imshow_isshow = self .get_parameter ('imshow_isshow' ).value
69+
70+ self .model_path = self .get_parameter ('model_path' ).value
71+ self .conf = self .get_parameter ('conf' ).value
72+
73+ self .input_width = self .get_parameter ('image_size/width' ).value
74+ self .input_height = self .get_parameter ('image_size/height' ).value
75+ self .input_shape_w = self .get_parameter ('input_shape/width' ).value
76+ self .input_shape_h = self .get_parameter ('input_shape/height' ).value
77+
78+ # ==============================================================
79+ self .with_p6 = self .get_parameter ('with_p6' ).value
80+
81+ self .get_logger ().info ('model_path: {}' .format (self .model_path ))
82+ self .get_logger ().info ('conf: {}' .format (self .conf ))
83+ self .get_logger ().info ('input_shape: {}' .format ((self .input_shape_w , self .input_shape_h )))
84+ self .get_logger ().info ('image_size: {}' .format ((self .input_width , self .input_height )))
85+
86+
87+ self .input_shape = (self .input_shape_h , self .input_shape_w )
88+
89+
90+ def yolox2bboxes_msgs (self , bboxes , scores , cls , cls_names , img_header :Header ):
91+ bboxes_msg = BoundingBoxes ()
92+ bboxes_msg .header = img_header
93+ i = 0
94+ for bbox in bboxes :
95+ one_box = BoundingBox ()
96+ one_box .xmin = int (bbox [0 ])
97+ one_box .ymin = int (bbox [1 ])
98+ one_box .xmax = int (bbox [2 ])
99+ one_box .ymax = int (bbox [3 ])
100+ one_box .probability = float (scores [i ])
101+ one_box .class_id = str (cls_names [int (cls [i ])])
102+ bboxes_msg .bounding_boxes .append (one_box )
103+ i = i + 1
104+
105+ return bboxes_msg
106+
107+ def imageflow_callback (self ,msg :Image ) -> None :
108+ try :
109+ # fps start
110+ start_time = cv2 .getTickCount ()
111+ bboxes = BoundingBoxes ()
112+ origin_img = self .bridge .imgmsg_to_cv2 (msg ,"bgr8" )
113+ # resize
114+ img = cv2 .resize (origin_img , (self .input_width , self .input_height ))
115+
116+ # preprocess
117+ img , self .ratio = preprocess (origin_img , self .input_shape )
118+
119+ session = onnxruntime .InferenceSession (self .model_path )
120+
121+ ort_inputs = {session .get_inputs ()[0 ].name : img [None , :, :, :]}
122+ output = session .run (None , ort_inputs )
123+
124+ predictions = demo_postprocess (output [0 ], self .input_shape , p6 = self .with_p6 )[0 ]
125+
126+ boxes = predictions [:, :4 ]
127+ scores = predictions [:, 4 :5 ] * predictions [:, 5 :]
128+
129+ boxes_xyxy = np .ones_like (boxes )
130+ boxes_xyxy [:, 0 ] = boxes [:, 0 ] - boxes [:, 2 ]/ 2.
131+ boxes_xyxy [:, 1 ] = boxes [:, 1 ] - boxes [:, 3 ]/ 2.
132+ boxes_xyxy [:, 2 ] = boxes [:, 0 ] + boxes [:, 2 ]/ 2.
133+ boxes_xyxy [:, 3 ] = boxes [:, 1 ] + boxes [:, 3 ]/ 2.
134+ boxes_xyxy /= self .ratio
135+ dets = multiclass_nms (boxes_xyxy , scores , nms_thr = 0.45 , score_thr = self .conf )
136+ if dets is not None :
137+ self .final_boxes , self .final_scores , self .final_cls_inds = dets [:, :4 ], dets [:, 4 ], dets [:, 5 ]
138+ origin_img = vis (origin_img , self .final_boxes , self .final_scores , self .final_cls_inds ,
139+ conf = self .conf , class_names = COCO_CLASSES )
140+
141+ end_time = cv2 .getTickCount ()
142+ time_took = (end_time - start_time ) / cv2 .getTickFrequency ()
143+
144+ # rclpy log FPS
145+ self .get_logger ().info (f'FPS: { 1 / time_took } ' )
146+
147+ try :
148+ bboxes = self .yolox2bboxes_msgs (dets [:, :4 ], self .final_scores , self .final_cls_inds , COCO_CLASSES , msg .header )
149+ # self.get_logger().info(f'bboxes: {bboxes}')
150+ if (self .imshow_isshow ):
151+ cv2 .imshow ("YOLOX" ,origin_img )
152+ cv2 .waitKey (1 )
153+
154+ except :
155+ # self.get_logger().info('No object detected')
156+ if (self .imshow_isshow ):
157+ cv2 .imshow ("YOLOX" ,origin_img )
158+ cv2 .waitKey (1 )
159+
160+ self .pub .publish (bboxes )
161+ self .pub_image .publish (self .bridge .cv2_to_imgmsg (origin_img ,"bgr8" ))
162+
163+ except Exception as e :
164+ self .get_logger ().info (f'Error: { e } ' )
165+ pass
166+
167+
168+ def ros_main (args = None ):
169+ rclpy .init (args = args )
170+ ros_class = yolox_ros ()
171+
172+ try :
173+ rclpy .spin (ros_class )
174+ except KeyboardInterrupt :
175+ pass
176+ finally :
177+ ros_class .destroy_node ()
178+ rclpy .shutdown ()
179+
180+ if __name__ == "__main__" :
181+ ros_main ()
0 commit comments