Skip to content

Commit 8587c1f

Browse files
committed
add script of check dataset information
1 parent 2717c9b commit 8587c1f

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ frankfurt_000001_079206_leftImg8bit.png,frankfurt_000001_079206_gtFine_labelIds.
109109
...
110110
```
111111
Each line is a pair of training sample and ground truth image path, which are separated by a single comma `,`.
112+
I recommand you to check the information of your dataset with the script:
113+
```
114+
$ python tools/check_dataset_info.py --im_root /path/to/your/data_root --im_anns /path/to/your/anno_file
115+
```
112116
Then you need to change the field of `im_root` and `train/val_im_anns` in the config file. I prepared a demo config file for you named [`bisenet_customer.py`](./configs/bisenet_customer.py). You can start from this conig file.
113117

114118

tools/check_dataset_info.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
2+
import os
3+
import os.path as osp
4+
import argparse
5+
from tqdm import tqdm
6+
7+
import cv2
8+
import numpy as np
9+
10+
11+
parse = argparse.ArgumentParser()
12+
parse.add_argument('--im_root', dest='im_root', type=str, default='./datasets/cityscapes',)
13+
parse.add_argument('--im_anns', dest='im_anns', type=str, default='./datasets/cityscapes/train.txt',)
14+
args = parse.parse_args()
15+
16+
17+
with open(args.im_anns, 'r') as fr:
18+
lines = fr.read().splitlines()
19+
20+
n_pairs = len(lines)
21+
impaths, lbpaths = [], []
22+
for l in lines:
23+
impth, lbpth = l.split(',')
24+
impth = osp.join(args.im_root, impth)
25+
lbpth = osp.join(args.im_root, lbpth)
26+
impaths.append(impth)
27+
lbpaths.append(lbpth)
28+
29+
30+
## shapes
31+
max_shape_area, min_shape_area = [0, 0], [100000, 100000]
32+
max_shape_height, min_shape_height = [0, 0], [100000, 100000]
33+
max_shape_width, min_shape_width = [0, 0], [100000, 100000]
34+
max_lb_val, min_lb_val = -1, 10000000
35+
for impth, lbpth in tqdm(zip(impaths, lbpaths), total=n_pairs):
36+
im = cv2.imread(impth)[:, :, ::-1]
37+
lb = cv2.imread(lbpth, 0)
38+
assert im.shape[:2] == lb.shape
39+
40+
shape = lb.shape
41+
area = shape[0] * shape[1]
42+
if area > max_shape_area[0] * max_shape_area[1]:
43+
max_shape_area = shape
44+
if area < min_shape_area[0] * min_shape_area[1]:
45+
min_shape_area = shape
46+
47+
if shape[0] > max_shape_height[0]:
48+
max_shape_height = shape
49+
if shape[0] < min_shape_height[0]:
50+
min_shape_height = shape
51+
52+
if shape[1] > max_shape_width[1]:
53+
max_shape_width = shape
54+
if shape[1] < min_shape_width[1]:
55+
min_shape_width = shape
56+
57+
max_lb_val = max(max_lb_val, np.max(lb.ravel()))
58+
min_lb_val = min(min_lb_val, np.min(lb.ravel()))
59+
60+
61+
## label info
62+
lb_minlength = max_lb_val+1-min_lb_val
63+
lb_hist = np.zeros(lb_minlength)
64+
for impth in tqdm(impaths):
65+
lb = cv2.imread(lbpth, 0).ravel() + min_lb_val
66+
lb_hist += np.bincount(lb, minlength=lb_minlength)
67+
68+
lb_missing_vals = [ind + min_lb_val
69+
for ind, el in enumerate(lb_hist.tolist()) if el == 0]
70+
lb_ratios = (lb_hist / lb_hist.sum()).tolist()
71+
72+
73+
## pixel mean/std
74+
rgb_mean = np.zeros(3).astype(np.float32)
75+
n_pixels = 0
76+
for impth in tqdm(impaths):
77+
im = cv2.imread(impth)[:, :, ::-1].astype(np.float32)
78+
im = im.reshape(-1, 3)
79+
n_pixels += im.shape[0]
80+
rgb_mean += im.sum(axis=0)
81+
rgb_mean = rgb_mean / n_pixels
82+
83+
rgb_std = np.zeros(3).astype(np.float32)
84+
for impth in tqdm(impaths):
85+
im = cv2.imread(impth)[:, :, ::-1].astype(np.float32)
86+
im = im.reshape(-1, 3)
87+
88+
a = (im - rgb_mean.reshape(1, 3)) ** 2
89+
rgb_std += a.sum(axis=0)
90+
rgb_std = (rgb_std / n_pixels) ** (0.5)
91+
92+
93+
print(f'there are {n_pairs} lines in {args.im_anns}, which means {n_pairs} image/label image pairs')
94+
print('\n')
95+
96+
print('max and min image shapes by area are: ')
97+
print(f'\t{max_shape_area}, {min_shape_area}')
98+
print('max and min image shapes by height are: ')
99+
print(f'\t{max_shape_height}, {min_shape_height}')
100+
print('max and min image shapes by width are: ')
101+
print(f'\t{max_shape_width}, {min_shape_width}')
102+
print('\n')
103+
104+
print(f'label values are within range of ({min_lb_val}, {max_lb_val})')
105+
print('label values that are missing: ')
106+
print('\t', lb_missing_vals)
107+
print('ratios of each label value: ')
108+
print('\t', lb_ratios)
109+
print('\n')
110+
111+
print('pixel mean rgb: ', mean)
112+
print('pixel std rgb: ', std)

0 commit comments

Comments
 (0)