Skip to content

Commit 22f0674

Browse files
authored
feat: Split validset with StratifiedGroupKFold (#14)
* feat: Add split_validset.py #12 * feat: Add split_validset_yolo.py #12
1 parent 5c9db08 commit 22f0674

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

utils/split_validset.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import os
2+
import json
3+
import numpy as np
4+
from sklearn.model_selection import StratifiedGroupKFold
5+
6+
def main():
7+
annotation = '/data/ephemeral/home/dataset/train.json'
8+
9+
with open(annotation) as f:
10+
data = json.load(f)
11+
12+
var = [(ann['image_id'], ann['category_id']) for ann in data['annotations']]
13+
X = np.ones((len(data['annotations']), 1))
14+
y = np.array([v[1] for v in var])
15+
groups = np.array([v[0] for v in var])
16+
17+
save_dir = '/data/ephemeral/home/dataset/split'
18+
if not os.path.isdir(save_dir):
19+
os.mkdir(save_dir)
20+
21+
SEED = 42
22+
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=SEED)
23+
24+
for fold, (train_idx, val_idx) in enumerate(sgkf.split(X, y, groups), start=1):
25+
26+
train_img_ids = set([data['annotations'][idx]['image_id'] for idx in train_idx])
27+
train_imgs = [data['images'][idx] for idx in train_img_ids]
28+
29+
train_data = {
30+
'images': train_imgs,
31+
'categories': data['categories'],
32+
'annotations': [data['annotations'][idx] for idx in train_idx]
33+
}
34+
35+
val_img_ids = set([data['annotations'][idx]['image_id'] for idx in val_idx])
36+
val_imgs = [data['images'][idx] for idx in val_img_ids]
37+
38+
val_data = {
39+
'images': val_imgs,
40+
'categories' : data['categories'],
41+
'annotations': [data['annotations'][idx] for idx in val_idx]
42+
}
43+
44+
train_path = os.path.join(save_dir, f'train_{SEED}_fold_{fold}.json')
45+
with open(train_path, 'w') as f:
46+
json.dump(train_data, f, indent=4)
47+
48+
val_path = os.path.join(save_dir, f'val_{SEED}_fold_{fold}.json')
49+
with open(val_path, 'w') as f:
50+
json.dump(val_data, f, indent=4)
51+
52+
if __name__ == '__main__':
53+
main()

utils/split_validset_yolo.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import os
2+
import json
3+
import numpy as np
4+
from sklearn.model_selection import StratifiedGroupKFold
5+
6+
def main():
7+
annotation = '/data/ephemeral/home/dataset/train.json'
8+
9+
with open(annotation) as f:
10+
data = json.load(f)
11+
12+
var = [(ann['image_id'], ann['category_id']) for ann in data['annotations']]
13+
X = np.ones((len(data['annotations']), 1))
14+
y = np.array([v[1] for v in var])
15+
groups = np.array([v[0] for v in var])
16+
17+
img_dir = '/data/ephemeral/home/dataset/yolo/images'
18+
19+
save_dir = '/data/ephemeral/home/dataset/yolo/split'
20+
if not os.path.isdir('/data/ephemeral/home/dataset/yolo'):
21+
os.mkdir('/data/ephemeral/home/dataset/yolo')
22+
if not os.path.isdir(save_dir):
23+
os.mkdir(save_dir)
24+
25+
SEED = 42
26+
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=SEED)
27+
28+
for fold, (train_idx, val_idx) in enumerate(sgkf.split(X, y, groups), start=1):
29+
save_train_path = os.path.join(save_dir, f'train_{SEED}_fold_{fold}.txt')
30+
save_val_path = os.path.join(save_dir, f'val_{SEED}_fold_{fold}.txt')
31+
32+
train_paths, val_paths = '', ''
33+
34+
train_img_ids = set([data['annotations'][idx]['image_id'] for idx in train_idx])
35+
val_img_ids = set([data['annotations'][idx]['image_id'] for idx in val_idx])
36+
37+
for idx in train_img_ids:
38+
train_paths += os.path.join(img_dir, data['images'][idx]['file_name'][6:]) + '\n'
39+
for idx in val_img_ids:
40+
val_paths += os.path.join(img_dir, data['images'][idx]['file_name'][6:]) + '\n'
41+
42+
with open(save_train_path, 'w') as f:
43+
f.write(train_paths)
44+
45+
with open(save_val_path, 'w') as f:
46+
f.write(val_paths)
47+
48+
if __name__ == '__main__':
49+
main()

0 commit comments

Comments
 (0)