Skip to content

Commit 6655863

Browse files
JeansBoussierladaapp
authored andcommitted
detection dataset: Allow more than one class per sample
1 parent c0e6c15 commit 6655863

File tree

11 files changed

+128
-92
lines changed

11 files changed

+128
-92
lines changed

lada/bpjdet/inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,9 @@ def _post_process_batch(data, imgs, paths, shapes, body_dets, part_dets):
112112

113113
return batch_bboxes, batch_points, batch_scores, batch_imgids, batch_parts_dict, img_indexs
114114

115-
def get_model(device: str, weights_path: str):
115+
def get_model(device: str):
116116
torch_device = torch.device(device)
117+
weights_path = os.path.join(MODEL_WEIGHTS_DIR, '3rd_party', 'ch_head_s_1536_e150_best_mMR.pt')
117118
return attempt_load(weights_path, map_location=torch_device)
118119

119120
def inference(model, image_path, imgz, data, conf_thres=0.45, iou_thres=0.75) -> list[Box]:

lada/lib/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class VideoMetadata:
5050

5151
@dataclass
5252
class Detection:
53-
cls: str
53+
cls: int
5454
box: Box
5555
mask: Mask # Binary segmentation mask. Values can be either 0 (background) or mask_val
5656

@@ -67,7 +67,7 @@ class Detections:
6767
Mask value is anon-zero value used in binary mask (Mask) to indicate if pixel belongs to the class
6868
"""
6969
DETECTION_CLASSES = {
70-
"nsfw": dict(class_id=0, mask_value=255),
71-
"sfw_head": dict(class_id=1, mask_value=127),
72-
"sfw_face": dict(class_id=2, mask_value=192),
70+
"nsfw": dict(cls=0, mask_value=255),
71+
"sfw_head": dict(cls=1, mask_value=127),
72+
"sfw_face": dict(cls=2, mask_value=192),
7373
}

lada/lib/box_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from lada.lib import Box
2+
3+
def box_overlap(box1: Box, box2: Box):
4+
y1min, x1min, y1max, x1max = box1
5+
y2min, x2min, y2max, x2max = box2
6+
return x1min < x2max and x2min < x1max and y1min < y2max and y2min < y1max

lada/lib/mosaic_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Optional
55

66
from lada.lib.ultralytics_utils import convert_yolo_boxes
7-
from lada.lib.scene_utils import box_overlap
7+
from lada.lib.box_utils import box_overlap
88
from lada.lib import Image, Box
99
from ultralytics import YOLO
1010

lada/lib/mosaic_detector.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ultralytics.engine.results import Results
1414
from lada.lib import Box, Mask, Image, VideoMetadata, threading_utils
1515
from lada.lib import image_utils
16+
from lada.lib.box_utils import box_overlap
1617
from lada.lib.mosaic_detection_model import MosaicDetectionModel
1718
from lada.lib.scene_utils import crop_to_box_v3
1819
from lada.lib import video_utils
@@ -67,16 +68,11 @@ def get_masks(self):
6768
def get_boxes(self):
6869
return [box for _, _, box in self.data]
6970

70-
def box_overlaps(self, box1: Box, box2: Box) -> bool:
71-
y_overlaps = (box1[0] <= box2[0] <= box1[2] or box1[0] <= box2[2] <= box1[2]) or (box2[0] <= box1[0] <= box2[2] or box2[0] <= box1[2] <= box2[2])
72-
x_overlaps = (box1[1] <= box2[1] <= box1[3] or box1[1] <= box2[3] <= box1[3]) or (box2[1] <= box1[1] <= box2[3] or box2[1] <= box1[3] <= box2[3])
73-
return y_overlaps and x_overlaps
74-
7571
def belongs(self, box: Box):
7672
if len(self.data) == 0:
7773
return False
7874
last_scene_box = self.data[-1][2]
79-
return self.box_overlaps(last_scene_box, box)
75+
return box_overlap(last_scene_box, box)
8076

8177
def __iter__(self):
8278
return self

lada/lib/nsfw_frame_detector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@ def __init__(self, model: ultralytics.models.YOLO, device=None, random_extend_ma
3939
self.conf = conf
4040

4141
def detect(self, file_path: str) -> Detections | None:
42-
for results in self.model.predict(source=file_path, stream=False, verbose=False, device=self.device, conf=0.4, iou=0.):
42+
for results in self.model.predict(source=file_path, stream=False, verbose=False, device=self.device, conf=self.conf, iou=0.):
4343
return get_nsfw_frames(results, self.random_extend_masks)

lada/lib/nudenet_nsfw_detector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Optional
55

66
from lada.lib.ultralytics_utils import convert_yolo_boxes
7-
from lada.lib.scene_utils import box_overlap
7+
from lada.lib.box_utils import box_overlap
88
from lada.lib import Image, Box
99
from ultralytics import YOLO
1010

lada/lib/scene_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,6 @@
55

66
from lada.lib import Box, Mask, Image
77

8-
def box_overlap(box1: Box, box2: Box):
9-
t1, l1, b1, r1 = box1
10-
t2, l2, b2, r2 = box2
11-
t = max(t1, t2)
12-
l = max(l1, l2)
13-
b = min(b1, b2)
14-
r = min(r1, r2)
15-
return r > l and b > t
168

179
def crop_to_box_v3(box: Box, img: Image, mask_img: Mask, target_size: tuple[int, int], max_box_expansion_factor=1.0, border_size=0):
1810
"""

