@@ -69,11 +69,11 @@ def visualize_results_usual_yolo_inference(
6969 if random_object_colors :
7070 random .seed (int (delta_colors ))
7171
72+ class_names = model .names
73+
7274 # Process each prediction
7375 for pred in predictions :
7476
75- class_names = pred .names
76-
7777 # Get the bounding boxes and convert them to a list of lists
7878 boxes = pred .boxes .xyxy .cpu ().int ().tolist ()
7979
@@ -86,11 +86,11 @@ def visualize_results_usual_yolo_inference(
8686 num_objects = len (classes )
8787
8888 if segment :
89- # Get the masks
89+ # Get the polygons
9090 try :
91- masks = pred .masks .data . cpu (). numpy ()
91+ polygons = pred .masks .xy
9292 except :
93- masks = []
93+ polygons = []
9494
9595 # Visualization
9696 for i in range (num_objects ):
@@ -111,26 +111,17 @@ def visualize_results_usual_yolo_inference(
111111 box = boxes [i ]
112112 x_min , y_min , x_max , y_max = box
113113
114- if segment :
115- mask = masks [i ]
116- # Resize mask to the size of the original image using nearest neighbor interpolation
117- mask_resized = cv2 .resize (
118- np .array (mask ), (img .shape [1 ], img .shape [0 ]), interpolation = cv2 .INTER_NEAREST
119- )
120- # Add label to the mask
121- mask_contours , _ = cv2 .findContours (
122- mask_resized .astype (np .uint8 ), cv2 .RETR_EXTERNAL , cv2 .CHAIN_APPROX_SIMPLE
123- )
124-
125- if fill_mask :
126- if alpha == 1 :
127- cv2 .fillPoly (labeled_image , pts = mask_contours , color = color )
128- else :
129- color_mask = np .zeros_like (img )
130- color_mask [mask_resized > 0 ] = color
131- labeled_image = cv2 .addWeighted (labeled_image , 1 , color_mask , alpha , 0 )
132-
133- cv2 .drawContours (labeled_image , mask_contours , - 1 , color , thickness )
114+ if segment and len (polygons ) > 0 :
115+ if len (polygons [i ]) > 0 :
116+ points = np .array (polygons [i ].reshape ((- 1 , 1 , 2 )), dtype = np .int32 )
117+ if fill_mask :
118+ if alpha == 1 :
119+ cv2 .fillPoly (labeled_image , pts = [points ], color = color )
120+ else :
121+ mask_from_poly = np .zeros_like (img )
122+ color_mask_from_poly = cv2 .fillPoly (mask_from_poly , pts = [points ], color = color )
123+ labeled_image = cv2 .addWeighted (labeled_image , 1 , color_mask_from_poly , alpha , 0 )
124+ cv2 .drawContours (labeled_image , [points ], - 1 , color , thickness )
134125
135126 # Write class label
136127 if show_boxes :
0 commit comments