Skip to content

Commit fa89423

Browse files
committed
...
1 parent a29a9a2 commit fa89423

File tree

5 files changed

+24
-24
lines changed

5 files changed

+24
-24
lines changed

test.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

utils/arg_parser.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def get_args():
3535

3636
parser.add_argument(
3737
"--download-data",
38-
action="store_true",
38+
type=bool,
39+
default=False,
3940
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
4041
)
4142

utils/dataloaders/svhn.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,46 +7,50 @@
77

88
class SVHNDataset(Dataset):
99
def __init__(
10-
self, datapath: str,
11-
transforms=None,
12-
download_data=True,
10+
self,
11+
data_path: str,
12+
train: bool,
13+
transform=None,
14+
download:bool=True,
1315
nr_channels=3
1416
):
1517
"""
1618
Initializes the SVHNDataset object.
1719
Args:
18-
datapath (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
20+
data_path (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
1921
transforms: Torch composite of transformations which are to be applied to the dataset images.
20-
download_data (bool): If True, downloads the dataset to the specified datapath.
22+
download_data (bool): If True, downloads the dataset to the specified data_path.
2123
split (str): The dataset split to use, either 'train' or 'test'.
2224
Raises:
2325
AssertionError: If the split is not 'train' or 'test'.
2426
"""
2527
super().__init__()
2628
# assert split == "train" or split == "test"
29+
self.split = 'train' if train else 'test'
30+
31+
if download:
32+
self._download_data(data_path)
2733

28-
if download_data:
29-
self._download_data(datapath)
30-
31-
data = loadmat(os.path.join(datapath, f"train_32x32.mat"))
34+
data = loadmat(os.path.join(data_path, f"{self.split}_32x32.mat"))
3235

3336
# Images on the form N x H x W x C
3437
self.images = data["X"].transpose(3, 1, 0, 2)
3538
self.labels = data["y"].flatten()
3639
self.labels[self.labels == 10] = 0
3740

3841
self.nr_channels = nr_channels
39-
self.transforms = transforms
42+
self.transforms = transform
4043

41-
def _download_data(self, path: str, split: str):
44+
def _download_data(self, path: str):
4245
"""
4346
Downloads the SVHN dataset.
4447
Args:
4548
path (str): The directory where the dataset will be downloaded.
4649
split (str): The dataset split to download, either 'train' or 'test'.
4750
"""
4851
print(f"Downloading SVHN data into {path}")
49-
SVHN(path, split='train', download=True)
52+
53+
SVHN(path, split=self.split, download=True)
5054

5155
def __len__(self):
5256
"""

utils/load_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
4141
case "usps_7-9":
4242
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
4343
case "svhn":
44-
raise SVHNDataset(*args, **kwargs)
44+
return SVHNDataset(*args, **kwargs)
4545
case "mnist_4-9":
4646
raise NotImplementedError("MNIST 4-9 dataset not yet implemented.")
4747
case _:

utils/models/magnus_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,24 @@
22

33

44
class MagnusModel(nn.Module):
5-
def __init__(self, imagesize: int, imagechannels: int, n_classes: int = 10):
5+
def __init__(self, image_shape: int, num_classes: int, imagechannels: int):
66
"""
77
Magnus model contains the model for Magnus' part of the homeexam.
88
This class contains a neural network consisting of three linear layers of 133 neurons each,
99
with ReLU activation between each layer.
1010
1111
Args
1212
----
13-
imagesize (int): Expected size of input image. This is needed to scale first layer input
13+
image_shape (int): Expected size of input image. This is needed to scale first layer input
1414
imagechannels (int): Expected number of image channels. This is needed to scale first layer input
15-
n_classes (int): Number of classes we are to provide.
15+
num_classes (int): Number of classes we are to provide.
1616
1717
Returns
1818
-------
1919
MagnusModel (nn.Module): Neural network as described above in this docstring.
2020
"""
2121
super().__init__()
22-
self.imagesize = imagesize
22+
self.image_shape = image_shape
2323
self.imagechannels = imagechannels
2424

2525
self.layer1 = nn.Sequential(*([
@@ -31,7 +31,7 @@ def __init__(self, imagesize: int, imagechannels: int, n_classes: int = 10):
3131
nn.ReLU()
3232
]))
3333
self.layer3 = nn.Sequential(*([
34-
nn.Linear(133, n_classes),
34+
nn.Linear(133, num_classes),
3535
nn.ReLU()
3636
]))
3737

0 commit comments

Comments
 (0)