Skip to content

Commit a121b98

Browse files
committed
add more docstrings; type checking
1 parent 33f5548 commit a121b98

File tree

2 files changed

+59
-13
lines changed

2 files changed

+59
-13
lines changed

sotabencheval/core/cache.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,19 @@
44

55

66
def cache_value(value):
7+
"""
8+
Takes in a value and puts it in a format ready for hashing + caching
9+
10+
Why? In sotabench we hash the output after the first batch as an indication of whether the model has changed or not.
11+
If the model hasn't changed, then we don't run the whole evaluation on the server - but return the same results
12+
as before. This speeds up evaluation - making "continuous evaluation" more feasible...it also means lower
13+
GPU costs for us :).
14+
15+
We apply some rounding and reformatting so small low precision changes do not change the hash.
16+
17+
:param value: example model output
18+
:return: formatted value (rounded and ready for hashing)
19+
"""
720
if isinstance(value, (str, int, bool)) or value is None:
821
return value
922
elif isinstance(value, float):

sotabencheval/utils.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,25 @@ def update(self, val, n=1):
3636
def calculate_batch_hash(output):
3737
"""Calculate the hash for the output of a batch
3838
39+
Output is passed into this method, stringified, and a hash is taken of the contents. For example,
40+
it could be an list of predictions that is passed in.
41+
3942
Args:
4043
output: data to be hashed
4144
"""
42-
4345
m = hashlib.sha256()
4446
m.update(str(output).encode("utf-8"))
4547
return m.hexdigest()
4648

4749

48-
def change_root_if_server(root, server_root):
50+
def change_root_if_server(root: str, server_root: str):
4951
"""
50-
:param root: string with a user-specified root
51-
:param server_root: string with a server root
52+
This method checks whether code is being executed on the sotabench server - if so it returns
53+
server_root, else root. Written as a method so the user doesn't have to fiddle with environmental
54+
variables.
55+
56+
:param root: (str) a user-specified root
57+
:param server_root: (str) a server root
5258
:return: server_root if SOTABENCH_SERVER env variable is set, else root
5359
"""
5460
check_server = os.environ.get("SOTABENCH_SERVER")
@@ -61,41 +67,68 @@ def change_root_if_server(root, server_root):
6167

6268
def is_server():
6369
"""
64-
If true, uses env variable SOTABENCH_SERVER to determine whether code is being run on the server
70+
Checks whether code is being executed on server; if so, returns True else False.
71+
72+
Uses env variable SOTABENCH_SERVER to determine whether code is being run on the server.
6573
6674
You can use this function for your control flow for server specific settings - e.g. the data paths.
67-
:return:
75+
76+
Examples:
77+
78+
.. code-block:: python
79+
80+
81+
from sotabencheval.utils import is_server
82+
83+
if is_server():
84+
DATA_ROOT = './.data/vision/imagenet'
85+
else: # local settings
86+
DATA_ROOT = '/home/ubuntu/my_data/'
87+
88+
:return: bool - whether the code is being run on the server or not
6889
"""
6990
if os.environ.get("SOTABENCH_SERVER") == 'true':
7091
return True
7192
else:
7293
return False
7394

7495

75-
def set_env_on_server(env_name, value):
96+
def set_env_on_server(env_name: str, value):
7697
"""
7798
If run on sotabench server, sets an environment variable with a given name to value (casted to str).
7899
79-
:param env_name: environment variable name
100+
:param env_name: (str) environment variable name
80101
:param value: value to set if executed on sotabench
81-
:return: whether code is being run on the server
102+
:return: bool - whether code is being run on the server
82103
"""
83104
if is_server():
84105
os.environ[env_name] = str(value)
85106
return True
86107
return False
87108

88109

89-
def get_max_memory_allocated(device='cuda'):
110+
def get_max_memory_allocated(device: str = 'cuda'):
111+
"""
112+
Finds out the maximum memory allocated, then clears the max memory allocated.
113+
114+
This currently only works for PyTorch models.
115+
116+
TODO: Support TensorFlow and MXNet.
117+
118+
:param device: (str) - name of device (Torch style) -> e.g. 'cuda'
119+
:return: float or None - if torch is in the environment, max memory allocated, else None
120+
"""
90121
try:
91122
import torch
92-
max_mem = torch.cuda.max_memory_allocated(device='cuda')
93-
torch.cuda.reset_max_memory_allocated(device='cuda')
123+
max_mem = torch.cuda.max_memory_allocated(device=device)
124+
torch.cuda.reset_max_memory_allocated(device=device)
94125
return max_mem
95126
except ImportError:
96127
return None
97128

98-
# below utilities are taken from the torchvision repository
129+
# Below the utilities have been taken directly from the torchvision repository
130+
# Contains helper functions for unzipping and making directories
131+
# https://github.com/pytorch/vision/tree/master/torchvision
99132

100133

101134
def makedir_exist_ok(dirpath):

0 commit comments

Comments
 (0)