Skip to content

Commit 661b622

Browse files
authored
Merge pull request #36 from SFI-Visual-Intelligence/christian/sphinx-autoapi
Add autoapi to crawl our modules and populate the docs, closes #14
2 parents 891f09b + ebadcdf commit 661b622

File tree

9 files changed

+176
-18
lines changed

9 files changed

+176
-18
lines changed

doc/about.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# About this code
22

3-
Work in progress ...
3+
Work is still in progress ...

doc/conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77

88
extensions = [
99
"myst_parser", # in order to use markdown
10+
"autoapi.extension", # in order to generate API documentation
1011
]
1112

13+
# search this directory for Python files
14+
autoapi_dirs = ["../utils"]
15+
1216
myst_enable_extensions = [
1317
"colon_fence", # ::: can be used instead of ``` for better rendering
1418
]

main.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def main():
109109
metrics(y, preds)
110110

111111
break
112-
print(metrics.__getmetrics__())
112+
print(metrics.accumulate())
113113
print("Dry run completed successfully.")
114114
exit(0)
115115

@@ -135,8 +135,8 @@ def main():
135135
preds = th.argmax(logits, dim=1)
136136
metrics(y, preds)
137137

138-
wandb.log(metrics.__getmetrics__(str_prefix="Train "))
139-
metrics.__resetvalues__()
138+
wandb.log(metrics.accumulate(str_prefix="Train "))
139+
metrics.reset()
140140

141141
evalloss = []
142142
# Eval loop start
@@ -151,8 +151,8 @@ def main():
151151
preds = th.argmax(logits, dim=1)
152152
metrics(y, preds)
153153

154-
wandb.log(metrics.__getmetrics__(str_prefix="Evaluation "))
155-
metrics.__resetvalues__()
154+
wandb.log(metrics.accumulate(str_prefix="Evaluation "))
155+
metrics.reset()
156156

