Skip to content

Commit 75b1801

Browse files
authored
Merge pull request #61 from SFI-Visual-Intelligence/mag-branch
Mondays sync
2 parents 2ac02eb + fa89423 commit 75b1801

File tree

4 files changed

+47
-23
lines changed

4 files changed

+47
-23
lines changed

utils/arg_parser.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def get_args():
3535

3636
parser.add_argument(
3737
"--download-data",
38-
action="store_true",
38+
type=bool,
39+
default=False,
3940
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
4041
)
4142

@@ -69,6 +70,21 @@ def get_args():
6970
nargs="+",
7071
help="Which metric to use for evaluation",
7172
)
73+
74+
parser.add_argument(
75+
'--imagesize',
76+
type=int,
77+
default=28,
78+
help='Imagesize'
79+
)
80+
81+
parser.add_argument(
82+
'--nr_channels',
83+
type=int,
84+
default=1,
85+
choices=[1,3],
86+
help='Number of image channels'
87+
)
7288

7389
# Training specific values
7490
parser.add_argument(

utils/dataloaders/svhn.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,56 @@
11
import os
2-
2+
import numpy as np
33
from scipy.io import loadmat
44
from torch.utils.data import Dataset
55
from torchvision.datasets import SVHN
66

77

88
class SVHNDataset(Dataset):
99
def __init__(
10-
self, datapath: str,
11-
transforms=None,
12-
download_data=True,
13-
split="train"
10+
self,
11+
data_path: str,
12+
train: bool,
13+
transform=None,
14+
download:bool=True,
15+
nr_channels=3
1416
):
1517
"""
1618
Initializes the SVHNDataset object.
1719
Args:
18-
datapath (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
20+
data_path (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
1921
transforms: Torch composite of transformations which are to be applied to the dataset images.
20-
download_data (bool): If True, downloads the dataset to the specified datapath.
22+
download_data (bool): If True, downloads the dataset to the specified data_path.
2123
split (str): The dataset split to use, either 'train' or 'test'.
2224
Raises:
2325
AssertionError: If the split is not 'train' or 'test'.
2426
"""
2527
super().__init__()
26-
assert split == "train" or split == "test"
27-
28-
if download_data:
29-
self._download_data(datapath, split)
28+
# assert split == "train" or split == "test"
29+
self.split = 'train' if train else 'test'
30+
31+
if download:
32+
self._download_data(data_path)
3033

31-
data = loadmat(os.path.join(datapath, f"{split}_32x32.mat"))
34+
data = loadmat(os.path.join(data_path, f"{self.split}_32x32.mat"))
3235

3336
# Images on the form N x H x W x C
3437
self.images = data["X"].transpose(3, 1, 0, 2)
3538
self.labels = data["y"].flatten()
3639
self.labels[self.labels == 10] = 0
40+
41+
self.nr_channels = nr_channels
42+
self.transforms = transform
3743

38-
self.transforms = transforms
39-
40-
def _download_data(self, path: str, split: str):
44+
def _download_data(self, path: str):
4145
"""
4246
Downloads the SVHN dataset.
4347
Args:
4448
path (str): The directory where the dataset will be downloaded.
4549
split (str): The dataset split to download, either 'train' or 'test'.
4650
"""
4751
print(f"Downloading SVHN data into {path}")
48-
SVHN(path, split=split, download=True)
52+
53+
SVHN(path, split=self.split, download=True)
4954

5055
def __len__(self):
5156
"""
@@ -65,6 +70,9 @@ def __getitem__(self, index):
6570
"""
6671
img, lab = self.images[index], self.labels[index]
6772

73+
if self.nr_channels == 1:
74+
img = np.mean(img, axis=2, keepdims=True)
75+
6876
if self.transforms is not None:
6977
img = self.transforms(img)
7078

utils/load_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
4141
case "usps_7-9":
4242
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
4343
case "svhn":
44-
raise SVHNDataset(*args, **kwargs)
44+
return SVHNDataset(*args, **kwargs)
4545
case "mnist_4-9":
4646
raise NotImplementedError("MNIST 4-9 dataset not yet implemented.")
4747
case _:

utils/models/magnus_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,24 @@
22

33

44
class MagnusModel(nn.Module):
5-
def __init__(self, imagesize: int, imagechannels: int, n_classes: int = 10):
5+
def __init__(self, image_shape: int, num_classes: int, imagechannels: int):
66
"""
77
Magnus model contains the model for Magnus' part of the homeexam.
88
This class contains a neural network consisting of three linear layers of 133 neurons each,
99
with ReLU activation between each layer.
1010
1111
Args
1212
----
13-
imagesize (int): Expected size of input image. This is needed to scale first layer input
13+
image_shape (int): Expected size of input image. This is needed to scale first layer input
1414
imagechannels (int): Expected number of image channels. This is needed to scale first layer input
15-
n_classes (int): Number of classes we are to provide.
15+
num_classes (int): Number of classes we are to provide.
1616
1717
Returns
1818
-------
1919
MagnusModel (nn.Module): Neural network as described above in this docstring.
2020
"""
2121
super().__init__()
22-
self.imagesize = imagesize
22+
self.image_shape = image_shape
2323
self.imagechannels = imagechannels
2424

2525
self.layer1 = nn.Sequential(*([
@@ -31,7 +31,7 @@ def __init__(self, imagesize: int, imagechannels: int, n_classes: int = 10):
3131
nn.ReLU()
3232
]))
3333
self.layer3 = nn.Sequential(*([
34-
nn.Linear(133, n_classes),
34+
nn.Linear(133, num_classes),
3535
nn.ReLU()
3636
]))
3737

0 commit comments

Comments
 (0)