Skip to content

Commit 0b01aee

Browse files
authored
Merge pull request #102 from SFI-Visual-Intelligence/johan/devbranch
Doc update
2 parents 5a50a07 + efa032e commit 0b01aee

File tree

4 files changed

+35
-24
lines changed

4 files changed

+35
-24
lines changed

CollaborativeCoding/dataloaders/mnist_4_9.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ class MNISTDataset4_9(Dataset):
1919
Array of indices spcifying which samples to load. This determines the samples used by the dataloader.
2020
train : bool, optional
2121
Whether to train the model or not, by default False
22+
transorm : callable, optional
23+
Transform to apply to the images, by default None
24+
nr_channels : int, optional
25+
Number of channels in the images, by default 1
2226
"""
2327

2428
def __init__(

CollaborativeCoding/metrics/precision.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55

66
class Precision(nn.Module):
7-
"""Metric module for precision. Can calculate precision both as a mean of precisions or as brute function of true positives and false positives.
7+
"""Metric module for precision. Can calculate both the micro- and macro-averaged precision.
88
99
Parameters
1010
----------
1111
num_classes : int
1212
Number of classes in the dataset.
13-
micro_averaging : bool
14-
Wheter to compute the micro or macro precision (default False)
13+
macro_averaging : bool
14+
Performs macro-averaging if True, otherwise micro-averaging.
1515
"""
1616

1717
def __init__(self, num_classes: int, macro_averaging: bool = False):
@@ -23,19 +23,15 @@ def __init__(self, num_classes: int, macro_averaging: bool = False):
2323
self.y_pred = []
2424

2525
def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
26-
"""Compute precision of model
26+
"""Add true and predicted values to the class-global lists.
2727
2828
Parameters
2929
----------
3030
y_true : torch.tensor
3131
True labels
32-
y_pred : torch.tensor
32+
logits : torch.tensor
3333
Predicted labels
3434
35-
Returns
36-
-------
37-
torch.tensor
38-
Precision score
3935
"""
4036
y_pred = logits.argmax(dim=-1)
4137

@@ -100,6 +96,13 @@ def _macro_avg_precision(
10096
return torch.nanmean(tp / (tp + fp))
10197

10298
def __returnmetric__(self):
99+
"""Return the micro- or macro-averaged precision.
100+
101+
Returns
102+
-------
103+
torch.tensor
104+
Micro- or macro-averaged precision
105+
"""
103106
if self.y_true == [] and self.y_pred == []:
104107
return np.nan
105108
elif self.y_true == [] or self.y_pred == []:

CollaborativeCoding/models/johan_model.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,32 @@
44
Multi-layer perceptron model for image classification.
55
"""
66

7-
# class NeuronLayer(nn.Module):
8-
# def __init__(self, in_features, out_features):
9-
# super().__init__()
10-
11-
# self.fc = nn.Linear(in_features, out_features)
12-
# self.relu = nn.ReLU()
13-
14-
# def forward(self, x):
15-
# x = self.fc(x)
16-
# x = self.relu(x)
17-
# return x
18-
197

208
class JohanModel(nn.Module):
219
"""Small MLP model for image classification.
2210
2311
Parameters
2412
----------
25-
in_features : int
26-
Numer of input features.
13+
image_shape : tuple(int, int, int)
14+
Shape of the input image (C, H, W).
2715
num_classes : int
2816
Number of classes in the dataset.
17+
18+
Processing Images
19+
-----------------
20+
Input: (N, C, H, W)
21+
N: Batch size
22+
C: Number of input channels
23+
H: Height of the input image
24+
W: Width of the input image
25+
26+
Example:
27+
Grayscale images (like MNIST) have C = 1.
28+
Input shape: (N, 1, 28, 28)
29+
fc1 Output shape: (N, 77)
30+
fc2 Output shape: (N, 77)
31+
fc3 Output shape: (N, 77)
32+
fc4 Output shape: (N, num_classes)
2933
"""
3034

3135
def __init__(self, image_shape, num_classes):

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
import torch as th
33
import torch.nn as nn
4-
import wandb
54
from torch.utils.data import DataLoader
65
from torchvision import transforms
76
from tqdm import tqdm
87

8+
import wandb
99
from CollaborativeCoding import (
1010
MetricWrapper,
1111
createfolders,

0 commit comments

Comments
 (0)