Skip to content

Commit e85761c

Browse files
committed
Update imagenet evaluation to batch design
1 parent 6f2f1ce commit e85761c

File tree

2 files changed

+176
-73
lines changed

2 files changed

+176
-73
lines changed
Lines changed: 172 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,95 @@
11
import numpy as np
2+
import pickle
23
from sotabenchapi.core import BenchmarkResult, check_inputs
34
import tqdm
45

56
from sotabencheval.utils import AverageMeter
67
from .utils import top_k_accuracy_score
78

8-
class ImageNet:
9+
10+
11+
class ImageNetEvaluator(object):
12+
"""`ImageNet <https://www.sotabench.com/benchmark/imagenet>`_ benchmark.
13+
14+
Examples:
15+
Evaluate a ResNeXt model from the torchvision repository:
16+
17+
.. code-block:: python
18+
19+
import numpy as np
20+
import PIL
21+
import torch
22+
from sotabencheval.image_classification import ImageNetEvaluator
23+
from torchvision.models.resnet import resnext101_32x8d
24+
import torchvision.transforms as transforms
25+
from torchvision.datasets import ImageNet
26+
from torch.utils.data import DataLoader
27+
28+
model = resnext101_32x8d(pretrained=True)
29+
30+
# Define the transforms need to convert ImageNet data to expected
31+
# model input
32+
normalize = transforms.Normalize(
33+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
34+
)
35+
input_transform = transforms.Compose([
36+
transforms.Resize(256, PIL.Image.BICUBIC),
37+
transforms.CenterCrop(224),
38+
transforms.ToTensor(),
39+
normalize,
40+
])
41+
42+
test_dataset = ImageNet(
43+
'./data',
44+
split="val",
45+
transform=input_transform,
46+
target_transform=None,
47+
download=True,
48+
)
49+
50+
test_loader = DataLoader(
51+
test_dataset,
52+
batch_size=128,
53+
shuffle=False,
54+
num_workers=4,
55+
pin_memory=True,
56+
)
57+
58+
model = model.cuda()
59+
model.eval()
60+
61+
final_output = None
62+
evaluator = ImageNetEvaluator(
63+
paper_model_name='ResNeXt-101-32x8d',
64+
paper_arxiv_id='1611.05431')
65+
66+
with torch.no_grad():
67+
for i, (input, target) in enumerate(test_loader):
68+
input = input.to(device=device, non_blocking=True)
69+
target = target.to(device=device, non_blocking=True)
70+
output = model(input)
71+
72+
image_ids = [img[0].split('/')[-1].replace('.JPEG', '') for img in test_loader.dataset.imgs[i*test_loader.batch_size:(i+1)*test_loader.batch_size]]
73+
74+
evaluator.update(dict(zip(image_ids, list(output.cpu().numpy()))))
75+
76+
print(evaluator.get_results())
77+
78+
evaluator.save()
79+
"""
80+
981
task = "Image Classification"
1082

11-
@classmethod
12-
@check_inputs
13-
def benchmark(
14-
cls,
15-
results_dict,
16-
data_root: str = "./.data/vision/imagenet",
17-
paper_model_name: str = None,
18-
paper_arxiv_id: str = None,
19-
paper_pwc_id: str = None,
20-
paper_results: dict = None,
21-
pytorch_hub_url: str = None,
22-
model_description=None,
23-
) -> BenchmarkResult:
83+
def __init__(self,
84+
paper_model_name: str = None,
85+
paper_arxiv_id: str = None,
86+
paper_pwc_id: str = None,
87+
paper_results: dict = None,
88+
pytorch_hub_url: str = None,
89+
model_description=None,):
2490
"""Benchmarking function.
2591
2692
Args:
27-
results_dict (dict): dict with keys as image IDs and values as a 1D 1000 x 1 np.ndarrays
28-
of logits. For example: {'ILSVRC2012_val_00000293': array([1.27443619e+01, ...]), ...}. There
29-
should be 5000 key/value pairs for the validation set.
30-
data_root (str): The location of the ImageNet dataset - change this
31-
parameter when evaluating locally if your ImageNet data is
32-
located in a different folder (or alternatively if you want to
33-
download to an alternative location).
34-
model_description (str, optional): Optional model description.
3593
paper_model_name (str, optional): The name of the model from the
3694
paper - if you want to link your build to a machine learning
3795
paper. See the ImageNet benchmark page for model names,
@@ -56,64 +114,108 @@ def benchmark(
56114
pytorch_hub_url (str, optional): Optional linking to PyTorch Hub
57115
url if your model is linked there; e.g:
58116
'nvidia_deeplearningexamples_waveglow'.
117+
model_description (str, optional): Optional model description.
59118
"""
60119

61-
print("Benchmarking on ImageNet...")
120+
self.paper_model_name = paper_model_name
121+
self.paper_arxiv_id = paper_arxiv_id
122+
self.paper_pwc_id = paper_pwc_id
123+
self.paper_results = paper_results
124+
self.pytorch_hub_url = pytorch_hub_url
125+
self.model_description = model_description
62126

63-
config = locals()
127+
self.top1 = None
128+
self.top5 = None
64129

