1+ import os
2+ import cv2
3+ import numpy as np
4+ import onnxruntime as ort
5+
6+
7+ class YOLODetection :
8+ def __init__ (self , model_path , labels_path ):
9+ self .path = os .path .dirname (os .path .dirname (os .path .abspath (__file__ )))
10+ self .model_path = model_path
11+ self .labels_path = labels_path
12+ self .class_conf = 0.3
13+ self .nms_thresh = 0.45
14+ self .labels = [line .strip () for line in open (self .labels_path , 'r' ).readlines ()]
15+ self .infer_session = self .init_infer_session ()
16+ self .warm_up_times = 1
17+ self .input_name = self .infer_session .get_inputs ()[0 ].name
18+ self .output_name = self .infer_session .get_outputs ()[0 ].name
19+ self .input_size = self .infer_session .get_inputs ()[0 ].shape [2 :4 ]
20+
21+ def init_infer_session (self ):
22+ session_options = ort .SessionOptions ()
23+ session_options .intra_op_num_threads = 4
24+
25+ # 加载 ONNX 模型
26+ session = ort .InferenceSession (self .model_path , sess_options = session_options )
27+
28+ return session
29+
30+ def warm_up (self ):
31+ warm_up_img = np .random .rand (1 , 3 , self .input_size [0 ], self .input_size [1 ]).astype (np .float32 )
32+
33+ for i in range (self .warm_up_times ):
34+ self .infer_session .run ([self .output_name ], {self .input_name : warm_up_img })
35+
36+ def infer (self , image ):
37+ img = image .copy ()
38+
39+ # 图像预处理
40+ input_tensor = self .preprocess (img , self .input_size )
41+ # 进行推理
42+ outputs = self .infer_session .run ([self .output_name ], {self .input_name : input_tensor })
43+ output = outputs [0 ]
44+ offset = output .shape [1 ]
45+ anchors = output .shape [2 ]
46+
47+ # 后处理
48+ dets = self .postprocess (image , output , anchors , offset , self .class_conf , self .input_size )
49+ dets = self .nms (dets )
50+
51+ rect_result = self .convert_rect_list (dets )
52+ if rect_result :
53+ center_x , center_y = self .convert_rect_list (dets )
54+ # 绘制结果
55+ result_img = self .draw_result (img , dets , self .labels )
56+ # 返回类别索引和名称列表
57+ class_ids = [int (det [4 ]) for det in dets ] # 获取所有检测到物体的类别索引
58+ # class_names = [self.labels[int(id)] for id in class_ids] # 获取类别名称
59+ return center_x , center_y , class_ids , result_img
60+ return None
61+
62+ def preprocess (self , image , input_size = (320 , 320 )):
63+ shape = image .shape [:2 ]
64+ pad_color = (0 , 0 , 0 )
65+ # 调整图像大小
66+ # Scale ratio
67+ r = min (input_size [0 ] / shape [0 ], input_size [1 ] / shape [1 ])
68+ # Compute padding
69+ ratio = r # width, height ratios
70+ new_unpad = int (round (shape [1 ] * r )), int (round (shape [0 ] * r ))
71+ dw , dh = input_size [1 ] - new_unpad [0 ], input_size [0 ] - new_unpad [1 ] # wh padding
72+ dw /= 2 # divide padding into 2 sides
73+ dh /= 2
74+ if shape [::- 1 ] != new_unpad : # resize
75+ image = cv2 .resize (image , new_unpad , interpolation = cv2 .INTER_LINEAR )
76+ image = cv2 .cvtColor (image , cv2 .COLOR_BGR2RGB )
77+ top , bottom = int (round (dh - 0.1 )), int (round (dh + 0.1 ))
78+ left , right = int (round (dw - 0.1 )), int (round (dw + 0.1 ))
79+ image = cv2 .copyMakeBorder (image , top , bottom , left , right , cv2 .BORDER_CONSTANT , value = pad_color ) # add border
80+
81+ # 归一化处理
82+ image = image .astype (np .float32 ) / 255.0
83+ # 调整维度以匹配模型输入 [batch, channel, height, width]
84+ image = np .transpose (image , (2 , 0 , 1 ))
85+ image = np .expand_dims (image , axis = 0 )
86+
87+ return image
88+
89+ def postprocess (self , image , output , anchors , offset , conf_threshold , input_size = (320 , 320 )):
90+ # 获取图像的高和宽
91+ shape = image .shape [:2 ]
92+ # 计算缩放比例
93+ r = min (input_size [0 ] / shape [0 ], input_size [1 ] / shape [1 ])
94+ # 计算新的未填充尺寸
95+ new_unpad = (int (round (shape [1 ] * r )), int (round (shape [0 ] * r )))
96+ # 计算填充量
97+ dw , dh = input_size [1 ] - new_unpad [0 ], input_size [0 ] - new_unpad [1 ]
98+ # 将填充量平分到两侧
99+ dw /= 2
100+ dh /= 2
101+
102+ # 去除 output 多余的维度
103+ output = output .squeeze ()
104+
105+ # 提取每个锚点对应的边界框信息(中心坐标、宽高)
106+ center_x = output [0 , :anchors ]
107+ center_y = output [1 , :anchors ]
108+ box_width = output [2 , :anchors ]
109+ box_height = output [3 , :anchors ]
110+
111+ # 提取每个锚点对应的所有类别概率
112+ class_probs = output [4 :offset , :anchors ]
113+
114+ # 找出每个锚点下概率最大的类别索引及其概率值
115+ max_prob_indices = np .argmax (class_probs , axis = 0 )
116+ max_probs = class_probs [max_prob_indices , np .arange (anchors )]
117+
118+ # 过滤掉置信度低于阈值的锚点
119+ valid_mask = max_probs > conf_threshold
120+ valid_center_x = center_x [valid_mask ]
121+ valid_center_y = center_y [valid_mask ]
122+ valid_box_width = box_width [valid_mask ]
123+ valid_box_height = box_height [valid_mask ]
124+ valid_max_prob_indices = max_prob_indices [valid_mask ]
125+ valid_max_probs = max_probs [valid_mask ]
126+
127+ # 过滤掉类别为 'person'(COCO 数据集中 'person' 类别的索引是 0)
128+ valid_mask = valid_max_prob_indices != 0 # 排除掉 'person' 类别,防止误识别手
129+ valid_center_x = valid_center_x [valid_mask ]
130+ valid_center_y = valid_center_y [valid_mask ]
131+ valid_box_width = valid_box_width [valid_mask ]
132+ valid_box_height = valid_box_height [valid_mask ]
133+ valid_max_prob_indices = valid_max_prob_indices [valid_mask ]
134+ valid_max_probs = valid_max_probs [valid_mask ]
135+
136+ # 计算边界框坐标
137+ half_width = valid_box_width / 2
138+ half_height = valid_box_height / 2
139+ x1 = np .maximum (0 , ((valid_center_x - half_width ) - dw ) / r ).astype (int )
140+ x2 = np .maximum (0 , ((valid_center_x + half_width ) - dw ) / r ).astype (int )
141+ y1 = np .maximum (0 , ((valid_center_y - half_height ) - dh ) / r ).astype (int )
142+ y2 = np .maximum (0 , ((valid_center_y + half_height ) - dh ) / r ).astype (int )
143+
144+ # 组合结果
145+ objects = np .column_stack ((x1 , y1 , x2 , y2 , valid_max_prob_indices , valid_max_probs )).tolist ()
146+
147+ return objects
148+
149+ def nms (self , dets ):
150+ if len (dets ) == 0 :
151+ return np .empty ((0 , 6 ))
152+
153+ dets_array = np .array (dets )
154+ # 按类别分组
155+ unique_labels = np .unique (dets_array [:, 4 ])
156+ final_dets = []
157+
158+ for label in unique_labels :
159+ # 获取当前类别的检测结果
160+ mask = dets_array [:, 4 ] == label
161+ dets_class = dets_array [mask ]
162+
163+ # 按置信度从高到低排序
164+ order = np .argsort (- dets_class [:, 5 ])
165+ dets_class = dets_class [order ]
166+
167+ # 逐个进行 NMS
168+ keep = []
169+ while dets_class .shape [0 ] > 0 :
170+ # 保留当前置信度最高的检测结果
171+ keep .append (dets_class [0 ])
172+ if dets_class .shape [0 ] == 1 :
173+ break
174+
175+ # 计算当前框与其他框的 IoU
176+ ious = self .calculate_iou (keep [- 1 ], dets_class [1 :])
177+ # 去除 IoU 大于阈值的框
178+ dets_class = dets_class [1 :][ious < self .nms_thresh ]
179+
180+ # 将当前类别的结果添加到最终结果中
181+ final_dets .extend (keep )
182+
183+ return final_dets
184+
185+ def calculate_iou (self , box , boxes ):
186+ """
187+ 计算一个框与一组框的 IoU
188+ :param box: 单个框 [x1, y1, x2, y2]
189+ :param boxes: 一组框 [N, 4]
190+ :return: IoU 值 [N]
191+ """
192+ # 计算交集区域
193+ x1 = np .maximum (box [0 ], boxes [:, 0 ])
194+ y1 = np .maximum (box [1 ], boxes [:, 1 ])
195+ x2 = np .minimum (box [2 ], boxes [:, 2 ])
196+ y2 = np .minimum (box [3 ], boxes [:, 3 ])
197+ inter_area = np .maximum (0 , x2 - x1 ) * np .maximum (0 , y2 - y1 )
198+
199+ # 计算并集区域
200+ box_area = (box [2 ] - box [0 ]) * (box [3 ] - box [1 ])
201+ boxes_area = (boxes [:, 2 ] - boxes [:, 0 ]) * (boxes [:, 3 ] - boxes [:, 1 ])
202+ union_area = box_area + boxes_area - inter_area
203+
204+ # 计算 IoU
205+ return inter_area / union_area
206+
207+ # 可视化结果
208+ def draw_result (self , image , dets , class_names , color = (0 , 255 , 0 ), thickness = 2 ):
209+ image = image .copy ()
210+ image_h , image_w = image .shape [:2 ]
211+
212+ for det in dets :
213+ x1 , y1 , x2 , y2 , label , score = det
214+ x1 = int (x1 )
215+ y1 = int (y1 )
216+ x2 = int (x2 )
217+ y2 = int (y2 )
218+ center_x , center_y = (x1 + x2 ) // 2 , (y1 + y2 ) // 2 # 计算中心坐标
219+ # 绘制边界框
220+ cv2 .rectangle (image , (x1 , y1 ), (x2 , y2 ), (0 , 255 , 0 ), 2 )
221+ cv2 .circle (image , (center_x , center_y ), 5 , (0 , 0 , 255 ), - 1 ) # 画出中心点
222+ cv2 .putText (image , f'{ class_names [int (label )]} : { score :.2f} ' , (x1 , y1 - 10 ), cv2 .FONT_HERSHEY_SIMPLEX , 0.9 ,
223+ (0 , 255 , 0 ), 2 )
224+
225+ return image
226+
227+ def convert_rect_list (self , original_list ):
228+ converted_list = []
229+ center_x = 0
230+ center_y = 0
231+ for x1 , y1 , x2 , y2 , label , prob in original_list :
232+ width = x2 - x1
233+ height = y2 - y1
234+ new_rect = ((x1 , y1 ), width , height , label , prob )
235+ converted_list .append (new_rect )
236+ center_x = x1 + width // 2 # 计算中心点x坐标
237+ center_y = y1 + height // 2 # 计算中心点y坐标
238+ if center_x + center_y > 0 :
239+ return center_x , center_y
240+ else :
241+ return None
0 commit comments