Skip to content

Commit bbb654e

Browse files
lidanqing-intelsfraczek
authored andcommitted
fix preprocess script with processbar, integrity check and logs (#16608)
* fix preprocess script with processbar, integrity check and logs * delete unnecessary empty lines, change function name test=release/1.4
1 parent 627ca4a commit bbb654e

File tree

1 file changed

+132
-72
lines changed

1 file changed

+132
-72
lines changed
Lines changed: 132 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
2-
#
32
# licensed under the apache license, version 2.0 (the "license");
43
# you may not use this file except in compliance with the license.
54
# you may obtain a copy of the license at
@@ -11,6 +10,7 @@
1110
# without warranties or conditions of any kind, either express or implied.
1211
# see the license for the specific language governing permissions and
1312
# limitations under the license.
13+
import hashlib
1414
import unittest
1515
import os
1616
import numpy as np
@@ -21,16 +21,20 @@
2121
import contextlib
2222
from PIL import Image, ImageEnhance
2323
import math
24-
from paddle.dataset.common import download
24+
from paddle.dataset.common import download, md5file
25+
import tarfile
2526

2627
random.seed(0)
2728
np.random.seed(0)
2829

2930
DATA_DIM = 224
30-
3131
SIZE_FLOAT32 = 4
3232
SIZE_INT64 = 8
33-
33+
FULL_SIZE_BYTES = 30106000008
34+
FULL_IMAGES = 50000
35+
DATA_DIR_NAME = 'ILSVRC2012'
36+
IMG_DIR_NAME = 'var'
37+
TARGET_HASH = '8dc592db6dcc8d521e4d5ba9da5ca7d2'
3438
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
3539
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
3640

@@ -70,19 +74,9 @@ def process_image(img_path, mode, color_jitter, rotate):
7074
return img
7175

7276

73-
def download_unzip():
74-
int8_download = 'int8/download'
75-
76-
target_name = 'data'
77-
78-
cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
79-
int8_download)
80-
81-
target_folder = os.path.join(cache_folder, target_name)
82-
77+
def download_concat(cache_folder, zip_path):
8378
data_urls = []
8479
data_md5s = []
85-
8680
data_urls.append(
8781
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa'
8882
)
@@ -91,72 +85,138 @@ def download_unzip():
9185
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partab'
9286
)
9387
data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5')
94-
9588
file_names = []
96-
89+
print("Downloading full ImageNet Validation dataset ...")
9790
for i in range(0, len(data_urls)):
9891
download(data_urls[i], cache_folder, data_md5s[i])
99-
file_names.append(data_urls[i].split('/')[-1])
100-
101-
zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz')
102-
92+
file_name = os.path.join(cache_folder, data_urls[i].split('/')[-1])
93+
file_names.append(file_name)
94+
print("Downloaded part {0}\n".format(file_name))
10395
if not os.path.exists(zip_path):
104-
cat_command = 'cat'
105-
for file_name in file_names:
106-
cat_command += ' ' + os.path.join(cache_folder, file_name)
107-
cat_command += ' > ' + zip_path
108-
os.system(cat_command)
109-
print('Data is downloaded at {0}\n').format(zip_path)
110-
111-
if not os.path.exists(target_folder):
112-
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, zip_path)
113-
os.system(cmd)
114-
print('Data is unzipped at {0}\n'.format(target_folder))
115-
116-
data_dir = os.path.join(target_folder, 'ILSVRC2012')
117-
print('ILSVRC2012 full val set at {0}\n'.format(data_dir))
118-
return data_dir
96+
with open(zip_path, "w+") as outfile:
97+
for fname in file_names:
98+
with open(fname) as infile:
99+
outfile.write(infile.read())
100+
101+
102+
def extract(zip_path, extract_folder):
103+
data_dir = os.path.join(extract_folder, DATA_DIR_NAME)
104+
img_dir = os.path.join(data_dir, IMG_DIR_NAME)
105+
print("Extracting...\n")
106+
107+
if not (os.path.exists(img_dir) and
108+
len(os.listdir(img_dir)) == FULL_IMAGES):
109+
tar = tarfile.open(zip_path)
110+
tar.extractall(path=extract_folder)
111+
tar.close()
112+
print('Extracted. Full Imagenet Validation dataset is located at {0}\n'.
113+
format(data_dir))
114+
115+
116+
def print_processbar(done, total):
117+
done_filled = done * '='
118+
empty_filled = (total - done) * ' '
119+
percentage_done = done * 100 / total
120+
sys.stdout.write("\r[%s%s]%d%%" %
121+
(done_filled, empty_filled, percentage_done))
122+
sys.stdout.flush()
123+
124+
125+
def check_integrity(filename, target_hash):
126+
print('\nThe binary file exists. Checking file integrity...\n')
127+
md = hashlib.md5()
128+
count = 0
129+
total_parts = 50
130+
chunk_size = 8192
131+
onepart = FULL_SIZE_BYTES / chunk_size / total_parts
132+
with open(filename) as ifs:
133+
while True:
134+
buf = ifs.read(8192)
135+
if count % onepart == 0:
136+
done = count / onepart
137+
print_processbar(done, total_parts)
138+
count = count + 1
139+
if not buf:
140+
break
141+
md.update(buf)
142+
hash1 = md.hexdigest()
143+
if hash1 == target_hash:
144+
return True
145+
else:
146+
return False
119147

