Skip to content

Commit 1534053

Browse files
committed
Fixed all SVHN related bugs
1 parent 133f3bf commit 1534053

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

CollaborativeCoding/dataloaders/mnist_0_3.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,5 +93,4 @@ def __getitem__(self, index):
9393

9494
if self.transform:
9595
image = self.transform(image)
96-
9796
return image, label

CollaborativeCoding/dataloaders/svhn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def __init__(
3838
self.nr_channels = nr_channels
3939
self.transforms = transform
4040

41+
if not os.path.exists(
42+
os.path.join(self.data_path, f"svhn_{self.split}data.h5")
43+
):
44+
self._download_data(self.data_path)
45+
4146
assert os.path.exists(
4247
os.path.join(self.data_path, f"svhn_{self.split}data.h5")
4348
), f"File svhn_{self.split}data.h5 does not exists. Run download=True"
@@ -97,4 +102,4 @@ def __getitem__(self, index):
97102
if self.transforms is not None:
98103
img = self.transforms(img)
99104

100-
return img, lab
105+
return img, int(lab)

CollaborativeCoding/load_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
7979
test_indices = np.arange(len(test_labels))
8080

8181
# Filter the labels to only get indices of the wanted labels
82-
train_samples = train_indices[np.isin(train_labels, labels)]
83-
test_samples = test_indices[np.isin(test_labels, labels)]
82+
train_samples = train_indices[np.isin(train_labels, labels).flatten()]
83+
test_samples = test_indices[np.isin(test_labels, labels).flatten()]
8484

8585
train_samples, val_samples = random_split(train_samples, [1 - val_size, val_size])
8686

0 commit comments

Comments
 (0)