Skip to content

Commit 09fa0d0

Browse files
committed
updated mnist
1 parent af3d8bc commit 09fa0d0

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

CollaborativeCoding/dataloaders/mnist_4_9.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ def __init__(
4343
self.labels_path = self.mnist_path / (
4444
MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1]
4545
)
46-
46+
47+
# Functions to map the labels from (4,9) -> (0,5) for CrossEntropyLoss to work properly.
48+
self.label_shift = lambda x: x-4
49+
self.label_restore = lambda x: x+4
50+
51+
4752
def __len__(self):
4853
return len(self.samples)
4954

@@ -66,4 +71,4 @@ def __getitem__(self, idx):
6671
if self.transform:
6772
image = self.transform(image)
6873

69-
return image, label
74+
return image, self.label_shift(label)

0 commit comments

Comments
 (0)