We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 8489dc1 + 18083b3 commit 922efc0Copy full SHA for 922efc0
CollaborativeCoding/dataloaders/uspsh5_7_9.py
@@ -65,6 +65,9 @@ def __init__(
65
mask = np.isin(labels, [7, 8, 9])
66
self.images = images[mask]
67
self.labels = labels[mask]
68
+ # map labels from (7,9) to (0,2) for CE loss
69
+ self.label_shift = lambda x: x - 7
70
+ self.label_restore = lambda x: x + 7
71
72
def __len__(self):
73
"""
@@ -95,7 +98,7 @@ def __getitem__(self, id):
95
98
# Convert to PIL Image (USPS images are typically grayscale 16x16)
96
99
image = Image.fromarray(self.images[id].astype(np.uint8), mode="L")
97
100
label = int(self.labels[id]) # Convert label to integer
-
101
+ label = self.label_shift(label)
102
if self.transform:
103
image = self.transform(image)
104
0 commit comments