forked from elijahcole/caltech-ee148-spring2020-hw02
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_split.py
More file actions
69 lines (51 loc) · 1.96 KB
/
generate_split.py
File metadata and controls
69 lines (51 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import os
import json
np.random.seed(2020) # to ensure you always get the same train/test split
data_path = './data/RedLights2011_Medium'
gts_path = './data/hw02_annotations'
split_path = './data/hw02_splits'
#os.makedirs(preds_path, exist_ok=True) # create directory if needed
os.makedirs(gts_path, exist_ok=True)
os.makedirs(split_path, exist_ok=True)
split_test = True # set to True and run when annotations are available
train_frac = 0.85
# get sorted list of files:
file_names = sorted(os.listdir(data_path))
# remove any non-JPEG files:
file_names = [f for f in file_names if '.jpg' in f]
# split file names into train and test
file_names_train = []
file_names_test = []
'''
Your code below.
'''
# Shuffle the file names
shuffled_file_names = np.random.permutation(file_names)
# Take 85% of the data to be training data, 15% to be testing data
train_num = int(train_frac * len(file_names))
file_names_train = shuffled_file_names[:train_num]
file_names_test = shuffled_file_names[train_num:]
assert (len(file_names_train) + len(file_names_test)) == len(file_names)
assert len(np.intersect1d(file_names_train,file_names_test)) == 0
np.save(os.path.join(split_path,'file_names_train.npy'),file_names_train)
np.save(os.path.join(split_path,'file_names_test.npy'),file_names_test)
if split_test:
with open(os.path.join(gts_path, 'annotations.json'),'r') as f:
gts = json.load(f)
# Use file_names_train and file_names_test to apply the split to the
# annotations
gts_train = {}
gts_test = {}
'''
Your code below.
'''
for fname, preds in gts.items():
if fname in file_names_train:
gts_train[fname] = preds
elif fname in file_names_test:
gts_test[fname] = preds
with open(os.path.join(gts_path, 'annotations_train.json'),'w') as f:
json.dump(gts_train,f)
with open(os.path.join(gts_path, 'annotations_test.json'),'w') as f:
json.dump(gts_test,f)