157157
wandb.log(
158158
{

utils/dataloaders/datasources.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
"""This module contains the data sources for the datasets used in the experiments.
2+
3+
The data sources are defined as dictionaries with the following keys
4+
- train: A list containing the URL, filename, and MD5 hash of the training data.
5+
- test: A list containing the URL, filename, and MD5 hash of the test data.
6+
"""
7+
18
USPS_SOURCE = {
29
"train": [
310
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",

utils/dataloaders/mnist_0_3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
class MNISTDataset0_3(Dataset):
1111
"""
1212
A custom dataset class for loading MNIST data, specifically for digits 0 through 3.
13+
1314
Parameters
1415
----------
1516
data_path : Path
@@ -20,6 +21,7 @@ class MNISTDataset0_3(Dataset):
2021
A function/transform that takes in an image and returns a transformed version. Default is None.
2122
download : bool, optional
2223
If True, downloads the dataset if it is not already present in the specified data_path. Default is False.
24+
2325
Attributes
2426
----------
2527
data_path : Path
@@ -40,6 +42,7 @@ class MNISTDataset0_3(Dataset):
4042
Indices of the labels that are less than 4.
4143
length : int
4244
The number of samples in the dataset.
45+
4346
Methods
4447
-------
4548
_parse_labels(train)

utils/dataloaders/usps_0_6.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class USPSDataset0_6(Dataset):
2626
Args
2727
----
2828
data_path : pathlib.Path
29-
Path to the USPS dataset file.
29+
Path to the data directory.
3030
train : bool, optional
3131
Mode of the dataset.
3232
transform : callable, optional
@@ -60,18 +60,29 @@ class USPSDataset0_6(Dataset):
6060
6161
Examples
6262
--------
63+
>>> from torchvision import transforms
6364
>>> from src.datahandlers import USPSDataset0_6
64-
>>> dataset = USPSDataset0_6(path="data/usps.h5", mode="train")
65+
>>> transform = transforms.Compose([
66+
... transforms.Resize((16, 16)),
67+
... transforms.ToTensor()
68+
... ])
69+
>>> dataset = USPSDataset0_6(
70+
... data_path="data",
71+
... transform=transform
72+
... download=True,
73+
... train=True,
74+
... )
6575
>>> len(dataset)
6676
5460
6777
>>> data, target = dataset[0]
6878
>>> data.shape
69-
(16, 16)
79+
(1, 16, 16)
7080
>>> target
71-
6
81+
tensor([1., 0., 0., 0., 0., 0., 0.])
7282
"""
7383

7484
filename = "usps.h5"
85+
num_classes = 7
7586

7687
def __init__(
7788
self,
@@ -85,7 +96,6 @@ def __init__(
8596
path = data_path if isinstance(data_path, Path) else Path(data_path)
8697
self.filepath = path / self.filename
8798
self.transform = transform
88-
self.num_classes = 7 # 0-6
8999
self.mode = "train" if train else "test"
90100

91101
# Download the dataset if it does not exist in a temporary directory
@@ -116,7 +126,24 @@ def _dataset_ok(self):
116126
return True
117127

118128
def download(self, url, filepath, checksum, mode):
119-
"""Download the USPS dataset."""
129+
"""Download the USPS dataset, and save it as an HDF5 file.
130+
131+
Args
132+
----
133+
url : str
134+
URL to download the dataset from.
135+
filepath : pathlib.Path
136+
Path to save the downloaded dataset.
137+
checksum : str
138+
MD5 checksum of the downloaded file.
139+
mode : str
140+
Mode of the dataset, either train or test.
141+
142+
Raises
143+
------
144+
ValueError
145+
If the checksum of the downloaded file does not match the expected checksum.
146+
"""
120147

121148
def reporthook(blocknum, blocksize, totalsize):
122149
"""Report download progress."""
@@ -164,7 +191,20 @@ def reporthook(blocknum, blocksize, totalsize):
164191

165192
@staticmethod
166193
def check_integrity(filepath, checksum):
167-
"""Check the integrity of the USPS dataset file."""
194+
"""Check the integrity of the USPS dataset file.
195+
196+
Args
197+
----
198+
filepath : pathlib.Path
199+
Path to the USPS dataset file.
200+
checksum : str
201+
MD5 checksum of the dataset file.
202+
203+
Returns
204+
-------
205+
bool
206+
True if the checksum of the file matches the expected checksum, False otherwise
207+
"""
168208

169209
file_hash = hashlib.md5(filepath.read_bytes()).hexdigest()
170210

utils/load_data.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,35 @@
44

55

66
def load_data(dataset: str, *args, **kwargs) -> Dataset:
7+
"""
8+
Load the dataset based on the dataset name.
9+
10+
Args
11+
----
12+
dataset : str
13+
Name of the dataset to load.
14+
*args : list
15+
Additional arguments for the dataset class.
16+
**kwargs : dict
17+
Additional keyword arguments for the dataset class.
18+
19+
Returns
20+
-------
21+
dataset : torch.utils.data.Dataset
22+
Dataset object.
23+
24+
Raises
25+
------
26+
NotImplementedError
27+
If the dataset is not implemented.
28+
29+
Examples
30+
--------
31+
>>> from utils import load_data
32+
>>> dataset = load_data("usps_0-6", data_path="data", train=True, download=True)
33+
>>> len(dataset)
34+
5460
35+
"""
736
match dataset.lower():
837
case "usps_0-6":
938
return USPSDataset0_6(*args, **kwargs)
@@ -12,4 +41,4 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
1241
case "usps_7-9":
1342
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
1443
case _:
15-
raise ValueError(f"Dataset: {dataset} not implemented.")
44+
raise NotImplementedError(f"Dataset: {dataset} not implemented.")

utils/load_metric.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,48 @@
77

88

99
class MetricWrapper(nn.Module):
10+
11+
"""
12+
Wrapper class for metrics, that runs multiple metrics on the same data.
13+
14+
Args
15+
----
16+
metrics : list[str]
17+
List of metrics to run on the data.
18+
19+
Attributes
20+
----------
21+
metrics : dict
22+
Dictionary containing the metric functions.
23+
tmp_scores : dict
24+
Dictionary containing the temporary scores of the metrics.
25+
26+
Methods
27+
-------
28+
__call__(y_true, y_pred)
29+
Call the metric functions on the true and predicted labels.
30+
accumulate()
31+
Get the average scores of the metrics.
32+
reset()
33+
Reset the temporary scores of the metrics.
34+
35+
Examples
36+
--------
37+
>>> from utils import MetricWrapper
38+
>>> metrics = MetricWrapper("entropy", "f1", "precision")
39+
>>> y_true = [0, 1, 0, 1]
40+
>>> y_pred = [0, 1, 1, 0]
41+
>>> metrics(y_true, y_pred)
42+
>>> metrics.accumulate()
43+
{'entropy': 0.6931471805599453, 'f1': 0.5, 'precision': 0.5}
44+
>>> metrics.reset()
45+
>>> metrics.accumulate()
46+
{'entropy': [], 'f1': [], 'precision': []}
47+
"""
48+
49+
1050
def __init__(self, *metrics, num_classes):
51+
1152
super().__init__()
1253
self.metrics = {}
1354
self.num_classes = num_classes
@@ -50,7 +91,7 @@ def __call__(self, y_true, y_pred):
5091
for key in self.metrics:
5192
self.tmp_scores[key].append(self.metrics[key](y_true, y_pred))
5293

53-
def __getmetrics__(self, str_prefix: str = None):
94+
def accumulate(self, str_prefix: str = None):
5495
return_metrics = {}
5596
for key in self.metrics:
5697
if str_prefix is not None:
@@ -60,6 +101,6 @@ def __getmetrics__(self, str_prefix: str = None):
60101

61102
return return_metrics
62103

63-
def __resetvalues__(self):
104+
def reset(self):
64105
for key in self.tmp_scores:
65106
self.tmp_scores[key] = []

utils/load_model.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,37 @@
44

55

66
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
7+
"""
8+
Load the model based on the model name.
9+
10+
Args
11+
----
12+
modelname : str
13+
Name of the model to load.
14+
*args : list
15+
Additional arguments for the model class.
16+
**kwargs : dict
17+
Additional keyword arguments for the model class.
18+
19+
Returns
20+
-------
21+
model : torch.nn.Module
22+
Model object.
23+
24+
Raises
25+
------
26+
NotImplementedError
27+
If the model is not implemented.
28+
29+
Examples
30+
--------
31+
>>> from utils import load_model
32+
>>> model = load_model("magnusmodel", num_classes=10)
33+
>>> model
34+
MagnusModel(
35+
(fc1): Linear(in_features=784, out_features=100, bias=True)
36+
(fc2): Linear(in_features=100, out_features=10, bias=True
37+
"""
738
match modelname.lower():
839
case "magnusmodel":
940
return MagnusModel(*args, **kwargs)
@@ -14,6 +45,9 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
1445
case "solveigmodel":
1546
return SolveigModel(*args, **kwargs)
1647
case _:
17-
raise ValueError(
18-
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
48+
errmsg = (
49+
f"Model: {modelname} not implemented. "
50+
"Check the documentation for implemented models, "
51+
"or check your spelling."
1952
)
53+
raise NotImplementedError(errmsg)

0 commit comments

Comments
 (0)