Skip to content

Commit b7bffa3

Browse files
committed
ruffisorted :'(
1 parent ba2212e commit b7bffa3

File tree

5 files changed

+32
-36
lines changed

5 files changed

+32
-36
lines changed

utils/arg_parser.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,15 @@ def get_args():
6868
nargs="+",
6969
help="Which metric to use for evaluation",
7070
)
71-
72-
parser.add_argument(
73-
'--imagesize',
74-
type=int,
75-
default=28,
76-
help='Imagesize'
77-
)
78-
71+
72+
parser.add_argument("--imagesize", type=int, default=28, help="Imagesize")
73+
7974
parser.add_argument(
80-
'--nr_channels',
75+
"--nr_channels",
8176
type=int,
8277
default=1,
83-
choices=[1,3],
84-
help='Number of image channels'
78+
choices=[1, 3],
79+
help="Number of image channels",
8580
)
8681

8782
# Training specific values

utils/dataloaders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@
88

99
from .download import Downloader
1010
from .mnist_0_3 import MNISTDataset0_3
11+
from .svhn import SVHNDataset
1112
from .usps_0_6 import USPSDataset0_6
1213
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
13-
from .svhn import SVHNDataset

utils/dataloaders/svhn.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
23
import numpy as np
34
from scipy.io import loadmat
45
from torch.utils.data import Dataset
@@ -7,13 +8,13 @@
78

89
class SVHNDataset(Dataset):
910
def __init__(
10-
self,
11-
data_path: str,
11+
self,
12+
data_path: str,
1213
train: bool,
13-
transform=None,
14-
download:bool=True,
15-
nr_channels=3
16-
):
14+
transform=None,
15+
download: bool = True,
16+
nr_channels=3,
17+
):
1718
"""
1819
Initializes the SVHNDataset object.
1920
Args:
@@ -26,8 +27,8 @@ def __init__(
2627
"""
2728
super().__init__()
2829
# assert split == "train" or split == "test"
29-
self.split = 'train' if train else 'test'
30-
30+
self.split = "train" if train else "test"
31+
3132
if download:
3233
self._download_data(data_path)
3334

@@ -37,7 +38,7 @@ def __init__(
3738
self.images = data["X"].transpose(3, 1, 0, 2)
3839
self.labels = data["y"].flatten()
3940
self.labels[self.labels == 10] = 0
40-
41+
4142
self.nr_channels = nr_channels
4243
self.transforms = transform
4344

@@ -49,7 +50,7 @@ def _download_data(self, path: str):
4950
split (str): The dataset split to download, either 'train' or 'test'.
5051
"""
5152
print(f"Downloading SVHN data into {path}")
52-
53+
5354
SVHN(path, split=self.split, download=True)
5455

5556
def __len__(self):
@@ -72,7 +73,7 @@ def __getitem__(self, index):
7273

7374
if self.nr_channels == 1:
7475
img = np.mean(img, axis=2, keepdims=True)
75-
76+
7677
if self.transforms is not None:
7778
img = self.transforms(img)
7879

utils/load_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from .dataloaders import (
55
Downloader,
66
MNISTDataset0_3,
7+
SVHNDataset,
78
USPSDataset0_6,
89
USPSH5_Digit_7_9_Dataset,
9-
SVHNDataset,
1010
)
1111

1212

utils/models/magnus_model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ def __init__(self, image_shape: int, num_classes: int, imagechannels: int):
2222
self.image_shape = image_shape
2323
self.imagechannels = imagechannels
2424

25-
self.layer1 = nn.Sequential(*([
26-
nn.Linear(self.imagechannels * self.imagesize * self.imagesize, 133),
27-
nn.ReLU(),
28-
]))
29-
self.layer2 = nn.Sequential(*([
30-
nn.Linear(133, 133),
31-
nn.ReLU()
32-
]))
33-
self.layer3 = nn.Sequential(*([
34-
nn.Linear(133, num_classes),
35-
nn.ReLU()
36-
]))
25+
self.layer1 = nn.Sequential(
26+
*(
27+
[
28+
nn.Linear(
29+
self.imagechannels * self.imagesize * self.imagesize, 133
30+
),
31+
nn.ReLU(),
32+
]
33+
)
34+
)
35+
self.layer2 = nn.Sequential(*([nn.Linear(133, 133), nn.ReLU()]))
36+
self.layer3 = nn.Sequential(*([nn.Linear(133, num_classes), nn.ReLU()]))
3737

3838
def forward(self, x):
3939
"""

0 commit comments

Comments
 (0)