lada/lib/ultralytics_utils.py

Lines changed: 34 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -77,85 +77,66 @@ def choose_biggest_detection(result: ultralytics.engine.results.Results, trackin
7777
mask = yolo_mask
7878
return box, mask
7979

80-
def convert_segment_masks_to_yolo_segmentation_labels(masks_dir, output_dir, pixel_to_class_mapping):
80+
def _get_unique_pixel_values(mask: Mask) -> list[int]:
81+
# get unique values except background (0)
82+
unique_values = np.unique(mask).tolist()
83+
if 0 in unique_values: unique_values.remove(0) # remove background class
84+
return unique_values
85+
86+
def convert_segment_masks_to_yolo_labels(masks_dir, output_dir_segmentation_labels, output_dir_detection_labels, pixel_to_class_mapping):
8187
"""
8288
pixel_to_class_mapping is a dict providing a mapping from pixel value to class id.
8389
e.g. if you only have a single class with id 0 and binary masks use pixel value 255 then this would be:
8490
pixel_to_class_mapping = {255: 0}
8591
86-
source: ultralytics.data.converter.convert_segment_masks_to_yolo_seg
92+
Based of: ultralytics.data.converter.convert_segment_masks_to_yolo_seg
8793
"""
94+
def get_yolo_box(contour) -> tuple[float]:
95+
x, y, w, h = cv2.boundingRect(contour)
96+
h, w = mask.shape[:2]
97+
center_x = x + w / 2
98+
center_y = y + h / 2
99+
yolo_box = center_x / w, center_y / h, w / w, h / h
100+
return yolo_box
101+
88102
for mask_path in Path(masks_dir).iterdir():
89103
if mask_path.suffix in {".png", ".jpg"}:
90104
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
91105
img_height, img_width = mask.shape
92106

93-
unique_values = np.unique(mask) # Get unique pixel values representing different classes
94-
yolo_format_data = []
107+
unique_values = _get_unique_pixel_values(mask)
108+
yolo_segmentation_format_data = []
109+
yolo_detection_format_data = []
95110

96111
for value in unique_values:
97-
if value == 0:
98-
continue # Skip background
99112
class_index = pixel_to_class_mapping.get(value, -1)
100113
if class_index == -1:
101114
print(f"Unknown class for pixel value {value} in file {mask_path}, skipping.")
102115
continue
103116

104117
# Create a binary mask for the current class and find contours
105-
contours, _ = cv2.findContours(
106-
(mask == value).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
107-
) # Find contours
118+
binary_mask_for_current_class = (mask == value).astype(np.uint8)
119+
contours, _ = cv2.findContours(binary_mask_for_current_class, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
108120

109121
for contour in contours:
110122
if len(contour) >= 3: # YOLO requires at least 3 points for a valid segmentation
111123
contour = contour.squeeze() # Remove single-dimensional entries
112-
yolo_format = [class_index]
124+
yolo_segmentation_format = [class_index]
113125
for point in contour:
114126
# Normalize the coordinates
115-
yolo_format.append(round(point[0] / img_width, 6)) # Rounding to 6 decimal places
116-
yolo_format.append(round(point[1] / img_height, 6))
117-
yolo_format_data.append(yolo_format)
127+
yolo_segmentation_format.append(round(point[0] / img_width, 6)) # Rounding to 6 decimal places
128+
yolo_segmentation_format.append(round(point[1] / img_height, 6))
129+
yolo_segmentation_format_data.append(yolo_segmentation_format)
130+
yolo_detection_format_data.append(get_yolo_box(contour))
131+
118132
# Save Ultralytics YOLO format data to file
119-
output_path = Path(output_dir) / f"{mask_path.stem}.txt"
133+
output_path = Path(output_dir_segmentation_labels) / f"{mask_path.stem}.txt"
120134
with open(output_path, "w", encoding="utf-8") as file:
121-
for item in yolo_format_data:
135+
for item in yolo_segmentation_format_data:
136+
line = " ".join(map(str, item))
137+
file.write(line + "\n")
138+
output_path = Path(output_dir_detection_labels) / f"{mask_path.stem}.txt"
139+
with open(output_path, "w", encoding="utf-8") as file:
140+
for item in yolo_detection_format_data:
122141
line = " ".join(map(str, item))
123142
file.write(line + "\n")
124-
125-
126-
def convert_binary_mask_to_yolo_detection_labels(masks_dir, output_dir, pixel_to_class_mapping):
127-
"""
128-
pixel_to_class_mapping is a dict providing a mapping from pixel value to class id.
129-
e.g. if you only have a single class with id 0 and binary masks use pixel value 255 then this would be:
130-
pixel_to_class_mapping = {255: 0}
131-
132-
"""
133-
134-
def _convert_binary_mask_to_yolo_detection_labels(mask: Mask) -> tuple[float]:
135-
t, l, b, r = mask_utils.get_box(mask)
136-
h, w = mask.shape[:2]
137-
box_width = r - l
138-
box_height = b - t
139-
box_center_x = l + box_width / 2
140-
box_center_y = t + box_height / 2
141-
yolo_box = box_center_x / w, box_center_y / h, box_width / w, box_height / h
142-
return yolo_box
143-
144-
def _get_class_id(mask: Mask) -> int:
145-
unique_values = np.unique(mask).tolist()
146-
if 0 in unique_values: unique_values.remove(0) # remove background class
147-
assert len(unique_values) == 1, f"only single class / binary segmentation mask supported but found these values: {unique_values}"
148-
mask_val = unique_values[0]
149-
150-
class_id = pixel_to_class_mapping.get(mask_val, -1)
151-
assert class_id != -1, f"Unknown class for pixel value {mask_val} in file {mask_path}"
152-
return class_id
153-
154-
for mask_path in Path(masks_dir).iterdir():
155-
if mask_path.suffix in {".png", ".jpg"}:
156-
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
157-
class_id = _get_class_id(mask)
158-
yolo_box = _convert_binary_mask_to_yolo_detection_labels(mask)
159-
label_file_path = Path(output_dir).joinpath(Path(mask_path).with_suffix('.txt').name)
160-
with open(label_file_path, 'a') as file:
161-
file.write(f"{class_id} {yolo_box[0]} {yolo_box[1]} {yolo_box[2]} {yolo_box[3]}")

lada/lib/watermark_detector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Optional
55

66
from lada.lib.ultralytics_utils import convert_yolo_boxes
7-
from lada.lib.scene_utils import box_overlap
7+
from lada.lib.box_utils import box_overlap
88
from lada.lib import Image, Box
99
from ultralytics import YOLO
1010

0 commit comments

Comments
 (0)