Skip to content

Commit f56a16c

Browse files
committed
update script
1 parent c8f7ce5 commit f56a16c

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

quantization/image_classification/trt/resnet50/e2e_tensorrt_resnet_example.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import onnxruntime
1010
from onnxruntime.quantization import CalibrationDataReader, create_calibrator, write_calibration_table
1111

12+
# onnxruntime.set_default_logger_severity(0)
1213

1314
class ImageNetDataReader(CalibrationDataReader):
1415
def __init__(self,
@@ -126,12 +127,13 @@ def preprocess_imagenet(self, images_folder, height, width, start_index=0, size_
126127
return: list of matrices characterizing multiple images
127128
'''
128129
def preprocess_images(input, channels=3, height=224, width=224):
129-
image = input.resize((width, height), Image.ANTIALIAS)
130+
image = input.resize((width, height), Image.Resampling.LANCZOS) # Image.ANTIALIAS was removed in Pillow 10.0.0
130131
input_data = np.asarray(image).astype(np.float32)
131132
if len(input_data.shape) != 2:
132133
input_data = input_data.transpose([2, 0, 1])
133134
else:
134135
input_data = np.stack([input_data] * 3)
136+
# image normalization
135137
mean = np.array([0.079, 0.05, 0]) + 0.406
136138
std = np.array([0.005, 0, 0.001]) + 0.224
137139
for channel in range(input_data.shape[0]):
@@ -153,7 +155,8 @@ def preprocess_images(input, channels=3, height=224, width=224):
153155

154156
for image_name in batch_filenames:
155157
image_filepath = images_folder + '/' + image_name
156-
img = Image.open(image_filepath)
158+
# Note: There is one image ILSVRC2012_val_00019877.JPEG which has 4 channels, so here we convert it to RGB with 3 channels for all images
159+
img = Image.open(image_filepath).convert("RGB")
157160
image_data = preprocess_images(img)
158161
image_data = np.expand_dims(image_data, 0)
159162
unconcatenated_batch_data.append(image_data)
@@ -163,7 +166,7 @@ def preprocess_images(input, channels=3, height=224, width=224):
163166
return batch_data, batch_filenames, image_size_list
164167

165168
def get_synset_id(self, image_folder, offset, dataset_size):
166-
ilsvrc2012_meta = scipy.io.loadmat(image_folder + "/devkit/data/meta.mat")
169+
ilsvrc2012_meta = scipy.io.loadmat(image_folder + "/ILSVRC2012_devkit_t12/data/meta.mat")
167170
id_to_synset = {}
168171
for i in range(1000):
169172
id = int(ilsvrc2012_meta["synsets"][i, 0][0][0][0])
@@ -178,7 +181,7 @@ def get_synset_id(self, image_folder, offset, dataset_size):
178181
index = index + 1
179182
file.close()
180183

181-
file = open(image_folder + "/devkit/data/ILSVRC2012_validation_ground_truth.txt", "r")
184+
file = open(image_folder + "/ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt", "r")
182185
id = file.read().strip().split("\n")
183186
id = list(map(int, id))
184187
file.close()
@@ -318,11 +321,14 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
318321
calibration_table_generation_enable = True # Enable/Disable INT8 calibration
319322

320323
# TensorRT EP INT8 settings
321-
os.environ["ORT_TENSORRT_FP16_ENABLE"] = "1" # Enable FP16 precision
322-
os.environ["ORT_TENSORRT_INT8_ENABLE"] = "1" # Enable INT8 precision
323-
os.environ["ORT_TENSORRT_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name
324-
os.environ["ORT_TENSORRT_ENGINE_CACHE_ENABLE"] = "1" # Enable engine caching
325-
execution_provider = ["TensorrtExecutionProvider"]
324+
execution_provider = [
325+
('TensorrtExecutionProvider', {
326+
'trt_int8_enable': True,
327+
'trt_fp16_enable': True,
328+
'trt_engine_cache_enable': True,
329+
'trt_int8_calibration_table_name': 'calibration.flatbuffers', # The implicit quantization is deprecated in TRT 10
330+
})
331+
]
326332

327333
# Convert static batch to dynamic batch
328334
[new_model_path, input_name] = convert_model_batch_to_dynamic(model_path)
@@ -343,7 +349,7 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
343349
model_path=augmented_model_path,
344350
input_name=input_name)
345351
calibrator.collect_data(data_reader)
346-
write_calibration_table(calibrator.compute_range())
352+
write_calibration_table(calibrator.compute_data())
347353

348354
# Run prediction in Tensorrt EP
349355
data_reader = ImageNetDataReader(ilsvrc2012_dataset_path,

0 commit comments

Comments
 (0)