Skip to content

Commit f86e028

Browse files
committed
nr_channels as argument to conform with load_data formalism
1 parent eb96039 commit f86e028

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

CollaborativeCoding/dataloaders/mnist_4_9.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class MNISTDataset4_9(Dataset):
2020
Whether to train the model or not, by default False
2121
"""
2222

23-
def __init__(self, data_path: Path, sample_ids: np.ndarray, train: bool = False):
23+
def __init__(self, data_path: Path, sample_ids: np.ndarray, train: bool = False, transform = None, nr_channels: int = 1):
2424
super.__init__()
2525
self.data_path = data_path
2626
self.mnist_path = self.data_path / "MNIST"
@@ -51,5 +51,8 @@ def __getitem__(self, idx):
5151
)
5252

5353
image = np.expand_dims(image, axis=0) # Channel
54+
55+
if self.transform:
56+
image = self.transform(image)
5457

5558
return image, label

0 commit comments

Comments
 (0)