Skip to content

Commit a29a9a2

Browse files
committed
Added some arguments
1 parent 2376196 commit a29a9a2

File tree

3 files changed

+31
-7
lines changed

3 files changed

+31
-7
lines changed

test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import numpy as np
2+
3+
a = np.random.rand(28,28,3)
4+
a = np.mean(a, axis=2, keepdims=True)
5+
print(a.shape)

utils/arg_parser.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,21 @@ def get_args():
6969
nargs="+",
7070
help="Which metric to use for evaluation",
7171
)
72+
73+
parser.add_argument(
74+
'--imagesize',
75+
type=int,
76+
default=28,
77+
help='Imagesize'
78+
)
79+
80+
parser.add_argument(
81+
'--nr_channels',
82+
type=int,
83+
default=1,
84+
choices=[1,3],
85+
help='Number of image channels'
86+
)
7287

7388
# Training specific values
7489
parser.add_argument(

utils/dataloaders/svhn.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
2+
import numpy as np
33
from scipy.io import loadmat
44
from torch.utils.data import Dataset
55
from torchvision.datasets import SVHN
@@ -10,7 +10,7 @@ def __init__(
1010
self, datapath: str,
1111
transforms=None,
1212
download_data=True,
13-
split="train"
13+
nr_channels=3
1414
):
1515
"""
1616
Initializes the SVHNDataset object.
@@ -23,18 +23,19 @@ def __init__(
2323
AssertionError: If the split is not 'train' or 'test'.
2424
"""
2525
super().__init__()
26-
assert split == "train" or split == "test"
26+
# assert split == "train" or split == "test"
2727

2828
if download_data:
29-
self._download_data(datapath, split)
29+
self._download_data(datapath)
3030

31-
data = loadmat(os.path.join(datapath, f"{split}_32x32.mat"))
31+
data = loadmat(os.path.join(datapath, f"train_32x32.mat"))
3232

3333
# Images on the form N x H x W x C
3434
self.images = data["X"].transpose(3, 1, 0, 2)
3535
self.labels = data["y"].flatten()
3636
self.labels[self.labels == 10] = 0
37-
37+
38+
self.nr_channels = nr_channels
3839
self.transforms = transforms
3940

4041
def _download_data(self, path: str, split: str):
@@ -45,7 +46,7 @@ def _download_data(self, path: str, split: str):
4546
split (str): The dataset split to download, either 'train' or 'test'.
4647
"""
4748
print(f"Downloading SVHN data into {path}")
48-
SVHN(path, split=split, download=True)
49+
SVHN(path, split='train', download=True)
4950

5051
def __len__(self):
5152
"""
@@ -65,6 +66,9 @@ def __getitem__(self, index):
6566
"""
6667
img, lab = self.images[index], self.labels[index]
6768

69+
if self.nr_channels == 1:
70+
img = np.mean(img, axis=2, keepdims=True)
71+
6872
if self.transforms is not None:
6973
img = self.transforms(img)
7074

0 commit comments

Comments
 (0)