11import numpy as np
2+ import pickle
23from sotabenchapi .core import BenchmarkResult , check_inputs
34import tqdm
45
56from sotabencheval .utils import AverageMeter
67from .utils import top_k_accuracy_score
78
8- class ImageNet :
9+
10+
11+ class ImageNetEvaluator (object ):
12+ """`ImageNet <https://www.sotabench.com/benchmark/imagenet>`_ benchmark.
13+
14+ Examples:
15+ Evaluate a ResNeXt model from the torchvision repository:
16+
17+ .. code-block:: python
18+
19+ import numpy as np
20+ import PIL
21+ import torch
22+ from sotabencheval.image_classification import ImageNetEvaluator
23+ from torchvision.models.resnet import resnext101_32x8d
24+ import torchvision.transforms as transforms
25+ from torchvision.datasets import ImageNet
26+ from torch.utils.data import DataLoader
27+
28+ model = resnext101_32x8d(pretrained=True)
29+
30+ # Define the transforms need to convert ImageNet data to expected
31+ # model input
32+ normalize = transforms.Normalize(
33+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
34+ )
35+ input_transform = transforms.Compose([
36+ transforms.Resize(256, PIL.Image.BICUBIC),
37+ transforms.CenterCrop(224),
38+ transforms.ToTensor(),
39+ normalize,
40+ ])
41+
42+ test_dataset = ImageNet(
43+ './data',
44+ split="val",
45+ transform=input_transform,
46+ target_transform=None,
47+ download=True,
48+ )
49+
50+ test_loader = DataLoader(
51+ test_dataset,
52+ batch_size=128,
53+ shuffle=False,
54+ num_workers=4,
55+ pin_memory=True,
56+ )
57+
58+ model = model.cuda()
59+ model.eval()
60+
61+ final_output = None
62+ evaluator = ImageNetEvaluator(
63+ paper_model_name='ResNeXt-101-32x8d',
64+ paper_arxiv_id='1611.05431')
65+
66+ with torch.no_grad():
67+ for i, (input, target) in enumerate(test_loader):
68+ input = input.to(device=device, non_blocking=True)
69+ target = target.to(device=device, non_blocking=True)
70+ output = model(input)
71+
72+ image_ids = [img[0].split('/')[-1].replace('.JPEG', '') for img in test_loader.dataset.imgs[i*test_loader.batch_size:(i+1)*test_loader.batch_size]]
73+
74+ evaluator.update(dict(zip(image_ids, list(output.cpu().numpy()))))
75+
76+ print(evaluator.get_results())
77+
78+ evaluator.save()
79+ """
80+
981 task = "Image Classification"
1082
11- @classmethod
12- @check_inputs
13- def benchmark (
14- cls ,
15- results_dict ,
16- data_root : str = "./.data/vision/imagenet" ,
17- paper_model_name : str = None ,
18- paper_arxiv_id : str = None ,
19- paper_pwc_id : str = None ,
20- paper_results : dict = None ,
21- pytorch_hub_url : str = None ,
22- model_description = None ,
23- ) -> BenchmarkResult :
83+ def __init__ (self ,
84+ paper_model_name : str = None ,
85+ paper_arxiv_id : str = None ,
86+ paper_pwc_id : str = None ,
87+ paper_results : dict = None ,
88+ pytorch_hub_url : str = None ,
89+ model_description = None ,):
2490 """Benchmarking function.
2591
2692 Args:
27- results_dict (dict): dict with keys as image IDs and values as a 1D 1000 x 1 np.ndarrays
28- of logits. For example: {'ILSVRC2012_val_00000293': array([1.27443619e+01, ...]), ...}. There
29- should be 5000 key/value pairs for the validation set.
30- data_root (str): The location of the ImageNet dataset - change this
31- parameter when evaluating locally if your ImageNet data is
32- located in a different folder (or alternatively if you want to
33- download to an alternative location).
34- model_description (str, optional): Optional model description.
3593 paper_model_name (str, optional): The name of the model from the
3694 paper - if you want to link your build to a machine learning
3795 paper. See the ImageNet benchmark page for model names,
@@ -56,64 +114,108 @@ def benchmark(
56114 pytorch_hub_url (str, optional): Optional linking to PyTorch Hub
57115 url if your model is linked there; e.g:
58116 'nvidia_deeplearningexamples_waveglow'.
117+ model_description (str, optional): Optional model description.
59118 """
60119
61- print ("Benchmarking on ImageNet..." )
120+ self .paper_model_name = paper_model_name
121+ self .paper_arxiv_id = paper_arxiv_id
122+ self .paper_pwc_id = paper_pwc_id
123+ self .paper_results = paper_results
124+ self .pytorch_hub_url = pytorch_hub_url
125+ self .model_description = model_description
62126
63- config = locals ()
127+ self .top1 = None
128+ self .top5 = None
64129
65- try :
66- test_dataset = cls .dataset (
67- data_root ,
68- split = "val" ,
69- transform = cls .input_transform ,
70- target_transform = None ,
71- download = True ,
72- )
73- except Exception :
74- test_dataset = cls .dataset (
75- data_root ,
76- split = "val" ,
77- transform = cls .input_transform ,
78- target_transform = None ,
79- download = False ,
80- )
130+ with open ('imagenet_val_targets.pkl' , 'rb' ) as handle :
131+ self .targets = pickle .load (handle )
132+
133+ self .outputs = {}
134+ self .results = None
135+
136+ def update (self , output_dict : dict ):
137+ """
138+ Update the evaluator with new results
139+
140+
141+ :param output_dict (dict): Where keys are image IDs, and each value should be an 1D np.ndarray of size 1000
142+ containing logits for that image ID.
143+ :return: void - updates self.outputs with the new IDSs and prediction
144+
145+ Examples:
146+ Update the evaluator with two results:
147+
148+ .. code-block:: python
149+
150+ my_evaluator.update({'ILSVRC2012_val_00000293': np.array([1.04243, ...]),
151+ 'ILSVRC2012_val_00000294': np.array([-2.3677, ...])})
152+ """
153+
154+ self .outputs = dict (list (self .outputs .items ()) + list (output_dict .items ()))
155+
156+ def get_results (self ):
157+ """
158+ Gets the results for the evaluator. This method only runs if predictions for all 5,000 ImageNet validation
159+ images are available. Otherwise raises an error and informs you of the missing or unmatched IDs.
160+
161+ :return: dict with Top 1 and Top 5 Accuracy
162+ """
163+
164+ if set (self .targets .keys ()) != set (self .outputs .keys ()):
165+ missing_ids = set (self .targets .keys ()) - set (self .outputs .keys ())
166+ unmatched_ids = set (self .outputs .keys ()) - set (self .targets .keys ())
81167
82- top1 = AverageMeter ()
83- top5 = AverageMeter ()
168+ if len (unmatched_ids ) > 0 :
169+ raise AttributeError ('''There are {mis_no} missing and {un_no} unmatched image IDs\n \n '''
170+ '''Missing IDs are {missing}\n \n '''
171+ '''Unmatched IDs are {unmatched}''' .format (mis_no = len (missing_ids ),
172+ un_no = len (unmatched_ids ),
173+ missing = missing_ids ,
174+ unmatched = unmatched_ids ))
175+ else :
176+ raise AttributeError ('''There are {mis_no} missing image IDs\n \n '''
177+ '''Missing IDs are {missing}''' .format (mis_no = len (missing_ids ),
178+ missing = missing_ids ))
84179
85- for i , (_ , target ) in enumerate (tqdm .tqdm (test_dataset )):
86- image_id = test_dataset .imgs [i ][0 ].split ('/' )[- 1 ].replace ('.JPEG' , '' )
87- output = results_dict [image_id ]
88- target = target .cpu ().numpy ()
180+ # Do the calculation only if we have all the results...
181+ self .top1 = AverageMeter ()
182+ self .top5 = AverageMeter ()
89183
184+ for i , dict_key in enumerate (tqdm .tqdm (self .targets .keys ())):
185+ output = self .outputs [dict_key ]
186+ target = self .targets [dict_key ]
90187 prec1 = top_k_accuracy_score (y_true = target , y_pred = np .array ([output ]), k = 1 )
91188 prec5 = top_k_accuracy_score (y_true = target , y_pred = np .array ([output ]), k = 5 )
92- top1 .update (prec1 , 1 )
93- top5 .update (prec5 , 1 )
94-
95- final_results = {
96- 'Top 1 Accuracy' : prec1 .avg ,
97- 'Top 5 Accuracy' : prec5 .avg
98- }
99-
100- print (
101- " * Acc@1 {top1:.3f} Acc@5 {top5:.3f}" .format (
102- top1 = final_results ["Top 1 Accuracy" ],
103- top5 = final_results ["Top 5 Accuracy" ],
104- )
105- )
189+ self .top1 .update (prec1 , 1 )
190+ self .top5 .update (prec5 , 1 )
191+
192+ self .results = {'Top 1 Accuracy' : self .top1 .avg , 'Top 5 Accuracy' : self .top5 .avg }
193+
194+ return self .results
195+
196+ def save (self ):
197+ """
198+ Calculate results and then put into a BenchmarkResult object
199+
200+ On the sotabench.com server, this will produce a JSON file serialisation and results will be recorded
201+ on the platform.
202+
203+ :return: BenchmarkResult object with results and metadata
204+ """
205+
206+ if not self .results :
207+ self .get_results ()
106208
107209 return BenchmarkResult (
108- task = cls .task ,
109- config = config ,
110- dataset = cls . dataset . __name__ ,
111- results = final_results ,
112- pytorch_hub_id = pytorch_hub_url ,
113- model = paper_model_name ,
114- model_description = model_description ,
115- arxiv_id = paper_arxiv_id ,
116- pwc_id = paper_pwc_id ,
117- paper_results = paper_results ,
210+ task = self .task ,
211+ config = {} ,
212+ dataset = 'ImageNet' ,
213+ results = self . results ,
214+ pytorch_hub_id = self . pytorch_hub_url ,
215+ model = self . paper_model_name ,
216+ model_description = self . model_description ,
217+ arxiv_id = self . paper_arxiv_id ,
218+ pwc_id = self . paper_pwc_id ,
219+ paper_results = self . paper_results ,
118220 run_hash = None ,
119221 )
0 commit comments