Skip to content

Commit 9b5f6cb

Browse files
committed
Merge branch 'main' into christian/update-dataloader-recall
2 parents 9f4c389 + 78181b9 commit 9b5f6cb

File tree

4 files changed

+15
-13
lines changed

4 files changed

+15
-13
lines changed

CollaborativeCoding/dataloaders/mnist_0_3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22

33
import numpy as np
4+
from PIL import Image
45
from torch.utils.data import Dataset
56

67
from .datasources import MNIST_SOURCE
@@ -87,7 +88,8 @@ def __getitem__(self, index):
8788
28, 28
8889
) # Read image data
8990

90-
image = np.expand_dims(image, axis=0) # Add channel dimension
91+
# image = np.expand_dims(image, axis=0) # Add channel dimension
92+
image = Image.fromarray(image.astype(np.uint8))
9193

9294
if self.transform:
9395
image = self.transform(image)

CollaborativeCoding/dataloaders/mnist_4_9.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22

33
import numpy as np
4+
from PIL import Image
45
from torch.utils.data import Dataset
56

67
from .datasources import MNIST_SOURCE
@@ -28,11 +29,13 @@ def __init__(
2829
transform=None,
2930
nr_channels: int = 1,
3031
):
31-
super.__init__()
32+
super().__init__()
3233
self.data_path = data_path
3334
self.mnist_path = self.data_path / "MNIST"
3435
self.samples = sample_ids
3536
self.train = train
37+
self.transform = transform
38+
self.num_classes = 6
3639

3740
self.images_path = self.mnist_path / (
3841
MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1]
@@ -46,7 +49,7 @@ def __len__(self):
4649

4750
def __getitem__(self, idx):
4851
with open(self.labels_path, "rb") as labelfile:
49-
label_pos = 8 + self.sample[idx]
52+
label_pos = 8 + self.samples[idx]
5053
labelfile.seek(label_pos)
5154
label = int.from_bytes(labelfile.read(1), byteorder="big")
5255

@@ -57,7 +60,8 @@ def __getitem__(self, idx):
5760
28, 28
5861
)
5962

60-
image = np.expand_dims(image, axis=0) # Channel
63+
# image = np.expand_dims(image, axis=0) # Channel
64+
image = Image.fromarray(image.astype(np.uint8))
6165

6266
if self.transform:
6367
image = self.transform(image)

CollaborativeCoding/load_data.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,6 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
7878
train_indices = np.arange(len(train_labels))
7979
test_indices = np.arange(len(test_labels))
8080

81-
print(train_indices.shape)
82-
print(np.asarray(train_labels).shape)
83-
print(labels.shape)
84-
8581
# Filter the labels to only get indices of the wanted labels
8682
train_samples = train_indices[np.isin(train_labels, labels)]
8783
test_samples = test_indices[np.isin(test_labels, labels)]

tests/test_dataloaders.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from CollaborativeCoding.dataloaders import (
99
MNISTDataset0_3,
10+
MNISTDataset4_9,
1011
SVHNDataset,
1112
USPSDataset0_6,
1213
USPSH5_Digit_7_9_Dataset,
@@ -21,11 +22,12 @@
2122
("usps_7-9", USPSH5_Digit_7_9_Dataset),
2223
("mnist_0-3", MNISTDataset0_3),
2324
("svhn", SVHNDataset),
24-
# TODO: Add more datasets here
25+
("mnist_4-9", MNISTDataset4_9),
2526
],
2627
)
2728
def test_load_data(data_name, expected):
28-
dataset = load_data(
29+
print(data_name)
30+
dataset, _, _ = load_data(
2931
data_name,
3032
data_dir=Path("data"),
3133
transform=transforms.ToTensor(),
@@ -34,6 +36,4 @@ def test_load_data(data_name, expected):
3436
assert len(dataset) > 0
3537
assert isinstance(dataset[0], tuple)
3638
assert isinstance(dataset[0][0], torch.Tensor)
37-
assert isinstance(
38-
dataset[0][1], (int, torch.Tensor, np.ndarray)
39-
) # Should probably restrict this to only int or one-hot encoded tensor or array for consistency.
39+
assert isinstance(dataset[0][1], int)

0 commit comments

Comments
 (0)