Skip to content

Commit 970e189

Browse files
committed
Formatting
1 parent 75b1801 commit 970e189

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

utils/dataloaders/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3", "SVHNDataset"]
1+
__all__ = [
2+
"USPSDataset0_6",
3+
"USPSH5_Digit_7_9_Dataset",
4+
"MNISTDataset0_3",
5+
"SVHNDataset",
6+
]
27

38
from .mnist_0_3 import MNISTDataset0_3
9+
from .svhn import SVHNDataset
410
from .usps_0_6 import USPSDataset0_6
511
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
6-
from .svhn import SVHNDataset

utils/dataloaders/svhn.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
23
import numpy as np
34
from scipy.io import loadmat
45
from torch.utils.data import Dataset
@@ -7,13 +8,13 @@
78

89
class SVHNDataset(Dataset):
910
def __init__(
10-
self,
11-
data_path: str,
11+
self,
12+
data_path: str,
1213
train: bool,
13-
transform=None,
14-
download:bool=True,
15-
nr_channels=3
16-
):
14+
transform=None,
15+
download: bool = True,
16+
nr_channels=3,
17+
):
1718
"""
1819
Initializes the SVHNDataset object.
1920
Args:
@@ -26,8 +27,8 @@ def __init__(
2627
"""
2728
super().__init__()
2829
# assert split == "train" or split == "test"
29-
self.split = 'train' if train else 'test'
30-
30+
self.split = "train" if train else "test"
31+
3132
if download:
3233
self._download_data(data_path)
3334

@@ -37,7 +38,7 @@ def __init__(
3738
self.images = data["X"].transpose(3, 1, 0, 2)
3839
self.labels = data["y"].flatten()
3940
self.labels[self.labels == 10] = 0
40-
41+
4142
self.nr_channels = nr_channels
4243
self.transforms = transform
4344

@@ -49,7 +50,7 @@ def _download_data(self, path: str):
4950
split (str): The dataset split to download, either 'train' or 'test'.
5051
"""
5152
print(f"Downloading SVHN data into {path}")
52-
53+
5354
SVHN(path, split=self.split, download=True)
5455

5556
def __len__(self):
@@ -72,7 +73,7 @@ def __getitem__(self, index):
7273

7374
if self.nr_channels == 1:
7475
img = np.mean(img, axis=2, keepdims=True)
75-
76+
7677
if self.transforms is not None:
7778
img = self.transforms(img)
7879

utils/load_data.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from torch.utils.data import Dataset
22

3-
from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset, SVHNDataset
3+
from .dataloaders import (
4+
MNISTDataset0_3,
5+
SVHNDataset,
6+
USPSDataset0_6,
7+
USPSH5_Digit_7_9_Dataset,
8+
)
49

510

611
def load_data(dataset: str, *args, **kwargs) -> Dataset:

0 commit comments

Comments
 (0)