11import os
2-
2+ import numpy as np
33from scipy .io import loadmat
44from torch .utils .data import Dataset
55from torchvision .datasets import SVHN
@@ -10,7 +10,7 @@ def __init__(
1010 self , datapath : str ,
1111 transforms = None ,
1212 download_data = True ,
13- split = "train"
13+ nr_channels = 3
1414 ):
1515 """
1616 Initializes the SVHNDataset object.
@@ -23,18 +23,19 @@ def __init__(
2323 AssertionError: If the split is not 'train' or 'test'.
2424 """
2525 super ().__init__ ()
26- assert split == "train" or split == "test"
26+ # assert split == "train" or split == "test"
2727
2828 if download_data :
29- self ._download_data (datapath , split )
29+ self ._download_data (datapath )
3030
31- data = loadmat (os .path .join (datapath , f"{ split } _32x32 .mat" ))
31+ data = loadmat (os .path .join (datapath , f"train_32x32 .mat" ))
3232
3333 # Images on the form N x H x W x C
3434 self .images = data ["X" ].transpose (3 , 1 , 0 , 2 )
3535 self .labels = data ["y" ].flatten ()
3636 self .labels [self .labels == 10 ] = 0
37-
37+
38+ self .nr_channels = nr_channels
3839 self .transforms = transforms
3940
4041 def _download_data (self , path : str , split : str ):
@@ -45,7 +46,7 @@ def _download_data(self, path: str, split: str):
4546 split (str): The dataset split to download, either 'train' or 'test'.
4647 """
4748 print (f"Downloading SVHN data into { path } " )
48- SVHN (path , split = split , download = True )
49+ SVHN (path , split = 'train' , download = True )
4950
5051 def __len__ (self ):
5152 """
@@ -65,6 +66,9 @@ def __getitem__(self, index):
6566 """
6667 img , lab = self .images [index ], self .labels [index ]
6768
69+ if self .nr_channels == 1 :
70+ img = np .mean (img , axis = 2 , keepdims = True )
71+
6872 if self .transforms is not None :
6973 img = self .transforms (img )
7074
0 commit comments