Skip to content

Commit 0d65699

Browse files
fix some bugs of unzip and reading val list
test=develop
1 parent b46e467 commit 0d65699

File tree

1 file changed

+46
-37
lines changed

1 file changed

+46
-37
lines changed

paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,14 @@ def process_image(img_path, mode, color_jitter, rotate):
7171

7272

7373
def download_unzip():
74+
int8_download = 'int8/download'
7475

75-
tmp_folder = 'int8/download'
76+
target_name = 'data'
7677

77-
cache_folder = os.path.expanduser('~/.cache/' + tmp_folder)
78+
cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
79+
int8_download)
80+
81+
target_folder = os.path.join(cache_folder, target_name)
7882

7983
data_urls = []
8084
data_md5s = []
@@ -89,8 +93,9 @@ def download_unzip():
8993
data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5')
9094

9195
file_names = []
96+
9297
for i in range(0, len(data_urls)):
93-
download(data_urls[i], tmp_folder, data_md5s[i])
98+
download(data_urls[i], cache_folder, data_md5s[i])
9499
file_names.append(data_urls[i].split('/')[-1])
95100

96101
zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz')
@@ -101,16 +106,15 @@ def download_unzip():
101106
cat_command += ' ' + os.path.join(cache_folder, file_name)
102107
cat_command += ' > ' + zip_path
103108
os.system(cat_command)
109+
print('Data is downloaded at {0}\n').format(zip_path)
104110

105-
if not os.path.exists(cache_folder):
106-
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(cache_folder, zip_path)
107-
108-
cmd = 'rm -rf {3} && ln -s {1} {0}'.format("data", cache_folder, zip_path)
109-
110-
os.system(cmd)
111-
112-
data_dir = os.path.expanduser(cache_folder + 'data')
111+
if not os.path.exists(target_folder):
112+
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, zip_path)
113+
os.system(cmd)
114+
print('Data is unzipped at {0}\n'.format(target_folder))
113115

116+
data_dir = os.path.join(target_folder, 'ILSVRC2012')
117+
print('ILSVRC2012 full val set at {0}\n'.format(data_dir))
114118
return data_dir
115119

116120

@@ -121,32 +125,37 @@ def reader():
121125
with open(file_list) as flist:
122126
lines = [line.strip() for line in flist]
123127
num_images = len(lines)
124-
125-
with open(output_file, "w+b") as of:
126-
#save num_images(int64_t) to file
127-
of.seek(0)
128-
num = np.array(int(num_images)).astype('int64')
129-
of.write(num.tobytes())
130-
for idx, line in enumerate(lines):
131-
img_path, label = line.split()
132-
img_path = os.path.join(data_dir, img_path)
133-
if not os.path.exists(img_path):
134-
continue
135-
136-
#save image(float32) to file
137-
img = process_image(
138-
img_path, 'val', color_jitter=False, rotate=False)
139-
np_img = np.array(img)
140-
of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 *
141-
idx)
142-
of.write(np_img.astype('float32').tobytes())
143-
144-
#save label(int64_t) to file
145-
label_int = (int)(label)
146-
np_label = np.array(label_int)
147-
of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 *
148-
num_images + idx * SIZE_INT64)
149-
of.write(np_label.astype('int64').tobytes())
128+
if not os.path.exists(output_file):
129+
print(
130+
'Preprocessing to binary file...<num_images><all images><all labels>...\n'
131+
)
132+
with open(output_file, "w+b") as of:
133+
#save num_images(int64_t) to file
134+
of.seek(0)
135+
num = np.array(int(num_images)).astype('int64')
136+
of.write(num.tobytes())
137+
for idx, line in enumerate(lines):
138+
img_path, label = line.split()
139+
img_path = os.path.join(data_dir, img_path)
140+
if not os.path.exists(img_path):
141+
continue
142+
143+
#save image(float32) to file
144+
img = process_image(
145+
img_path, 'val', color_jitter=False, rotate=False)
146+
np_img = np.array(img)
147+
of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3
148+
* idx)
149+
of.write(np_img.astype('float32').tobytes())
150+
151+
#save label(int64_t) to file
152+
label_int = (int)(label)
153+
np_label = np.array(label_int)
154+
of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3
155+
* num_images + idx * SIZE_INT64)
156+
of.write(np_label.astype('int64').tobytes())
157+
158+
print('The preprocessed binary file path {}\n'.format(output_file))
150159

151160

152161
if __name__ == '__main__':

0 commit comments

Comments
 (0)