Skip to content

Commit 05fe0a4

Browse files
authored
Add 'load_label' parameter for image classification models (#185)
* add 'load_label' parameter for image classification models * move load_label flag to initializer
1 parent d84d6d6 commit 05fe0a4

File tree

5 files changed

+26
-12
lines changed

5 files changed

+26
-12
lines changed

models/image_classification_mobilenet/demo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,16 @@
3131
{:d}: TIM-VX + NPU,
3232
{:d}: CANN + NPU
3333
'''.format(*[x for x in range(len(backend_target_pairs))]))
34+
parser.add_argument('--top_k', type=int, default=1,
35+
help='Usage: Get top k predictions.')
3436
args = parser.parse_args()
3537

3638
if __name__ == '__main__':
3739
backend_id = backend_target_pairs[args.backend_target][0]
3840
target_id = backend_target_pairs[args.backend_target][1]
41+
top_k = args.top_k
3942
# Instantiate MobileNet
40-
model = MobileNet(modelPath=args.model, backendId=backend_id, targetId=target_id)
43+
model = MobileNet(modelPath=args.model, topK=top_k, backendId=backend_id, targetId=target_id)
4144

4245
# Read image and get a 224x224 crop from a 256x256 resized
4346
image = cv.imread(args.input)

models/image_classification_mobilenet/mobilenet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ class MobileNet:
66
Works with MobileNet V1 & V2.
77
'''
88

9-
def __init__(self, modelPath, topK=1, backendId=0, targetId=0):
9+
def __init__(self, modelPath, topK=1, loadLabel=True, backendId=0, targetId=0):
1010
self.model_path = modelPath
1111
assert topK >= 1
1212
self.top_k = topK
13+
self.load_label = loadLabel
1314
self.backend_id = backendId
1415
self.target_id = targetId
1516

@@ -64,7 +65,7 @@ def _postprocess(self, output_blob):
6465
for o in output_blob:
6566
class_id_list = o.argsort()[::-1][:self.top_k]
6667
batched_class_id_list.append(class_id_list)
67-
if len(self._labels) > 0:
68+
if len(self._labels) > 0 and self.load_label:
6869
batched_predicted_labels = []
6970
for class_id_list in batched_class_id_list:
7071
predicted_labels = []

models/image_classification_ppresnet/demo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,16 @@
3737
{:d}: TIM-VX + NPU,
3838
{:d}: CANN + NPU
3939
'''.format(*[x for x in range(len(backend_target_pairs))]))
40+
parser.add_argument('--top_k', type=int, default=1,
41+
help='Usage: Get top k predictions.')
4042
args = parser.parse_args()
4143

4244
if __name__ == '__main__':
4345
backend_id = backend_target_pairs[args.backend_target][0]
4446
target_id = backend_target_pairs[args.backend_target][1]
47+
top_k = args.top_k
4548
# Instantiate ResNet
46-
model = PPResNet(modelPath=args.model, backendId=backend_id, targetId=target_id)
49+
model = PPResNet(modelPath=args.model, topK=top_k, backendId=backend_id, targetId=target_id)
4750

4851
# Read image and get a 224x224 crop from a 256x256 resized
4952
image = cv.imread(args.input)

models/image_classification_ppresnet/ppresnet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
import cv2 as cv
1010

1111
class PPResNet:
12-
def __init__(self, modelPath, topK=1, backendId=0, targetId=0):
12+
def __init__(self, modelPath, topK=1, loadLabel=True, backendId=0, targetId=0):
1313
self._modelPath = modelPath
1414
assert topK >= 1
1515
self._topK = topK
16+
self._load_label = loadLabel
1617
self._backendId = backendId
1718
self._targetId = targetId
1819

@@ -69,7 +70,7 @@ def _postprocess(self, outputBlob):
6970
for ob in outputBlob:
7071
class_id_list = ob.argsort()[::-1][:self._topK]
7172
batched_class_id_list.append(class_id_list)
72-
if len(self._labels) > 0:
73+
if len(self._labels) > 0 and self._load_label:
7374
batched_predicted_labels = []
7475
for class_id_list in batched_class_id_list:
7576
predicted_labels = []

tools/eval/eval.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,32 +25,38 @@
2525
name="MobileNet",
2626
topic="image_classification",
2727
modelPath=os.path.join(root_dir, "models/image_classification_mobilenet/image_classification_mobilenetv1_2022apr.onnx"),
28-
topK=5),
28+
topK=5,
29+
loadLabel=False),
2930
mobilenetv1_q=dict(
3031
name="MobileNet",
3132
topic="image_classification",
3233
modelPath=os.path.join(root_dir, "models/image_classification_mobilenet/image_classification_mobilenetv1_2022apr_int8.onnx"),
33-
topK=5),
34+
topK=5,
35+
loadLabel=False),
3436
mobilenetv2=dict(
3537
name="MobileNet",
3638
topic="image_classification",
3739
modelPath=os.path.join(root_dir, "models/image_classification_mobilenet/image_classification_mobilenetv2_2022apr.onnx"),
38-
topK=5),
40+
topK=5,
41+
loadLabel=False),
3942
mobilenetv2_q=dict(
4043
name="MobileNet",
4144
topic="image_classification",
4245
modelPath=os.path.join(root_dir, "models/image_classification_mobilenet/image_classification_mobilenetv2_2022apr_int8.onnx"),
43-
topK=5),
46+
topK=5,
47+
loadLabel=False),
4448
ppresnet=dict(
4549
name="PPResNet",
4650
topic="image_classification",
4751
modelPath=os.path.join(root_dir, "models/image_classification_ppresnet/image_classification_ppresnet50_2022jan.onnx"),
48-
topK=5),
52+
topK=5,
53+
loadLabel=False),
4954
ppresnet_q=dict(
5055
name="PPResNet",
5156
topic="image_classification",
5257
modelPath=os.path.join(root_dir, "models/image_classification_ppresnet/image_classification_ppresnet50_2022jan_int8.onnx"),
53-
topK=5),
58+
topK=5,
59+
loadLabel=False),
5460
yunet=dict(
5561
name="YuNet",
5662
topic="face_detection",

0 commit comments

Comments
 (0)