Skip to content

Commit 6064b38

Browse files
add segmentation benchmark
1 parent 7d86141 commit 6064b38

File tree

5 files changed

+121
-2
lines changed

5 files changed

+121
-2
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
Benchmark:
2+
name: "Image Segmentation Benchmark"
3+
type: "Segmentation"
4+
data:
5+
path: "data/image_segmentation"
6+
files: ["messi5.jpg", "100040721_1.jpg"]
7+
sizes: # [[w1, h1], ...], Omit to run at original scale
8+
- [640, 640]
9+
metric:
10+
warmup: 30
11+
repeat: 10
12+
backend: "default"
13+
target: "cpu"
14+
15+
Model:
16+
name: "EfficientSAM"

benchmark/utils/dataloaders/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .classification import ClassificationImageLoader
33
from .recognition import RecognitionImageLoader
44
from .tracking import TrackingVideoLoader
5+
from .segmentation import SegmentationImageLoader
56

6-
__all__ = ['BaseImageLoader', 'BaseVideoLoader', 'ClassificationImageLoader', 'RecognitionImageLoader', 'TrackingVideoLoader']
7+
__all__ = ['BaseImageLoader', 'BaseVideoLoader', 'ClassificationImageLoader', 'RecognitionImageLoader', 'SegmentationImageLoader', 'TrackingVideoLoader']
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import os
2+
3+
import numpy as np
4+
import cv2 as cv
5+
6+
from .base_dataloader import _BaseImageLoader
7+
from ..factory import DATALOADERS
8+
9+
@DATALOADERS.register
10+
class SegmentationImageLoader(_BaseImageLoader):
11+
def __init__(self, **kwargs):
12+
super().__init__(**kwargs)
13+
14+
self._to_rgb = kwargs.pop('toRGB', False)
15+
self._point_label= self._load_point_and_label()
16+
17+
def _load_point_and_label(self):
18+
points_labels = dict.fromkeys(self._files, None)
19+
for filename in self._files:
20+
if os.path.exists(os.path.join(self._path, '{}.txt'.format(filename[:-4]))):
21+
points_labels[filename] = np.loadtxt(os.path.join(self._path, '{}.txt'.format(filename[:-4])), ndmin=2)
22+
else:
23+
points_labels[filename] = None
24+
# for filename in self._files:
25+
# label_file = os.path.join(self._path, '{}.txt'.format(filename[:-4]))
26+
# if os.path.exists(label_file):
27+
# # 假设标签文件的每一行格式为:x y label
28+
# # 其中 x, y 是点的坐标,label 是标签(0 或 1)
29+
# with open(label_file, 'r') as file:
30+
# lines = file.readlines()
31+
# current_point_label = []
32+
# for line in lines:
33+
# parts = line.strip().split()
34+
# if len(parts) == 3:
35+
# x, y, label = map(int, parts)
36+
# current_point_label.append((x, y, label))
37+
# points_labels[filename] = current_point_label
38+
# else:
39+
# points_labels[filename] = None
40+
return points_labels
41+
42+
43+
def _toRGB(self, image):
44+
return cv.cvtColor(image, cv.COLOR_BGR2RGB)
45+
46+
def __iter__(self):
47+
for filename in self._files:
48+
image = cv.imread(os.path.join(self._path, filename))
49+
50+
if self._to_rgb:
51+
image = self._toRGB(image)
52+
53+
if [0, 0] in self._sizes:
54+
point_and_label = self._point_label.get(filename)
55+
if point_and_label is not None:
56+
yield filename, image, point_and_label
57+
else:
58+
yield filename, image, None
59+
else:
60+
for size in self._sizes:
61+
image_r = cv.resize(image, size)
62+
point_and_label = self._point_label.get(filename)
63+
if point_and_label is not None:
64+
yield filename, image_r, point_and_label
65+
else:
66+
yield filename, image_r, None

benchmark/utils/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .detection import Detection
33
from .recognition import Recognition
44
from .tracking import Tracking
5+
from .segmentation import Segmentation
56

6-
__all__ = ['Base', 'Detection', 'Recognition', 'Tracking']
7+
__all__ = ['Base', 'Detection', 'Recognition', 'Segmentation', 'Tracking']
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import cv2 as cv
2+
3+
from .base_metric import BaseMetric
4+
from ..factory import METRICS
5+
6+
@METRICS.register
7+
class Segmentation(BaseMetric):
8+
def __init__(self, **kwargs):
9+
super().__init__(**kwargs)
10+
11+
def forward(self, model, *args, **kwargs):
12+
img, point_and_label = args
13+
size = [img.shape[1], img.shape[0]]
14+
self._timer.reset()
15+
if point_and_label is not None:
16+
for idx, pl in enumerate(point_and_label):
17+
point = [[pl[0], pl[1]]]
18+
label = [[pl[2]]]
19+
for _ in range(self._warmup):
20+
model.infer(img, point, label)
21+
for _ in range(self._repeat):
22+
self._timer.start()
23+
model.infer(img, point, label)
24+
self._timer.stop()
25+
else:
26+
point = [[int(size[0]/2), int(size[1]/2)]]
27+
label = [[1]]
28+
for _ in range(self._warmup):
29+
model.infer(img, point, label)
30+
for _ in range(self._repeat):
31+
self._timer.start()
32+
model.infer(img, point, label)
33+
self._timer.stop()
34+
35+
return self._timer.getRecords()

0 commit comments

Comments
 (0)