forked from wyharveychen/CloserLookFewShot
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmake_hdf5.py
More file actions
67 lines (57 loc) · 2.43 KB
/
make_hdf5.py
File metadata and controls
67 lines (57 loc) · 2.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
from data.dataset import SimpleDataset
from data.datamgr import SimpleDataManager, ResizeDataManager
import os
from tqdm import tqdm
import numpy as np
import h5py
parser = argparse.ArgumentParser(description= 'yeeeee, im description of make_hdf5 parser!')
parser.add_argument('--dataset', choices=['CUB','miniImagenet','omniglot','emnist'], required=True)
parser.add_argument('--mode', choices=['all','train','val','test','noLatin'], required=True)
parser.add_argument('--aug', action='store_true', help='use the augmented data or not')
parser.add_argument('--channel_order', default='NCHW', help='NCHW for PyTorch / NHWC for TF, but LrLiVAE have dealed with NCHW so don\'t need to do NHWC')
parser.add_argument('--img_size', help='image size', required=True, type=int)
parser.add_argument('--batch_size', help='The batch size when processing the data', default=50, type=int)
parser.add_argument('--debug', action='store_true', help='debug mode on')
args = parser.parse_args()
# assert args.dataset is not None
# assert args.mode is not None
# assert args.argument is not None
# prepare to load data
data_path = os.path.join('filelists',args.dataset)
file = {
'all':'all.json', 'train':'base.json', 'val':'val.json', 'test':'novel.json',
'noLatin':'noLatin.json'}
file_path = os.path.join(data_path,file[args.mode])
datamgr = ResizeDataManager(args.img_size, batch_size=50)
data_loader = datamgr.get_data_loader(data_file=file_path, aug=args.aug)
imgs_list = []
labels_list = []
# load data
t_loader = tqdm(data_loader)
for i, data in enumerate(t_loader):
batch_x, batch_y = data[0].numpy(), data[1].numpy()
if args.channel_order == 'NHWC':
batch_x = np.transpose(batch_x, axis=(0,2,3,1))
imgs_list.append(batch_x)
labels_list.append(batch_y)
if i==2 and args.debug:
break
imgs = np.concatenate(imgs_list, axis=0)
labels = np.concatenate(labels_list, axis=0)
print('Final: imgs shape:',imgs.shape,', labels shape:',labels.shape)
# write hdf5 file
filename = args.mode + '-' + args.channel_order + '-' + str(args.img_size)
if args.aug:
filename += '-aug'
filename += '.h5'
out_path = os.path.join(data_path, 'hdf5')
if not os.path.exists(out_path):
print('make directory:', out_path)
os.makedirs(out_path)
out_file = os.path.join(out_path, filename)
print('making file:', out_file)
with h5py.File(out_file,'w-') as f:
f['images'] = imgs
f['labels'] = labels
print('Finish writing file:', out_file)