Skip to content

Commit 80e89eb

Browse files
authored
Merge pull request #119 from SFI-Visual-Intelligence/Jan-doc
Saving model and metrics locally
2 parents 4d51869 + d7dbe28 commit 80e89eb

File tree

9 files changed

+289
-48
lines changed

9 files changed

+289
-48
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ doc/autoapi
1212

1313
*.DS_Store
1414

15+
# Jan
16+
job.yaml
17+
sync.sh
1518

1619
#Magnus specific
1720
job*

CollaborativeCoding/dataloaders/download.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
class Downloader:
1818
"""
19-
Class to download and load the USPS dataset.
19+
Class used to verify availability and potentially download implemented datasets.
2020
2121
Methods
2222
-------
2323
mnist(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
24-
Download the MNIST dataset and save it as an HDF5 file to `data_dir`.
24+
Checks the availability of mnist dataset. If not present downloads it into MNIST folder in `data_dir`.
2525
svhn(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
2626
Download the SVHN dataset and save it as an HDF5 file to `data_dir`.
2727
usps(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
@@ -42,6 +42,10 @@ class Downloader:
4242
"""
4343

4444
def mnist(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
45+
"""
46+
Check the availability of mnist dataset. If not present downloads it into MNIST folder in `data_dir`.
47+
"""
48+
4549
def _chech_is_downloaded(path: Path) -> bool:
4650
path = path / "MNIST"
4751
if path.exists():

CollaborativeCoding/dataloaders/mnist_0_3.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010
class MNISTDataset0_3(Dataset):
1111
"""
1212
A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3.
13-
Parameters
13+
14+
Args
1415
----------
1516
data_path : Path
16-
The root directory where the MNIST data is stored.
17+
The root directory where the MNIST folder with data is stored.
1718
sample_ids : list
1819
A list of indices specifying which samples to load.
1920
train : bool, optional
2021
If True, load training data, otherwise load test data. Default is False.
2122
transform : callable, optional
2223
A function/transform to apply to the images. Default is None.
24+
2325
Attributes
2426
----------
25-
data_path : Path
26-
The root directory where the MNIST data is stored.
2727
mnist_path : Path
2828
The directory where the MNIST dataset is located within the root directory.
2929
idx : list
@@ -40,6 +40,7 @@ class MNISTDataset0_3(Dataset):
4040
The path to the label file (train or test) based on the `train` flag.
4141
length : int
4242
The number of samples in the dataset.
43+
4344
Methods
4445
-------
4546
__len__()
@@ -58,8 +59,7 @@ def __init__(
5859
):
5960
super().__init__()
6061

61-
self.data_path = data_path
62-
self.mnist_path = self.data_path / "MNIST"
62+
self.mnist_path = data_path / "MNIST"
6363
self.idx = sample_ids
6464
self.train = train
6565
self.transform = transform

CollaborativeCoding/load_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def filter_labels(samples: list, wanted_labels: list) -> list:
1717

1818
def load_data(dataset: str, *args, **kwargs) -> tuple:
1919
"""
20-
load the dataset based on the dataset name.
20+
Load the dataset based on the dataset name.
2121
2222
Args
2323
----

CollaborativeCoding/load_metric.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,32 @@ class MetricWrapper(nn.Module):
99
This class allows you to compute several metrics simultaneously on given
1010
true and predicted labels. It supports a variety of common metrics and
1111
provides methods to accumulate results and reset the state.
12+
1213
Args
1314
----
1415
num_classes : int
1516
The number of classes in the classification task.
1617
metrics : list[str]
1718
A list of metric names to be evaluated.
19+
macro_averaging : bool
20+
Whether to compute macro-averaged metrics for multi-class classification.
21+
1822
Attributes
1923
----------
2024
metrics : dict
2125
A dictionary mapping metric names to their corresponding functions.
2226
num_classes : int
2327
The number of classes for the classification task.
28+
2429
Methods
2530
-------
2631
__call__(y_true, y_pred)
27-
Computes the specified metrics on the provided true and predicted labels.
32+
Passes the true and predicted labels to the metric functions.
2833
getmetrics(str_prefix: str = None)
29-
Retrieves the computed metrics, optionally prefixed with a string.
34+
Retrieves the dictionary of computed metrics, optionally all keys can be prefixed with a string.
3035
resetmetric()
3136
Resets the state of all metric computations.
37+
3238
Examples
3339
--------
3440
>>> from CollaborativeCoding import MetricWrapperProposed

CollaborativeCoding/metrics/accuracy.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,45 @@
44

55

66
class Accuracy(nn.Module):
7+
"""
8+
Computes the accuracy of a model's predictions.
9+
10+
Args
11+
----------
12+
num_classes : int
13+
The number of classes in the classification task.
14+
macro_averaging : bool, optional
15+
If True, computes macro-average accuracy. Otherwise, computes micro-average accuracy. Default is False.
16+
17+
18+
Methods
19+
-------
20+
forward(y_true, y_pred)
21+
Stores the true and predicted labels. Typically called for each batch during the forward pass of a model.
22+
_macro_acc()
23+
Computes the macro-average accuracy.
24+
_micro_acc()
25+
Computes the micro-average accuracy.
26+
__returnmetric__()
27+
Returns the computed accuracy based on the averaging method for all stored predictions.
28+
__reset__()
29+
Resets the stored true and predicted labels.
30+
31+
Examples
32+
--------
33+
>>> y_true = torch.tensor([0, 1, 2, 3, 3])
34+
>>> y_pred = torch.tensor([0, 1, 2, 3, 0])
35+
>>> accuracy = Accuracy(num_classes=4)
36+
>>> accuracy(y_true, y_pred)
37+
>>> accuracy.__returnmetric__()
38+
0.8
39+
>>> accuracy.__reset__()
40+
>>> accuracy.macro_averaging = True
41+
>>> accuracy(y_true, y_pred)
42+
>>> accuracy.__returnmetric__()
43+
0.875
44+
"""
45+
746
def __init__(self, num_classes, macro_averaging=False):
847
super().__init__()
948
self.num_classes = num_classes
@@ -13,19 +52,14 @@ def __init__(self, num_classes, macro_averaging=False):
1352

1453
def forward(self, y_true, y_pred):
1554
"""
16-
Compute the accuracy of the model.
55+
Store the true and predicted labels.
1756
1857
Parameters
1958
----------
2059
y_true : torch.Tensor
2160
True labels.
2261
y_pred : torch.Tensor
23-
Predicted labels.
24-
25-
Returns
26-
-------
27-
float
28-
Accuracy score.
62+
Predicted labels. Either a 1D tensor of shape (batch_size,) or a 2D tensor of shape (batch_size, num_classes).
2963
"""
3064
if y_pred.dim() > 1:
3165
y_pred = y_pred.argmax(dim=1)
@@ -34,14 +68,7 @@ def forward(self, y_true, y_pred):
3468

3569
def _macro_acc(self):
3670
"""
37-
Compute the macro-average accuracy.
38-
39-
Parameters
40-
----------
41-
y_true : torch.Tensor
42-
True labels.
43-
y_pred : torch.Tensor
44-
Predicted labels.
71+
Compute the macro-average accuracy on the stored predictions.
4572
4673
Returns
4774
-------
@@ -63,14 +90,7 @@ def _macro_acc(self):
6390

6491
def _micro_acc(self):
6592
"""
66-
Compute the micro-average accuracy.
67-
68-
Parameters
69-
----------
70-
y_true : torch.Tensor
71-
True labels.
72-
y_pred : torch.Tensor
73-
Predicted labels.
93+
Compute the micro-average accuracy on the stored predictions.
7494
7595
Returns
7696
-------
@@ -80,6 +100,14 @@ def _micro_acc(self):
80100
return (self.y_true == self.y_pred).float().mean().item()
81101

82102
def __returnmetric__(self):
103+
"""
104+
Return the computed accuracy based on the averaging method for all stored predictions.
105+
106+
Returns
107+
-------
108+
float
109+
Computed accuracy score.
110+
"""
83111
if self.y_true == [] or self.y_pred == []:
84112
return np.nan
85113
if isinstance(self.y_true, list):
@@ -92,6 +120,9 @@ def __returnmetric__(self):
92120
return self._micro_acc() if not self.macro_averaging else self._macro_acc()
93121

94122
def __reset__(self):
123+
"""
124+
Reset the stored true and predicted labels.
125+
"""
95126
self.y_true = []
96127
self.y_pred = []
97128
return None

CollaborativeCoding/models/jan_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44

55
class JanModel(nn.Module):
6-
"""A simple MLP network model for image classification tasks.
6+
"""A simple MLP network model for image classification tasks. Two hidden layers with 100 neurons.
77
88
Args
99
----
10-
in_channels : int
11-
Number of input channels.
10+
image_shape : tuple(int, int, int)
11+
Shape of the input image (C, H, W).
1212
num_classes : int
1313
Number of classes in the dataset.
1414

main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
import torch as th
33
import torch.nn as nn
4-
import wandb
54
from torch.utils.data import DataLoader
65
from torchvision import transforms
76
from tqdm import tqdm
87

8+
import wandb
99
from CollaborativeCoding import (
1010
MetricWrapper,
1111
createfolders,
@@ -132,6 +132,7 @@ def main():
132132
wandb.init(
133133
entity="ColabCode",
134134
project=args.run_name,
135+
dir=args.resultfolder,
135136
tags=[args.modelname, args.dataset],
136137
config=args,
137138
)
@@ -178,6 +179,9 @@ def main():
178179
train_metrics.resetmetric()
179180
val_metrics.resetmetric()
180181

182+
if args.savemodel:
183+
th.save(model, args.modelfolder / f"{args.modelname}_run:{args.run_name}.pth")
184+
181185
testloss = []
182186
model.eval()
183187
with th.no_grad():

0 commit comments

Comments
 (0)