1+ import os
2+ import json
3+ import numpy as np
4+ from sklearn .model_selection import StratifiedGroupKFold
5+
6+ def main ():
7+ annotation = '/data/ephemeral/home/dataset/train.json'
8+
9+ with open (annotation ) as f :
10+ data = json .load (f )
11+
12+ var = [(ann ['image_id' ], ann ['category_id' ]) for ann in data ['annotations' ]]
13+ X = np .ones ((len (data ['annotations' ]), 1 ))
14+ y = np .array ([v [1 ] for v in var ])
15+ groups = np .array ([v [0 ] for v in var ])
16+
17+ save_dir = '/data/ephemeral/home/dataset/split'
18+ if not os .path .isdir (save_dir ):
19+ os .mkdir (save_dir )
20+
21+ SEED = 42
22+ sgkf = StratifiedGroupKFold (n_splits = 5 , shuffle = True , random_state = SEED )
23+
24+ for fold , (train_idx , val_idx ) in enumerate (sgkf .split (X , y , groups ), start = 1 ):
25+
26+ train_img_ids = set ([data ['annotations' ][idx ]['image_id' ] for idx in train_idx ])
27+ train_imgs = [data ['images' ][idx ] for idx in train_img_ids ]
28+
29+ train_data = {
30+ 'images' : train_imgs ,
31+ 'categories' : data ['categories' ],
32+ 'annotations' : [data ['annotations' ][idx ] for idx in train_idx ]
33+ }
34+
35+ val_img_ids = set ([data ['annotations' ][idx ]['image_id' ] for idx in val_idx ])
36+ val_imgs = [data ['images' ][idx ] for idx in val_img_ids ]
37+
38+ val_data = {
39+ 'images' : val_imgs ,
40+ 'categories' : data ['categories' ],
41+ 'annotations' : [data ['annotations' ][idx ] for idx in val_idx ]
42+ }
43+
44+ train_path = os .path .join (save_dir , f'train_{ SEED } _fold_{ fold } .json' )
45+ with open (train_path , 'w' ) as f :
46+ json .dump (train_data , f , indent = 4 )
47+
48+ val_path = os .path .join (save_dir , f'val_{ SEED } _fold_{ fold } .json' )
49+ with open (val_path , 'w' ) as f :
50+ json .dump (val_data , f , indent = 4 )
51+
52+ if __name__ == '__main__' :
53+ main ()
0 commit comments