11# Some of the processing logic here is based on the torchvision ImageNet dataset
2+ # https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py
23
34import numpy as np
45import os
1213from sotabencheval .utils import get_max_memory_allocated
1314from sotabencheval .image_classification .utils import top_k_accuracy_score
1415
16+
1517ARCHIVE_DICT = {
1618 'labels' : {
1719 'url' : 'https://github.com/paperswithcode/sotabench-eval/releases/download/0.01/imagenet_val_targets.pkl' ,
@@ -37,6 +39,7 @@ class ImageNetEvaluator(object):
3739 from torch.utils.data import DataLoader
3840
3941 from sotabencheval.image_classification import ImageNetEvaluator
42+ from sotabencheval.utils import is_server
4043
4144 if is_server():
4245 DATA_ROOT = './.data/vision/imagenet'
@@ -107,11 +110,12 @@ def __init__(self,
107110 paper_pwc_id : str = None ,
108111 paper_results : dict = None ,
109112 model_description = None ,):
110- """Benchmarking function.
113+ """Initializes an ImageNet Evaluator object
111114
112115 Args:
113116 root (string): Root directory of the ImageNet Dataset - where the
114- label data is located (or will be downloaded to).
117+ label data is located (or will be downloaded to). Note this does not download
118+ the full ImageNet dataset (!) but just annotation information.
115119 model_name (str, optional): The name of the model from the
116120 paper - if you want to link your build to a model from a
117121 machine learning paper. See the ImageNet benchmark page for model names,
@@ -135,11 +139,12 @@ def __init__(self,
135139 'Top 5 Accuracy'.
136140 model_description (str, optional): Optional model description.
137141 """
138-
139142 root = self .root = os .path .expanduser (change_root_if_server (
140143 root = root ,
141144 server_root = "./.data/vision/imagenet" ))
142145
146+ # Model metadata
147+
143148 self .model_name = model_name
144149 self .paper_arxiv_id = paper_arxiv_id
145150 self .paper_pwc_id = paper_pwc_id
@@ -148,15 +153,19 @@ def __init__(self,
148153
149154 self .top1 = AverageMeter ()
150155 self .top5 = AverageMeter ()
151-
152156 self .load_targets ()
153157
154158 self .outputs = {}
155159 self .results = None
160+
161+ # Backend variables for hashing and caching
162+
156163 self .first_batch_processed = False
157164 self .batch_hash = None
158165 self .cached_results = False
159166
167+ # Speed and memory metrics
168+
160169 self .speed_mem_metrics = {}
161170 self .init_time = time .time ()
162171
@@ -167,9 +176,13 @@ def cache_exists(self):
167176 then sets self.results to cached results and returns True.
168177
169178 You can use this property for control flow to break a for loop over a dataset
170- after the first iteration. This prevents rerunning the same calculation for the
179+ after the first iteration. This prevents re-running the same calculation for the
171180 same model twice.
172181
182+ Q: Why should the user use this?
183+ A: If you want fast "continuous evaluation" and don't want to avoid rerunning the same model over and over
184+ each time you commit something new to your repository.
185+
173186 Examples:
174187 Breaking a for loop for a PyTorch evaluation
175188
@@ -192,9 +205,10 @@ def cache_exists(self):
192205
193206 evaluator.save() # uses the cached results
194207
195- :return:
196- """
208+ This logic is for the server; it will not break the loop if you evaluate locally.
197209
210+ :return: bool or None (if not in check mode)
211+ """
198212 if not self .first_batch_processed :
199213 raise ValueError ('No batches of data have been processed so no batch_hash exists' )
200214
@@ -217,7 +231,8 @@ def cache_exists(self):
217231
218232 def load_targets (self ):
219233 """
220- Downloads ImageNet labels and IDs and puts into self.root, then loads at self.targets
234+ Downloads ImageNet labels and IDs and puts into self.root, then loads to self.targets
235+
221236 :return: void - update self.targets with the ImageNet validation data labels, and downloads if
222237 the pickled validation data is not in the root location
223238 """
@@ -233,7 +248,7 @@ def add(self, output_dict: dict):
233248 """
234249 Updates the evaluator with new results
235250
236- :param output_dict (dict): Where keys are image IDs, and each value should be an 1D np.ndarray of size 1000
251+ :param output_dict: (dict) Where keys are image IDs, and each value should be an 1D np.ndarray of size 1000
237252 containing logits for that image ID.
238253 :return: void - updates self.outputs with the new IDSs and prediction
239254
@@ -245,7 +260,6 @@ def add(self, output_dict: dict):
245260 my_evaluator.add({'ILSVRC2012_val_00000293': np.array([1.04243, ...]),
246261 'ILSVRC2012_val_00000294': np.array([-2.3677, ...])})
247262 """
248-
249263 if not output_dict :
250264 print ('Empty output_dict; will not process' )
251265 return
@@ -312,21 +326,28 @@ def get_results(self):
312326 return self .results
313327
314328 def reset_time (self ):
329+ """
330+ Simple method to reset the timer self.init_time. Often used before a loop, to time the evaluation
331+ appropriately, for example:
332+
333+ :return: void - resets self.init_time
334+ """
315335 self .init_time = time .time ()
316336
317337 def save (self ):
318338 """
319- Calculate results and then puts into a BenchmarkResult object
339+ Calculate results and then put into a BenchmarkResult object
320340
321- On the sotabench.com server, this will produce a JSON file serialisation and results will be recorded
322- on the platform.
341+ On the sotabench.com server, this will produce a JSON file serialisation in sotabench_results.json and results
342+ will be recorded on the platform.
323343
324344 :return: BenchmarkResult object with results and metadata
325345 """
326-
327346 # recalculate to ensure no mistakes made during batch-by-batch metric calculation
328347 self .get_results ()
329348
349+ # If this is the first time the model is run, then we record evaluation time information
350+
330351 if not self .cached_results :
331352 exec_speed = (time .time () - self .init_time )
332353 self .speed_mem_metrics ['Tasks / Evaluation Time' ] = len (self .outputs ) / exec_speed
0 commit comments