120148

121-
def reader():
122-
data_dir = download_unzip()
123-
file_list = os.path.join(data_dir, 'val_list.txt')
124-
output_file = os.path.join(data_dir, 'int8_full_val.bin')
149+
def convert(file_list, data_dir, output_file):
150+
print('Converting 50000 images to binary file ...\n')
125151
with open(file_list) as flist:
126152
lines = [line.strip() for line in flist]
127153
num_images = len(lines)
128-
if not os.path.exists(output_file):
129-
print(
130-
'Preprocessing to binary file...<num_images><all images><all labels>...\n'
131-
)
132-
with open(output_file, "w+b") as of:
133-
#save num_images(int64_t) to file
134-
of.seek(0)
135-
num = np.array(int(num_images)).astype('int64')
136-
of.write(num.tobytes())
137-
for idx, line in enumerate(lines):
138-
img_path, label = line.split()
139-
img_path = os.path.join(data_dir, img_path)
140-
if not os.path.exists(img_path):
141-
continue
142-
143-
#save image(float32) to file
144-
img = process_image(
145-
img_path, 'val', color_jitter=False, rotate=False)
146-
np_img = np.array(img)
147-
of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3
148-
* idx)
149-
of.write(np_img.astype('float32').tobytes())
150-
151-
#save label(int64_t) to file
152-
label_int = (int)(label)
153-
np_label = np.array(label_int)
154-
of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3
155-
* num_images + idx * SIZE_INT64)
156-
of.write(np_label.astype('int64').tobytes())
157-
158-
print('The preprocessed binary file path {}\n'.format(output_file))
154+
with open(output_file, "w+b") as ofs:
155+
#save num_images(int64_t) to file
156+
ofs.seek(0)
157+
num = np.array(int(num_images)).astype('int64')
158+
ofs.write(num.tobytes())
159+
per_parts = 1000
160+
full_parts = FULL_IMAGES / per_parts
161+
print_processbar(0, full_parts)
162+
for idx, line in enumerate(lines):
163+
img_path, label = line.split()
164+
img_path = os.path.join(data_dir, img_path)
165+
if not os.path.exists(img_path):
166+
continue
167+
168+
#save image(float32) to file
169+
img = process_image(
170+
img_path, 'val', color_jitter=False, rotate=False)
171+
np_img = np.array(img)
172+
ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 *
173+
idx)
174+
ofs.write(np_img.astype('float32').tobytes())
175+
ofs.flush()
176+
177+
#save label(int64_t) to file
178+
label_int = (int)(label)
179+
np_label = np.array(label_int)
180+
ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 *
181+
num_images + idx * SIZE_INT64)
182+
ofs.write(np_label.astype('int64').tobytes())
183+
ofs.flush()
184+
if (idx + 1) % per_parts == 0:
185+
done = (idx + 1) / per_parts
186+
print_processbar(done, full_parts)
187+
print("Conversion finished.")
188+
189+
190+
def run_convert():
191+
print('Start to download and convert 50000 images to binary file...')
192+
cache_folder = os.path.expanduser('~/.cache/paddle/dataset/int8/download')
193+
extract_folder = os.path.join(cache_folder, 'full_data')
194+
data_dir = os.path.join(extract_folder, DATA_DIR_NAME)
195+
file_list = os.path.join(data_dir, 'val_list.txt')
196+
zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz')
197+
output_file = os.path.join(cache_folder, 'int8_full_val.bin')
198+
retry = 0
199+
try_limit = 3
200+
201+
while not (os.path.exists(output_file) and
202+
os.path.getsize(output_file) == FULL_SIZE_BYTES and
203+
check_integrity(output_file, TARGET_HASH)):
204+
if os.path.exists(output_file):
205+
sys.stderr.write(
206+
"\n\nThe existing binary file is broken. Start to generate new one...\n\n".
207+
format(output_file))
208+
os.remove(output_file)
209+
if retry < try_limit:
210+
retry = retry + 1
211+
else:
212+
raise RuntimeError(
213+
"Can not convert the dataset to binary file with try limit {0}".
214+
format(try_limit))
215+
download_concat(cache_folder, zip_path)
216+
extract(zip_path, extract_folder)
217+
convert(file_list, data_dir, output_file)
218+
print("\nSuccess! The binary file can be found at {0}".format(output_file))
159219

160220

161221
if __name__ == '__main__':
162-
reader()
222+
run_convert()

0 commit comments

Comments
 (0)