Skip to content

Commit 8f7b588

Browse files
authored
Merge pull request #16529 from lidanqing-intel/lidanqing/preprocess-data
preprocess with PIL the full val dataset and save binary
2 parents 5b24002 + 0d65699 commit 8f7b588

File tree

1 file changed

+162
-0
lines changed

1 file changed

+162
-0
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
2+
#
3+
# licensed under the apache license, version 2.0 (the "license");
4+
# you may not use this file except in compliance with the license.
5+
# you may obtain a copy of the license at
6+
#
7+
# http://www.apache.org/licenses/license-2.0
8+
#
9+
# unless required by applicable law or agreed to in writing, software
10+
# distributed under the license is distributed on an "as is" basis,
11+
# without warranties or conditions of any kind, either express or implied.
12+
# see the license for the specific language governing permissions and
13+
# limitations under the license.
14+
import unittest
15+
import os
16+
import numpy as np
17+
import time
18+
import sys
19+
import random
20+
import functools
21+
import contextlib
22+
from PIL import Image, ImageEnhance
23+
import math
24+
from paddle.dataset.common import download
25+
26+
random.seed(0)
27+
np.random.seed(0)
28+
29+
DATA_DIM = 224
30+
31+
SIZE_FLOAT32 = 4
32+
SIZE_INT64 = 8
33+
34+
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
35+
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
36+
37+
38+
def resize_short(img, target_size):
39+
percent = float(target_size) / min(img.size[0], img.size[1])
40+
resized_width = int(round(img.size[0] * percent))
41+
resized_height = int(round(img.size[1] * percent))
42+
img = img.resize((resized_width, resized_height), Image.LANCZOS)
43+
return img
44+
45+
46+
def crop_image(img, target_size, center):
47+
width, height = img.size
48+
size = target_size
49+
if center == True:
50+
w_start = (width - size) / 2
51+
h_start = (height - size) / 2
52+
else:
53+
w_start = np.random.randint(0, width - size + 1)
54+
h_start = np.random.randint(0, height - size + 1)
55+
w_end = w_start + size
56+
h_end = h_start + size
57+
img = img.crop((w_start, h_start, w_end, h_end))
58+
return img
59+
60+
61+
def process_image(img_path, mode, color_jitter, rotate):
62+
img = Image.open(img_path)
63+
img = resize_short(img, target_size=256)
64+
img = crop_image(img, target_size=DATA_DIM, center=True)
65+
if img.mode != 'RGB':
66+
img = img.convert('RGB')
67+
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
68+
img -= img_mean
69+
img /= img_std
70+
return img
71+
72+
73+
def download_unzip():
74+
int8_download = 'int8/download'
75+
76+
target_name = 'data'
77+
78+
cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
79+
int8_download)
80+
81+
target_folder = os.path.join(cache_folder, target_name)
82+
83+
data_urls = []
84+
data_md5s = []
85+
86+
data_urls.append(
87+
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa'
88+
)
89+
data_md5s.append('60f6525b0e1d127f345641d75d41f0a8')
90+
data_urls.append(
91+
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partab'
92+
)
93+
data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5')
94+
95+
file_names = []
96+
97+
for i in range(0, len(data_urls)):
98+
download(data_urls[i], cache_folder, data_md5s[i])
99+
file_names.append(data_urls[i].split('/')[-1])
100+
101+
zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz')
102+
103+
if not os.path.exists(zip_path):
104+
cat_command = 'cat'
105+
for file_name in file_names:
106+
cat_command += ' ' + os.path.join(cache_folder, file_name)
107+
cat_command += ' > ' + zip_path
108+
os.system(cat_command)
109+
print('Data is downloaded at {0}\n').format(zip_path)
110+
111+
if not os.path.exists(target_folder):
112+
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, zip_path)
113+
os.system(cmd)
114+
print('Data is unzipped at {0}\n'.format(target_folder))
115+
116+
data_dir = os.path.join(target_folder, 'ILSVRC2012')
117+
print('ILSVRC2012 full val set at {0}\n'.format(data_dir))
118+
return data_dir
119+
120+
121+
def reader():
122+
data_dir = download_unzip()
123+
file_list = os.path.join(data_dir, 'val_list.txt')
124+
output_file = os.path.join(data_dir, 'int8_full_val.bin')
125+
with open(file_list) as flist:
126+
lines = [line.strip() for line in flist]
127+
num_images = len(lines)
128+
if not os.path.exists(output_file):
129+
print(
130+
'Preprocessing to binary file...<num_images><all images><all labels>...\n'
131+
)
132+
with open(output_file, "w+b") as of:
133+
#save num_images(int64_t) to file
134+
of.seek(0)
135+
num = np.array(int(num_images)).astype('int64')
136+
of.write(num.tobytes())
137+
for idx, line in enumerate(lines):
138+
img_path, label = line.split()
139+
img_path = os.path.join(data_dir, img_path)
140+
if not os.path.exists(img_path):
141+
continue
142+
143+
#save image(float32) to file
144+
img = process_image(
145+
img_path, 'val', color_jitter=False, rotate=False)
146+
np_img = np.array(img)
147+
of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3
148+
* idx)
149+
of.write(np_img.astype('float32').tobytes())
150+
151+
#save label(int64_t) to file
152+
label_int = (int)(label)
153+
np_label = np.array(label_int)
154+
of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3
155+
* num_images + idx * SIZE_INT64)
156+
of.write(np_label.astype('int64').tobytes())
157+
158+
print('The preprocessed binary file path {}\n'.format(output_file))
159+
160+
161+
if __name__ == '__main__':
162+
reader()

0 commit comments

Comments
 (0)