We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent af3d8bc commit 09fa0d0Copy full SHA for 09fa0d0
CollaborativeCoding/dataloaders/mnist_4_9.py
@@ -43,7 +43,12 @@ def __init__(
43
self.labels_path = self.mnist_path / (
44
MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1]
45
)
46
-
+
47
+ # Functions to map the labels from (4,9) -> (0,5) for CrossEntropyLoss to work properly.
48
+ self.label_shift = lambda x: x-4
49
+ self.label_restore = lambda x: x+4
50
51
52
def __len__(self):
53
return len(self.samples)
54
@@ -66,4 +71,4 @@ def __getitem__(self, idx):
66
71
if self.transform:
67
72
image = self.transform(image)
68
73
69
- return image, label
74
+ return image, self.label_shift(label)
0 commit comments