Skip to content

Commit f654631

Browse files
authored
Merge pull request #55 from SFI-Visual-Intelligence/mag-branch
Dataloader and Main file
2 parents b93ee66 + 8c8435e commit f654631

File tree

4 files changed

+92
-19
lines changed

4 files changed

+92
-19
lines changed

.gitignore

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
__pycache__/
22
.ipynb_checkpoints/
3-
Data/
4-
Results/
5-
Experiments/
3+
Data/*
4+
Results/*
5+
Experiments/*
66
_build/
7-
bin/
8-
wandb/
7+
bin/*
8+
wandb/*
99
wandb_api.py
1010

1111
#Magnus specific

utils/dataloaders/svhn.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,66 @@
11
from torch.utils.data import Dataset
2+
from scipy.io import loadmat
3+
import os
4+
from torchvision.datasets import SVHN
25

3-
4-
class SVHN(Dataset):
5-
def __init__(self):
6+
class SVHNDataset(Dataset):
7+
def __init__(self,
8+
datapath: str,
9+
transforms=None,
10+
download_data=True,
11+
split='train'):
12+
"""
13+
Initializes the SVHNDataset object.
14+
Args:
15+
datapath (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
16+
transforms: Torch composite of transformations which are to be applied to the dataset images.
17+
download_data (bool): If True, downloads the dataset to the specified datapath.
18+
split (str): The dataset split to use, either 'train' or 'test'.
19+
Raises:
20+
AssertionError: If the split is not 'train' or 'test'.
21+
"""
622
super().__init__()
7-
23+
assert split == 'train' or split == 'test'
24+
25+
if download_data:
26+
self._download_data(datapath, split)
27+
28+
data = loadmat(os.path.join(datapath, f'{split}_32x32.mat'))
29+
30+
# Images on the form N x H x W x C
31+
self.images = data['X'].transpose(3, 1, 0, 2)
32+
self.labels = data['y'].flatten()
33+
self.labels[self.labels == 10] = 0
34+
35+
self.transforms = transforms
36+
def _download_data(self, path: str, split: str):
37+
"""
38+
Downloads the SVHN dataset.
39+
Args:
40+
path (str): The directory where the dataset will be downloaded.
41+
split (str): The dataset split to download, either 'train' or 'test'.
42+
"""
43+
print(f'Downloading SVHN data into {path}')
44+
SVHN(path, split=split, download=True)
45+
846
def __len__(self):
9-
return
10-
47+
"""
48+
Returns the number of samples in the dataset.
49+
Returns:
50+
int: The number of samples.
51+
"""
52+
return len(self.labels)
1153
def __getitem__(self, index):
12-
return
54+
"""
55+
Retrieves the image and label at the specified index.
56+
Args:
57+
index (int): The index of the sample to retrieve.
58+
Returns:
59+
tuple: A tuple containing the image and its corresponding label.
60+
"""
61+
img, lab = self.images[index], self.labels[index]
62+
63+
if self.transforms is not None:
64+
img = self.transforms(img)
65+
66+
return img, lab

utils/metrics/EntropyPred.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,32 @@
1+
import numpy as np
12
import torch.nn as nn
3+
from scipy.stats import entropy
24

35

46
class EntropyPrediction(nn.Module):
5-
def __init__(self):
7+
def __init__(self, averages: str = 'average'):
8+
"""
9+
Initializes the EntropyPrediction module.
10+
Args:
11+
averages (str): Specifies the method of aggregation for entropy values.
12+
Must be either 'average' or 'sum'.
13+
Raises:
14+
AssertionError: If the averages parameter is not 'average' or 'sum'.
15+
"""
616
super().__init__()
7-
17+
18+
assert averages == 'average' or averages == 'sum'
19+
self.averages = averages
20+
self.stored_entropy_values = []
21+
822
def __call__(self, y_true, y_false_logits):
9-
return
10-
11-
def __reset__(self):
12-
pass
23+
"""
24+
Computes the entropy between true labels and predicted logits, storing the results.
25+
Args:
26+
y_true: The true labels.
27+
y_false_logits: The predicted logits.
28+
Side Effects:
29+
Appends the computed entropy values to the stored_entropy_values list.
30+
"""
31+
entropy_values = entropy(y_true, qk=y_false_logits)
32+
return entropy_values

utils/models/magnus_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def __init__(self,
2323
MagnusModel (nn.Module): Neural network as described above in this docstring.
2424
"""
2525

26-
2726
super().__init__()
2827
self.imagesize = imagesize
2928
self.imagechannels = imagechannels

0 commit comments

Comments
 (0)