Skip to content

Commit 9e7f179

Browse files
authored
Add ResNet preprocessing model (#594)
* Add ResNet preprocessing model Signed-off-by: Joaquin Anton <[email protected]> * Support sequence in tests Signed-off-by: Joaquin Anton <[email protected]> --------- Signed-off-by: Joaquin Anton <[email protected]>
1 parent 8e893eb commit 9e7f179

File tree

8 files changed

+198
-26
lines changed

8 files changed

+198
-26
lines changed

ONNX_HUB_MANIFEST.json

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4683,6 +4683,46 @@
46834683
"model_with_data_bytes": 95237476
46844684
}
46854685
},
4686+
{
4687+
"model": "ResNet-preproc",
4688+
"model_path": "vision/classification/resnet/preproc/resnet-preproc-v1-18.onnx",
4689+
"onnx_version": "1.13.1",
4690+
"opset_version": 18,
4691+
"metadata": {
4692+
"model_sha": "9cda24af90b4cd2ced4167fa36a41956ea0ce5e55c6ae475614a097cb89762c7",
4693+
"model_bytes": 1129,
4694+
"tags": [
4695+
"vision",
4696+
"classification",
4697+
"resnet",
4698+
"preprocessing"
4699+
],
4700+
"io_ports": {
4701+
"inputs": [
4702+
{
4703+
"name": "images",
4704+
"shape": [],
4705+
"type": "seq(tensor(uint8))"
4706+
}
4707+
],
4708+
"outputs": [
4709+
{
4710+
"name": "preproc_data",
4711+
"shape": [
4712+
"B",
4713+
3,
4714+
224,
4715+
224
4716+
],
4717+
"type": "tensor(float)"
4718+
}
4719+
]
4720+
},
4721+
"model_with_data_path": "vision/classification/resnet/preproc/resnet-preproc-v1-18.tar.gz",
4722+
"model_with_data_sha": "216b89c1676c8a5a2dfc0ee1736b179b0777f9ba845ee6dd955d4ff684f29a3c",
4723+
"model_with_data_bytes": 883999
4724+
}
4725+
},
46864726
{
46874727
"model": "ShuffleNet-v1",
46884728
"model_path": "vision/classification/shufflenet/model/shufflenet-3.onnx",

vision/classification/imagenet_preprocess.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,50 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import numpy as np
4+
from PIL import Image
35
import mxnet
46
from mxnet.gluon.data.vision import transforms
57

8+
def preprocess(image):
9+
# resize so that the shorter side is 256, maintaining aspect ratio
10+
def image_resize(image, min_len):
11+
image = Image.fromarray(image)
12+
ratio = float(min_len) / min(image.size[0], image.size[1])
13+
if image.size[0] > image.size[1]:
14+
new_size = (int(round(ratio * image.size[0])), min_len)
15+
else:
16+
new_size = (min_len, int(round(ratio * image.size[1])))
17+
image = image.resize(new_size, Image.BILINEAR)
18+
return np.array(image)
19+
image = image_resize(image, 256)
20+
21+
# Crop centered window 224x224
22+
def crop_center(image, crop_w, crop_h):
23+
h, w, c = image.shape
24+
start_x = w//2 - crop_w//2
25+
start_y = h//2 - crop_h//2
26+
return image[start_y:start_y+crop_h, start_x:start_x+crop_w, :]
27+
image = crop_center(image, 224, 224)
28+
29+
# transpose
30+
image = image.transpose(2, 0, 1)
31+
32+
# convert the input data into the float32 input
33+
img_data = image.astype('float32')
34+
35+
# normalize
36+
mean_vec = np.array([0.485, 0.456, 0.406])
37+
stddev_vec = np.array([0.229, 0.224, 0.225])
38+
norm_img_data = np.zeros(img_data.shape).astype('float32')
39+
for i in range(img_data.shape[0]):
40+
norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
41+
42+
# add batch channel
43+
norm_img_data = norm_img_data.reshape(1, 3, 224, 224).astype('float32')
44+
return norm_img_data
45+
646
# Pre-processing function for ImageNet models
7-
def preprocess(img):
47+
def preprocess_mxnet(img):
848
'''
949
Preprocessing required on the images for inference with mxnet gluon
1050
The function takes path to an image and returns processed tensor

vision/classification/resnet/README.md

Lines changed: 81 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ ResNet v2 uses pre-activation function whereas ResNet v1 uses post-activation f
3737
|ResNet50-qdq | [24.6 MB](model/resnet50-v1-12-qdq.onnx) | [16.8 MB](model/resnet50-v1-12-qdq.tar.gz) | 1.10.0 | 12 |74.43 | |
3838
> Compared with the fp32 ResNet50, int8 ResNet50's Top-1 accuracy drop ratio is 0.27%, Top-5 accuracy drop ratio is 0.01% and performance improvement is 1.82x.
3939
>
40-
> Note the performance depends on the test hardware.
41-
>
40+
> Note the performance depends on the test hardware.
41+
>
4242
> Performance data here is collected with Intel® Xeon® Platinum 8280 Processor, 1s 4c per instance, CentOS Linux 8.3, data batch size is 1.
4343
4444
|Model |Download |Download (with sample test data)| ONNX version |Opset version|
@@ -68,24 +68,88 @@ All pre-trained models expect input images normalized in the same way, i.e. mini
6868
The inference was done using jpeg image.
6969

7070
### Preprocessing
71-
The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]. The transformation should preferably happen at preprocessing.
7271

73-
The following code shows how to preprocess a NCHW tensor:
72+
The image needs to be preprocessed before fed to the network.
73+
The first step is to extract a 224x224 crop from the center of the image. For this, the image is first scaled to a minimum size of 256x256, while keeping aspect ratio. That is, the shortest side of the image is resized to 256 and the other side is scaled accordingly to maintain the original aspect ratio. After that, the image is normalized with mean = 255*[0.485, 0.456, 0.406] and std = 255*[0.229, 0.224, 0.225]. Last step is to transpose it from HWC to CHW layout.
74+
75+
The described preprocessing steps can be represented with an ONNX model:
76+
```python
77+
import onnx
78+
from onnx import parser
79+
from onnx import checker
80+
81+
resnet_preproc = parser.parse_model('''
82+
<
83+
ir_version: 8,
84+
opset_import: [ "" : 18, "local" : 1 ],
85+
metadata_props: [ "preprocessing_fn" : "local.preprocess"]
86+
>
87+
resnet_preproc_g (seq(uint8[?, ?, 3]) images) => (float[B, 3, 224, 224] preproc_data)
88+
{
89+
preproc_data = local.preprocess(images)
90+
}
91+
92+
<
93+
opset_import: [ "" : 18 ],
94+
domain: "local",
95+
doc_string: "Preprocessing function."
96+
>
97+
preprocess (input_batch) => (output_tensor) {
98+
tmp_seq = SequenceMap <
99+
body = sample_preprocessing(uint8[?, ?, 3] sample_in) => (float[3, 224, 224] sample_out) {
100+
target_size = Constant <value = int64[2] {256, 256}> ()
101+
image_resized = Resize <mode = \"linear\",
102+
antialias = 1,
103+
axes = [0, 1],
104+
keep_aspect_ratio_policy = \"not_smaller\"> (sample_in, , , target_size)
105+
106+
target_crop = Constant <value = int64[2] {224, 224}> ()
107+
image_sliced = CenterCropPad <axes = [0, 1]> (image_resized, target_crop)
108+
109+
kMean = Constant <value = float[3] {123.675, 116.28, 103.53}> ()
110+
kStddev = Constant <value = float[3] {58.395, 57.12, 57.375}> ()
111+
im_norm_tmp1 = Cast <to = 1> (image_sliced)
112+
im_norm_tmp2 = Sub (im_norm_tmp1, kMean)
113+
im_norm = Div (im_norm_tmp2, kStddev)
114+
115+
sample_out = Transpose <perm = [2, 0, 1]> (im_norm)
116+
}
117+
> (input_batch)
118+
output_tensor = ConcatFromSequence < axis = 0, new_axis = 1 >(tmp_seq)
119+
}
120+
121+
''')
122+
checker.check_model(resnet_preproc)
123+
```
124+
125+
* ResNet preprocessing:
126+
127+
|Model |Download |Download (with sample test data)| ONNX version |Opset version|
128+
|-------------|:--------------|:--------------|:--------------|:--------------|
129+
|ResNet-preproc| [4.0KB](preproc/resnet-preproc-v1-18.onnx) | [864 KB](preproc/resnet-preproc-v1-18.tar.gz) | 1.13.1 | 18|
74130

131+
132+
To prepend the data preprocessing steps to the model, we can use the ONNX compose utils:
75133
```python
76-
import numpy
77-
78-
def preprocess(img_data):
79-
mean_vec = np.array([0.485, 0.456, 0.406])
80-
stddev_vec = np.array([0.229, 0.224, 0.225])
81-
norm_img_data = np.zeros(img_data.shape).astype('float32')
82-
for i in range(img_data.shape[0]):
83-
# for each pixel in each channel, divide the value by 255 to get value between [0, 1] and then normalize
84-
norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
85-
return norm_img_data
134+
135+
import onnx
136+
from onnx import version_converter
137+
from onnx import checker
138+
139+
network_model = onnx.version_converter.convert_version(network_model, 18)
140+
network_model.ir_version = 8
141+
checker.check_model(network_model)
142+
143+
model_w_preproc = onnx.compose.merge_models(
144+
preprocessing_model, network_model,
145+
io_map=[('preproc_data', 'data')]
146+
)
147+
checker.check_model(model_w_preproc)
148+
86149
```
87150

88-
Check [imagenet_preprocess.py](../imagenet_preprocess.py) for additional sample code.
151+
152+
Check [imagenet_preprocess.py](../imagenet_preprocess.py) for some reference Python and MxNet implementations.
89153

90154
### Output
91155
The model outputs image scores for each of the [1000 classes of ImageNet](../synset.txt).
@@ -113,7 +177,7 @@ We used MXNet as framework with gluon APIs to perform validation. Use the notebo
113177
ResNet50-int8 and ResNet50-qdq are obtained by quantizing ResNet50-fp32 model. We use [Intel® Neural Compressor](https://github.com/intel/neural-compressor) with onnxruntime backend to perform quantization. View the [instructions](https://github.com/intel/neural-compressor/blob/master/examples/onnxrt/image_recognition/onnx_model_zoo/resnet50/quantization/ptq/README.md) to understand how to use Intel® Neural Compressor for quantization.
114178

115179
### Environment
116-
onnx: 1.7.0
180+
onnx: 1.7.0
117181
onnxruntime: 1.6.0+
118182

119183
### Prepare model
@@ -153,6 +217,7 @@ In European Conference on Computer Vision, pp. 630-645. Springer, Cham, 2016.
153217
* [airMeng](https://github.com/airMeng) (Intel)
154218
* [ftian1](https://github.com/ftian1) (Intel)
155219
* [hshen14](https://github.com/hshen14) (Intel)
220+
* [jantonguirao](https://github.com/jantonguirao) (NVIDIA)
156221

157222
## License
158223
Apache 2.0
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:9cda24af90b4cd2ced4167fa36a41956ea0ce5e55c6ae475614a097cb89762c7
3+
size 1129
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:216b89c1676c8a5a2dfc0ee1736b179b0777f9ba845ee6dd955d4ff684f29a3c
3+
size 883999

workflow_scripts/generate_onnx_hub_manifest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,11 @@ def get_file_info(row, field, target_models=None):
124124
def get_model_tags(row):
125125
source_dir = split(row["source_file"])[0]
126126
raw_tags = source_dir.split("/")
127-
return [tag.replace("_", " ") for tag in raw_tags]
127+
tags = [tag.replace("_", " ") for tag in raw_tags]
128+
model_file = row['model_path'].contents[0].attrs["href"]
129+
if 'preproc' in model_file.split("/"):
130+
tags.append('preprocessing')
131+
return tags
128132

129133

130134
def get_model_ports(source_file, metadata, model_name):

workflow_scripts/onnx_test_data_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,23 @@
1111
import numpy as np
1212
import onnx
1313
from onnx import numpy_helper
14+
from onnx.onnx_data_pb2 import SequenceProto
1415

1516

1617
def read_tensorproto_pb_file(filename):
1718
"""Return tuple of tensor name and numpy.ndarray of the data from a pb file containing a TensorProto."""
18-
1919
tensor = onnx.load_tensor(filename)
2020
np_array = numpy_helper.to_array(tensor)
2121
return tensor.name, np_array
2222

23+
def read_sequenceproto_pb_file(filename):
24+
"""Return tuple of sequence name and list of numpy.ndarray of the data from a pb file containing a SequenceProto."""
25+
seq = SequenceProto()
26+
with open(filename, 'rb') as f:
27+
seq.ParseFromString(f.read())
28+
list_of_arrays = numpy_helper.to_list(seq)
29+
return seq.name, list_of_arrays
30+
2331

2432
def dump_tensorproto_pb_file(filename):
2533
"""Dump the data from a pb file containing a TensorProto."""

workflow_scripts/ort_test_dir_utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def save_data(prefix, name_data_map, model_info):
157157
save_data("output", name_output_map, model_outputs)
158158

159159

160-
def read_test_dir(dir_name):
160+
def read_test_dir(dir_name, input_types, output_types):
161161
"""
162162
Read the input and output .pb files from the provided directory.
163163
Input files should have a prefix of 'input_'
@@ -169,15 +169,22 @@ def read_test_dir(dir_name):
169169

170170
inputs = {}
171171
outputs = {}
172+
172173
input_files = glob.glob(os.path.join(dir_name, "input_*.pb"))
173174
output_files = glob.glob(os.path.join(dir_name, "output_*.pb"))
174175

175-
for i in input_files:
176-
name, data = onnx_test_data_utils.read_tensorproto_pb_file(i)
176+
for i, filename in enumerate(input_files):
177+
if 'seq' in input_types[i]:
178+
name, data = onnx_test_data_utils.read_sequenceproto_pb_file(filename)
179+
else:
180+
name, data = onnx_test_data_utils.read_tensorproto_pb_file(filename)
177181
inputs[name] = data
178182

179-
for o in output_files:
180-
name, data = onnx_test_data_utils.read_tensorproto_pb_file(o)
183+
for i, filename in enumerate(output_files):
184+
if 'seq' in output_files[i]:
185+
name, data = onnx_test_data_utils.read_sequenceproto_pb_file(filename)
186+
else:
187+
name, data = onnx_test_data_utils.read_tensorproto_pb_file(filename)
181188
outputs[name] = data
182189

183190
return inputs, outputs
@@ -217,12 +224,14 @@ def run_test_dir(model_or_dir):
217224
test_dirs = [d for d in glob.glob(os.path.join(model_dir, "test*")) if os.path.isdir(d)]
218225
if not test_dirs:
219226
raise ValueError("No directories with name starting with 'test' were found in {}.".format(model_dir))
220-
221227
sess = ort.InferenceSession(model_path)
222228

229+
input_types = [inp.type for inp in sess.get_inputs()]
230+
output_types = [out.type for out in sess.get_outputs()]
231+
223232
for d in test_dirs:
224233
print(d)
225-
inputs, expected_outputs = read_test_dir(d)
234+
inputs, expected_outputs = read_test_dir(d, input_types, output_types)
226235

227236
if expected_outputs:
228237
output_names = list(expected_outputs.keys())

0 commit comments

Comments
 (0)