forked from tsuyoikaze/powerful-gnns
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsplit_validation.py
More file actions
30 lines (25 loc) · 887 Bytes
/
split_validation.py
File metadata and controls
30 lines (25 loc) · 887 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import sys
import os
import numpy as np
import json
from subprocess import call
# usage: python split_validation.py <path of training set> <path of validation set> <ratio of data removed from training> <patient_to_labels.json>
l = os.listdir(sys.argv[1])
os.mkdir(sys.argv[2])
d = dict()
patient_to_labels = json.load(open(sys.argv[4]))
for i in l:
if patient_to_labels[i] not in d:
d[patient_to_labels[i]] = []
d[patient_to_labels[i]].append(i)
res = []
for i in d:
num = min(1, int(len(d[i]) * float(sys.argv[3])))
if num < 1:
print('In class %d, epected number of patient in validation is %f' % (i, len(d[i]) * float(sys.argv[3])))
idx = np.random.choice(int(len(d[i])), size=num)
for patient in idx:
print(d[i][patient])
res.append(d[i][patient])
for i in res:
call('mv %s %s' % (os.path.join(sys.argv[1], str(i)), os.path.join(sys.argv[2], str(i))), shell=True)