Skip to content

Commit 21e51fd

Browse files
committed
ruffed, isorted
1 parent 536aafb commit 21e51fd

File tree

2 files changed

+28
-26
lines changed

2 files changed

+28
-26
lines changed

utils/dataloaders/svhn.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from torch.utils.data import Dataset
1+
import os
2+
23
from scipy.io import loadmat
3-
import os
4+
from torch.utils.data import Dataset
45
from torchvision.datasets import SVHN
56

7+
68
class SVHNDataset(Dataset):
7-
def __init__(self,
8-
datapath: str,
9-
transforms=None,
10-
download_data=True,
11-
split='train'):
9+
def __init__(
10+
self, datapath: str, transforms=None, download_data=True, split="train"
11+
):
1212
"""
1313
Initializes the SVHNDataset object.
1414
Args:
@@ -20,36 +20,38 @@ def __init__(self,
2020
AssertionError: If the split is not 'train' or 'test'.
2121
"""
2222
super().__init__()
23-
assert split == 'train' or split == 'test'
24-
23+
assert split == "train" or split == "test"
24+
2525
if download_data:
2626
self._download_data(datapath, split)
27-
28-
data = loadmat(os.path.join(datapath, f'{split}_32x32.mat'))
29-
27+
28+
data = loadmat(os.path.join(datapath, f"{split}_32x32.mat"))
29+
3030
# Images on the form N x H x W x C
31-
self.images = data['X'].transpose(3, 1, 0, 2)
32-
self.labels = data['y'].flatten()
31+
self.images = data["X"].transpose(3, 1, 0, 2)
32+
self.labels = data["y"].flatten()
3333
self.labels[self.labels == 10] = 0
34-
34+
3535
self.transforms = transforms
36+
3637
def _download_data(self, path: str, split: str):
3738
"""
3839
Downloads the SVHN dataset.
3940
Args:
4041
path (str): The directory where the dataset will be downloaded.
4142
split (str): The dataset split to download, either 'train' or 'test'.
4243
"""
43-
print(f'Downloading SVHN data into {path}')
44-
SVHN(path, split=split, download=True)
45-
44+
print(f"Downloading SVHN data into {path}")
45+
SVHN(path, split=split, download=True)
46+
4647
def __len__(self):
4748
"""
4849
Returns the number of samples in the dataset.
4950
Returns:
5051
int: The number of samples.
5152
"""
5253
return len(self.labels)
54+
5355
def __getitem__(self, index):
5456
"""
5557
Retrieves the image and label at the specified index.
@@ -59,8 +61,8 @@ def __getitem__(self, index):
5961
tuple: A tuple containing the image and its corresponding label.
6062
"""
6163
img, lab = self.images[index], self.labels[index]
62-
64+
6365
if self.transforms is not None:
6466
img = self.transforms(img)
65-
66-
return img, lab
67+
68+
return img, lab

utils/metrics/EntropyPred.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@
44

55

66
class EntropyPrediction(nn.Module):
7-
def __init__(self, averages: str = 'average'):
7+
def __init__(self, averages: str = "average"):
88
"""
99
Initializes the EntropyPrediction module.
1010
Args:
11-
averages (str): Specifies the method of aggregation for entropy values.
11+
averages (str): Specifies the method of aggregation for entropy values.
1212
Must be either 'average' or 'sum'.
1313
Raises:
1414
AssertionError: If the averages parameter is not 'average' or 'sum'.
1515
"""
1616
super().__init__()
17-
18-
assert averages == 'average' or averages == 'sum'
17+
18+
assert averages == "average" or averages == "sum"
1919
self.averages = averages
2020
self.stored_entropy_values = []
21-
21+
2222
def __call__(self, y_true, y_false_logits):
2323
"""
2424
Computes the entropy between true labels and predicted logits, storing the results.

0 commit comments

Comments
 (0)