Skip to content

Commit d7e83d4

Browse files
committed
Fixed several tests after chaning the code
1 parent 78680a8 commit d7e83d4

File tree

3 files changed

+31
-26
lines changed

3 files changed

+31
-26
lines changed

tests/test_dataloaders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ def test_svhn_dataset():
4444
trans = transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor()])
4545

4646
dataset = SVHNDataset(
47-
tempdir, train=True, transform=trans, download=True, nr_channels=1
47+
tempdir, train=False, transform=trans, download=True, nr_channels=1
4848
)
4949

5050
assert dataset.__len__() != 0
51-
assert os.path.exists(os.path.join(tempdir, "train_32x32.mat"))
51+
assert os.path.exists(os.path.join(tempdir, "test_32x32.mat")), f'No such file as test_32x32.mat. Try running download=True'
52+
assert os.path.exists(os.path.join(tempdir, "svhn_testdata.h5")), f'No such file as svhn_testdata.h5. Try running download=True'
5253

5354
img, label = dataset.__getitem__(0)
5455
assert len(img.size()) == 3 and img.size() == (1, 28, 28) and img.size(0) == 1
55-
assert len(label.size()) == 1

utils/dataloaders/svhn.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22

3+
import h5py
34
import numpy as np
5+
from PIL import Image
46
from scipy.io import loadmat
57
from torch.utils.data import Dataset
68
from torchvision.datasets import SVHN
@@ -27,22 +29,23 @@ def __init__(
2729
AssertionError: If the split is not 'train' or 'test'.
2830
"""
2931
super().__init__()
32+
self.data_path = data_path
3033
self.split = "train" if train else "test"
3134

3235
if download:
3336
self._download_data(data_path)
3437

35-
data = loadmat(os.path.join(data_path, f"{self.split}_32x32.mat"))
36-
37-
# Reform images to the form N x H x W x C
38-
self.images = data["X"].transpose(3, 1, 0, 2)
39-
self.labels = data["y"].flatten()
40-
41-
self.labels[self.labels == 10] = 0
42-
4338
self.nr_channels = nr_channels
4439
self.transforms = transform
40+
41+
42+
assert os.path.exists(os.path.join(self.data_path, f'svhn_{self.split}data.h5')), f'File svhn_{self.split}data.h5 does not exists. Run download=True'
43+
with h5py.File(os.path.join(self.data_path, f'svhn_{self.split}data.h5'), 'r') as h5f:
44+
self.labels = h5f['labels'][:]
45+
4546
self.num_classes = len(np.unique(self.labels))
47+
48+
4649

4750
def _download_data(self, path: str):
4851
"""
@@ -52,7 +55,17 @@ def _download_data(self, path: str):
5255
"""
5356
print(f"Downloading SVHN data into {path}")
5457
SVHN(path, split=self.split, download=True)
58+
data = loadmat(os.path.join(path, f'{self.split}_32x32.mat'))
5559

60+
images, labels = data['X'], data['y']
61+
images = images.transpose(3,1,0,2)
62+
labels[labels == 10] = 0
63+
labels = labels.flatten()
64+
65+
with h5py.File(os.path.join(self.data_path, f'svhn_{self.split}data.h5'), 'w') as h5f:
66+
h5f.create_dataset('images', data=images)
67+
h5f.create_dataset('labels', data=labels)
68+
5669
def __len__(self):
5770
"""
5871
Returns the number of samples in the dataset.
@@ -69,11 +82,15 @@ def __getitem__(self, index):
6982
Returns:
7083
tuple: A tuple containing the image and its corresponding label.
7184
"""
72-
img, lab = self.images[index], self.labels[index]
73-
85+
lab = self.labels[index]
86+
with h5py.File(os.path.join(self.data_path, f'svhn_{self.split}data.h5'), 'r') as h5f:
87+
img = Image.fromarray(h5f['images'][index])
88+
7489
if self.nr_channels == 1:
75-
img = np.mean(img, axis=2, keepdims=True)
90+
img = img.convert('L')
91+
7692
if self.transforms is not None:
7793
img = self.transforms(img)
7894

7995
return img, lab
96+

utils/metrics/EntropyPred.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,3 @@ def __returnmetric__(self):
6363
def __reset__(self):
6464
self.stored_entropy_values = []
6565

66-
if __name__ == '__main__':
67-
68-
pred_logits = th.rand(6, 5)
69-
true_lab = th.rand(6, 5)
70-
71-
metric = EntropyPrediction(averages="mean")
72-
metric2 = EntropyPrediction(averages="sum")
73-
74-
# Test for averaging metric consistency
75-
metric(true_lab, pred_logits)
76-
metric2(true_lab, pred_logits)
77-
assert (th.abs(th.sum(6 * metric.__returnmetric__() - metric2.__returnmetric__())) < 1e-5)

0 commit comments

Comments
 (0)