Skip to content

Commit 78181b9

Browse files
authored
Merge pull request #91 from SFI-Visual-Intelligence/mag-branch
Mag branch
2 parents 85ac780 + 133f3bf commit 78181b9

File tree

15 files changed

+163
-64
lines changed

15 files changed

+163
-64
lines changed

.gitignore

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__pycache__/
22
.ipynb_checkpoints/
33
Data/*
4+
data/*
45
Results/*
56
Experiments/*
67
_build/
@@ -14,9 +15,7 @@ doc/autoapi
1415

1516
#Magnus specific
1617
job*
17-
env2/*
18-
ruffian.sh
19-
localtest.sh
18+
local*
2019

2120
# Johanthings
2221
formatting.x

CollaborativeCoding/dataloaders/download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _get_labels(path: Path) -> np.ndarray:
8888

8989
def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
9090
def download_svhn(path, train: bool = True):
91-
SVHN()
91+
SVHN(path, split="train" if train else "test", download=True)
9292

9393
parent_path = data_dir / "SVHN"
9494

CollaborativeCoding/dataloaders/mnist_0_3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22

33
import numpy as np
4+
from PIL import Image
45
from torch.utils.data import Dataset
56

67
from .datasources import MNIST_SOURCE
@@ -87,7 +88,8 @@ def __getitem__(self, index):
8788
28, 28
8889
) # Read image data
8990

90-
image = np.expand_dims(image, axis=0) # Add channel dimension
91+
# image = np.expand_dims(image, axis=0) # Add channel dimension
92+
image = Image.fromarray(image.astype(np.uint8))
9193

9294
if self.transform:
9395
image = self.transform(image)

CollaborativeCoding/dataloaders/mnist_4_9.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22

33
import numpy as np
4+
from PIL import Image
45
from torch.utils.data import Dataset
56

67
from .datasources import MNIST_SOURCE
@@ -28,11 +29,13 @@ def __init__(
2829
transform=None,
2930
nr_channels: int = 1,
3031
):
31-
super.__init__()
32+
super().__init__()
3233
self.data_path = data_path
3334
self.mnist_path = self.data_path / "MNIST"
3435
self.samples = sample_ids
3536
self.train = train
37+
self.transform = transform
38+
self.num_classes = 6
3639

3740
self.images_path = self.mnist_path / (
3841
MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1]
@@ -46,7 +49,7 @@ def __len__(self):
4649

4750
def __getitem__(self, idx):
4851
with open(self.labels_path, "rb") as labelfile:
49-
label_pos = 8 + self.sample[idx]
52+
label_pos = 8 + self.samples[idx]
5053
labelfile.seek(label_pos)
5154
label = int.from_bytes(labelfile.read(1), byteorder="big")
5255

@@ -57,7 +60,8 @@ def __getitem__(self, idx):
5760
28, 28
5861
)
5962

60-
image = np.expand_dims(image, axis=0) # Channel
63+
# image = np.expand_dims(image, axis=0) # Channel
64+
image = Image.fromarray(image.astype(np.uint8))
6165

6266
if self.transform:
6367
image = self.transform(image)

CollaborativeCoding/dataloaders/usps_0_6.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
sample_ids: list,
8484
train: bool = False,
8585
transform=None,
86+
nr_channels=1,
8687
):
8788
super().__init__()
8889

@@ -91,6 +92,7 @@ def __init__(
9192
self.transform = transform
9293
self.mode = "train" if train else "test"
9394
self.sample_ids = sample_ids
95+
self.nr_channels = nr_channels
9496

9597
def __len__(self):
9698
return len(self.sample_ids)
@@ -100,11 +102,18 @@ def __getitem__(self, id):
100102

101103
with h5.File(self.filepath, "r") as f:
102104
data = f[self.mode]["data"][index].astype(np.uint8)
103-
label = f[self.mode]["target"][index]
105+
label = int(f[self.mode]["target"][index])
104106

105-
data = Image.fromarray(data, mode="L")
107+
if self.nr_channels == 1:
108+
data = Image.fromarray(data, mode="L")
109+
elif self.nr_channels == 3:
110+
data = Image.fromarray(data, mode="RGB")
111+
else:
112+
raise ValueError("Invalid number of channels")
106113

107114
if self.transform:
108115
data = self.transform(data)
109116

117+
# label = torch.tensor(label).long()
118+
110119
return data, label

CollaborativeCoding/dataloaders/uspsh5_7_9.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ class USPSH5_Digit_7_9_Dataset(Dataset):
3232
A transform function to apply to the images.
3333
"""
3434

35-
def __init__(self, data_path, sample_ids, train=False, transform=None, nr_channels=1):
35+
def __init__(
36+
self, data_path, sample_ids, train=False, transform=None, nr_channels=1
37+
):
3638
super().__init__()
3739
"""
3840
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
@@ -112,7 +114,8 @@ def main():
112114
indices = np.array([7, 8, 9])
113115
# Load the dataset
114116
dataset = USPSH5_Digit_7_9_Dataset(
115-
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git", sample_ids=indices,
117+
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git",
118+
sample_ids=indices,
116119
train=False,
117120
transform=transform,
118121
)

CollaborativeCoding/load_data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
6464
case "svhn":
6565
dataset = SVHNDataset
6666
train_labels, test_labels = downloader.svhn(data_dir=data_dir)
67-
labels = np.arange(10)
67+
labels = np.unique(train_labels)
6868
case "mnist_4-9":
6969
dataset = MNISTDataset4_9
7070
train_labels, test_labels = downloader.mnist(data_dir=data_dir)
@@ -89,23 +89,23 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
8989
sample_ids=train_samples,
9090
train=True,
9191
transform=transform,
92-
nr_channels=kwargs.get("nr_channels"),
92+
nr_channels=kwargs.get("nr_channels", 1),
9393
)
9494

9595
val = dataset(
9696
data_path=data_dir,
9797
sample_ids=val_samples,
9898
train=True,
9999
transform=transform,
100-
nr_channels=kwargs.get("nr_channels"),
100+
nr_channels=kwargs.get("nr_channels", 1),
101101
)
102102

103103
test = dataset(
104104
data_path=data_dir,
105105
sample_ids=test_samples,
106106
train=False,
107107
transform=transform,
108-
nr_channels=kwargs.get("nr_channels"),
108+
nr_channels=kwargs.get("nr_channels", 1),
109109
)
110110

111111
return train, val, test

CollaborativeCoding/load_metric.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,28 @@ def getmetrics(self, str_prefix: str = None):
9494
def resetmetric(self):
9595
for key in self.metrics:
9696
self.metrics[key].__reset__()
97+
98+
99+
if __name__ == "__main__":
100+
import torch as th
101+
102+
metrics = ["entropy", "f1", "recall", "precision", "accuracy"]
103+
104+
class_sizes = [3, 6, 10]
105+
for class_size in class_sizes:
106+
y_true = th.rand((5, class_size)).argmax(dim=1)
107+
y_pred = th.rand((5, class_size))
108+
109+
metricwrapper = MetricWrapper(
110+
metric,
111+
num_classes=class_size,
112+
macro_averaging=True if class_size % 2 == 0 else False,
113+
)
114+
115+
metricwrapper(y_true, y_pred)
116+
metric = metricwrapper.getmetrics()
117+
assert metric is not None
118+
119+
metricwrapper.resetmetric()
120+
metric2 = metricwrapper.getmetrics()
121+
assert metric != metric2

CollaborativeCoding/metrics/F1.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,13 @@ def __returnmetric__(self):
159159
else:
160160
self.y_true = torch.cat(self.y_true)
161161
self.y_pred = torch.cat(self.y_pred)
162-
return self._micro_F1(self.y_true, self.y_pred) if not self.macro_averaging else self._macro_F1(self.y_true, self.y_pred)
162+
return (
163+
self._micro_F1(self.y_true, self.y_pred)
164+
if not self.macro_averaging
165+
else self._macro_F1(self.y_true, self.y_pred)
166+
)
163167

164168
def __reset__(self):
165169
self.y_true = []
166170
self.y_pred = []
167171
return None
168-
169-

CollaborativeCoding/metrics/recall.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23
import torch.nn as nn
34

@@ -57,26 +58,49 @@ def __init__(self, num_classes, macro_averaging=False):
5758
self.num_classes = num_classes
5859
self.macro_averaging = macro_averaging
5960

61+
self.__y_true = []
62+
self.__y_pred = []
63+
6064
def forward(self, true, logits):
6165
pred = logits.argmax(dim=-1)
6266
y_true = one_hot_encode(true, self.num_classes)
6367
y_pred = one_hot_encode(pred, self.num_classes)
6468

69+
self.__y_true.append(y_true)
70+
self.__y_pred.append(y_pred)
71+
72+
def compute(self, y_true, y_pred):
6573
if self.macro_averaging:
66-
recall = 0
67-
for i in range(self.num_classes):
68-
tp = (y_true[:, i] * y_pred[:, i]).sum()
69-
fn = torch.sum(~y_pred[y_true[:, i].bool()].bool())
70-
recall += tp / (tp + fn)
71-
recall /= self.num_classes
72-
else:
73-
recall = self.__compute(y_true, y_pred)
74+
return self.__compute_macro_averaging(y_true, y_pred)
75+
76+
return self.__compute_micro_averaging(y_true, y_pred)
77+
78+
def __compute_macro_averaging(self, y_true, y_pred):
79+
recall = 0
80+
for i in range(self.num_classes):
81+
tp = (y_true[:, i] * y_pred[:, i]).sum()
82+
fn = torch.sum(~y_pred[y_true[:, i].bool()].bool())
83+
recall += tp / (tp + fn)
84+
recall /= self.num_classes
7485

7586
return recall
7687

77-
def __compute(self, y_true, y_pred):
88+
def __compute_micro_averaging(self, y_true, y_pred):
7889
true_positives = (y_true * y_pred).sum()
7990
false_negatives = torch.sum(~y_pred[y_true.bool()].bool())
8091

8192
recall = true_positives / (true_positives + false_negatives)
8293
return recall
94+
95+
def __returnmetric__(self):
96+
if len(self.__y_true) == 0 and len(self.__y_pred) == 0:
97+
return np.nan
98+
99+
y_true = torch.cat(self.__y_true, dim=0)
100+
y_pred = torch.cat(self.__y_pred, dim=0)
101+
102+
return self.compute(y_true, y_pred)
103+
104+
def __reset__(self):
105+
self.__y_true = []
106+
self.__y_pred = []

0 commit comments

Comments
 (0)