Skip to content

Commit 2e202c9

Browse files
committed
Update docstrings for documentation (still lots to do)
Added a bit more information to some of the docstrings, but there is still a lot of room for improvement ;)
1 parent d6128d7 commit 2e202c9

File tree

6 files changed

+163
-12
lines changed

6 files changed

+163
-12
lines changed

utils/dataloaders/datasources.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
"""This module contains the data sources for the datasets used in the experiments.
2+
3+
The data sources are defined as dictionaries with the following keys
4+
- train: A list containing the URL, filename, and MD5 hash of the training data.
5+
- test: A list containing the URL, filename, and MD5 hash of the test data.
6+
"""
7+
18
USPS_SOURCE = {
29
"train": [
310
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",

utils/dataloaders/mnist_0_3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
class MNISTDataset0_3(Dataset):
1111
"""
1212
A custom dataset class for loading MNIST data, specifically for digits 0 through 3.
13+
1314
Parameters
1415
----------
1516
data_path : Path
@@ -20,6 +21,7 @@ class MNISTDataset0_3(Dataset):
2021
A function/transform that takes in an image and returns a transformed version. Default is None.
2122
download : bool, optional
2223
If True, downloads the dataset if it is not already present in the specified data_path. Default is False.
24+
2325
Attributes
2426
----------
2527
data_path : Path
@@ -40,6 +42,7 @@ class MNISTDataset0_3(Dataset):
4042
Indices of the labels that are less than 4.
4143
length : int
4244
The number of samples in the dataset.
45+
4346
Methods
4447
-------
4548
_parse_labels(train)

utils/dataloaders/usps_0_6.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class USPSDataset0_6(Dataset):
2626
Args
2727
----
2828
data_path : pathlib.Path
29-
Path to the USPS dataset file.
29+
Path to the data directory.
3030
train : bool, optional
3131
Mode of the dataset.
3232
transform : callable, optional
@@ -60,18 +60,29 @@ class USPSDataset0_6(Dataset):
6060
6161
Examples
6262
--------
63+
>>> from torchvision import transforms
6364
>>> from src.datahandlers import USPSDataset0_6
64-
>>> dataset = USPSDataset0_6(path="data/usps.h5", mode="train")
65+
>>> transform = transforms.Compose([
66+
... transforms.Resize((16, 16)),
67+
... transforms.ToTensor()
68+
... ])
69+
>>> dataset = USPSDataset0_6(
70+
... data_path="data",
71+
... transform=transform
72+
... download=True,
73+
... train=True,
74+
... )
6575
>>> len(dataset)
6676
5460
6777
>>> data, target = dataset[0]
6878
>>> data.shape
69-
(16, 16)
79+
(1, 16, 16)
7080
>>> target
71-
6
81+
tensor([1., 0., 0., 0., 0., 0., 0.])
7282
"""
7383

7484
filename = "usps.h5"
85+
num_classes = 7
7586

7687
def __init__(
7788
self,
@@ -85,7 +96,6 @@ def __init__(
8596
path = data_path if isinstance(data_path, Path) else Path(data_path)
8697
self.filepath = path / self.filename
8798
self.transform = transform
88-
self.num_classes = 7 # 0-6
8999
self.mode = "train" if train else "test"
90100

91101
# Download the dataset if it does not exist in a temporary directory
@@ -116,7 +126,24 @@ def _dataset_ok(self):
116126
return True
117127

118128
def download(self, url, filepath, checksum, mode):
119-
"""Download the USPS dataset."""
129+
"""Download the USPS dataset, and save it as an HDF5 file.
130+
131+
Args
132+
----
133+
url : str
134+
URL to download the dataset from.
135+
filepath : pathlib.Path
136+
Path to save the downloaded dataset.
137+
checksum : str
138+
MD5 checksum of the downloaded file.
139+
mode : str
140+
Mode of the dataset, either train or test.
141+
142+
Raises
143+
------
144+
ValueError
145+
If the checksum of the downloaded file does not match the expected checksum.
146+
"""
120147

121148
def reporthook(blocknum, blocksize, totalsize):
122149
"""Report download progress."""
@@ -164,7 +191,20 @@ def reporthook(blocknum, blocksize, totalsize):
164191

165192
@staticmethod
166193
def check_integrity(filepath, checksum):
167-
"""Check the integrity of the USPS dataset file."""
194+
"""Check the integrity of the USPS dataset file.
195+
196+
Args
197+
----
198+
filepath : pathlib.Path
199+
Path to the USPS dataset file.
200+
checksum : str
201+
MD5 checksum of the dataset file.
202+
203+
Returns
204+
-------
205+
bool
206+
True if the checksum of the file matches the expected checksum, False otherwise
207+
"""
168208

169209
file_hash = hashlib.md5(filepath.read_bytes()).hexdigest()
170210

utils/load_data.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,35 @@
44

55

66
def load_data(dataset: str, *args, **kwargs) -> Dataset:
7+
"""
8+
Load the dataset based on the dataset name.
9+
10+
Args
11+
----
12+
dataset : str
13+
Name of the dataset to load.
14+
*args : list
15+
Additional arguments for the dataset class.
16+
**kwargs : dict
17+
Additional keyword arguments for the dataset class.
18+
19+
Returns
20+
-------
21+
dataset : torch.utils.data.Dataset
22+
Dataset object.
23+
24+
Raises
25+
------
26+
NotImplementedError
27+
If the dataset is not implemented.
28+
29+
Examples
30+
--------
31+
>>> from utils import load_data
32+
>>> dataset = load_data("usps_0-6", data_path="data", train=True, download=True)
33+
>>> len(dataset)
34+
5460
35+
"""
736
match dataset.lower():
837
case "usps_0-6":
938
return USPSDataset0_6(*args, **kwargs)
@@ -12,4 +41,4 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
1241
case "usps_7-9":
1342
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
1443
case _:
15-
raise ValueError(f"Dataset: {dataset} not implemented.")
44+
raise NotImplementedError(f"Dataset: {dataset} not implemented.")

utils/load_metric.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,44 @@
77

88

99
class MetricWrapper(nn.Module):
10+
"""
11+
Wrapper class for metrics, that runs multiple metrics on the same data.
12+
13+
Args
14+
----
15+
metrics : list[str]
16+
List of metrics to run on the data.
17+
18+
Attributes
19+
----------
20+
metrics : dict
21+
Dictionary containing the metric functions.
22+
tmp_scores : dict
23+
Dictionary containing the temporary scores of the metrics.
24+
25+
Methods
26+
-------
27+
__call__(y_true, y_pred)
28+
Call the metric functions on the true and predicted labels.
29+
accumulate()
30+
Get the average scores of the metrics.
31+
reset()
32+
Reset the temporary scores of the metrics.
33+
34+
Examples
35+
--------
36+
>>> from utils import MetricWrapper
37+
>>> metrics = MetricWrapper("entropy", "f1", "precision")
38+
>>> y_true = [0, 1, 0, 1]
39+
>>> y_pred = [0, 1, 1, 0]
40+
>>> metrics(y_true, y_pred)
41+
>>> metrics.accumulate()
42+
{'entropy': 0.6931471805599453, 'f1': 0.5, 'precision': 0.5}
43+
>>> metrics.reset()
44+
>>> metrics.accumulate()
45+
{'entropy': [], 'f1': [], 'precision': []}
46+
"""
47+
1048
def __init__(self, *metrics):
1149
super().__init__()
1250
self.metrics = {}
@@ -49,13 +87,13 @@ def __call__(self, y_true, y_pred):
4987
for key in self.metrics:
5088
self.tmp_scores[key].append(self.metrics[key](y_true, y_pred))
5189

52-
def __getmetrics__(self):
90+
def accumulate(self):
5391
return_metrics = {}
5492
for key in self.metrics:
5593
return_metrics[key] = np.mean(self.tmp_scores[key])
5694

5795
return return_metrics
5896

59-
def __resetvalues__(self):
97+
def reset(self):
6098
for key in self.tmp_scores:
6199
self.tmp_scores[key] = []

utils/load_model.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,37 @@
44

55

66
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
7+
"""
8+
Load the model based on the model name.
9+
10+
Args
11+
----
12+
modelname : str
13+
Name of the model to load.
14+
*args : list
15+
Additional arguments for the model class.
16+
**kwargs : dict
17+
Additional keyword arguments for the model class.
18+
19+
Returns
20+
-------
21+
model : torch.nn.Module
22+
Model object.
23+
24+
Raises
25+
------
26+
NotImplementedError
27+
If the model is not implemented.
28+
29+
Examples
30+
--------
31+
>>> from utils import load_model
32+
>>> model = load_model("magnusmodel", num_classes=10)
33+
>>> model
34+
MagnusModel(
35+
(fc1): Linear(in_features=784, out_features=100, bias=True)
36+
(fc2): Linear(in_features=100, out_features=10, bias=True
37+
"""
738
match modelname.lower():
839
case "magnusmodel":
940
return MagnusModel(*args, **kwargs)
@@ -14,6 +45,9 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
1445
case "solveigmodel":
1546
return SolveigModel(*args, **kwargs)
1647
case _:
17-
raise ValueError(
18-
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
48+
errmsg = (
49+
f"Model: {modelname} not implemented. "
50+
"Check the documentation for implemented models, "
51+
"or check your spelling."
1952
)
53+
raise NotImplementedError(errmsg)

0 commit comments

Comments
 (0)