65-
try:
66-
test_dataset = cls.dataset(
67-
data_root,
68-
split="val",
69-
transform=cls.input_transform,
70-
target_transform=None,
71-
download=True,
72-
)
73-
except Exception:
74-
test_dataset = cls.dataset(
75-
data_root,
76-
split="val",
77-
transform=cls.input_transform,
78-
target_transform=None,
79-
download=False,
80-
)
130+
with open('imagenet_val_targets.pkl', 'rb') as handle:
131+
self.targets = pickle.load(handle)
132+
133+
self.outputs = {}
134+
self.results = None
135+
136+
def update(self, output_dict: dict):
137+
"""
138+
Update the evaluator with new results
139+
140+
141+
:param output_dict (dict): Where keys are image IDs, and each value should be an 1D np.ndarray of size 1000
142+
containing logits for that image ID.
143+
:return: void - updates self.outputs with the new IDSs and prediction
144+
145+
Examples:
146+
Update the evaluator with two results:
147+
148+
.. code-block:: python
149+
150+
my_evaluator.update({'ILSVRC2012_val_00000293': np.array([1.04243, ...]),
151+
'ILSVRC2012_val_00000294': np.array([-2.3677, ...])})
152+
"""
153+
154+
self.outputs = dict(list(self.outputs.items()) + list(output_dict.items()))
155+
156+
def get_results(self):
157+
"""
158+
Gets the results for the evaluator. This method only runs if predictions for all 5,000 ImageNet validation
159+
images are available. Otherwise raises an error and informs you of the missing or unmatched IDs.
160+
161+
:return: dict with Top 1 and Top 5 Accuracy
162+
"""
163+
164+
if set(self.targets.keys()) != set(self.outputs.keys()):
165+
missing_ids = set(self.targets.keys()) - set(self.outputs.keys())
166+
unmatched_ids = set(self.outputs.keys()) - set(self.targets.keys())
81167

82-
top1 = AverageMeter()
83-
top5 = AverageMeter()
168+
if len(unmatched_ids) > 0:
169+
raise AttributeError('''There are {mis_no} missing and {un_no} unmatched image IDs\n\n'''
170+
'''Missing IDs are {missing}\n\n'''
171+
'''Unmatched IDs are {unmatched}'''.format(mis_no=len(missing_ids),
172+
un_no=len(unmatched_ids),
173+
missing=missing_ids,
174+
unmatched=unmatched_ids))
175+
else:
176+
raise AttributeError('''There are {mis_no} missing image IDs\n\n'''
177+
'''Missing IDs are {missing}'''.format(mis_no=len(missing_ids),
178+
missing=missing_ids))
84179

85-
for i, (_, target) in enumerate(tqdm.tqdm(test_dataset)):
86-
image_id = test_dataset.imgs[i][0].split('/')[-1].replace('.JPEG', '')
87-
output = results_dict[image_id]
88-
target = target.cpu().numpy()
180+
# Do the calculation only if we have all the results...
181+
self.top1 = AverageMeter()
182+
self.top5 = AverageMeter()
89183

184+
for i, dict_key in enumerate(tqdm.tqdm(self.targets.keys())):
185+
output = self.outputs[dict_key]
186+
target = self.targets[dict_key]
90187
prec1 = top_k_accuracy_score(y_true=target, y_pred=np.array([output]), k=1)
91188
prec5 = top_k_accuracy_score(y_true=target, y_pred=np.array([output]), k=5)
92-
top1.update(prec1, 1)
93-
top5.update(prec5, 1)
94-
95-
final_results = {
96-
'Top 1 Accuracy': prec1.avg,
97-
'Top 5 Accuracy': prec5.avg
98-
}
99-
100-
print(
101-
" * Acc@1 {top1:.3f} Acc@5 {top5:.3f}".format(
102-
top1=final_results["Top 1 Accuracy"],
103-
top5=final_results["Top 5 Accuracy"],
104-
)
105-
)
189+
self.top1.update(prec1, 1)
190+
self.top5.update(prec5, 1)
191+
192+
self.results = {'Top 1 Accuracy': self.top1.avg, 'Top 5 Accuracy': self.top5.avg}
193+
194+
return self.results
195+
196+
def save(self):
197+
"""
198+
Calculate results and then put into a BenchmarkResult object
199+
200+
On the sotabench.com server, this will produce a JSON file serialisation and results will be recorded
201+
on the platform.
202+
203+
:return: BenchmarkResult object with results and metadata
204+
"""
205+
206+
if not self.results:
207+
self.get_results()
106208

107209
return BenchmarkResult(
108-
task=cls.task,
109-
config=config,
110-
dataset=cls.dataset.__name__,
111-
results=final_results,
112-
pytorch_hub_id=pytorch_hub_url,
113-
model=paper_model_name,
114-
model_description=model_description,
115-
arxiv_id=paper_arxiv_id,
116-
pwc_id=paper_pwc_id,
117-
paper_results=paper_results,
210+
task=self.task,
211+
config={},
212+
dataset='ImageNet',
213+
results=self.results,
214+
pytorch_hub_id=self.pytorch_hub_url,
215+
model=self.paper_model_name,
216+
model_description=self.model_description,
217+
arxiv_id=self.paper_arxiv_id,
218+
pwc_id=self.paper_pwc_id,
219+
paper_results=self.paper_results,
118220
run_hash=None,
119221
)

sotabencheval/image_classification/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22

33
def top_k_accuracy_score(y_true, y_pred, k=5, normalize=True):
4-
"""Top k Accuracy classification score.
5-
"""
6-
assert(y_true.shape == 1) # should be 1D, each index is obs true label
4+
"""Top k Accuracy classification score."""
5+
6+
if len(y_true.shape) == 2:
7+
y_true = y_true[0] # should be one-dimensional
78

89
num_obs, num_labels = y_pred.shape
910

0 commit comments

Comments
 (0)