Skip to content

Commit 922efc0

Browse files
authored
Merge pull request #96 from SFI-Visual-Intelligence/solveig_dataloader
fixed remapping of the labels for usps 7-9
2 parents 8489dc1 + 18083b3 commit 922efc0

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

CollaborativeCoding/dataloaders/uspsh5_7_9.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def __init__(
6565
mask = np.isin(labels, [7, 8, 9])
6666
self.images = images[mask]
6767
self.labels = labels[mask]
68+
# map labels from (7,9) to (0,2) for CE loss
69+
self.label_shift = lambda x: x - 7
70+
self.label_restore = lambda x: x + 7
6871

6972
def __len__(self):
7073
"""
@@ -95,7 +98,7 @@ def __getitem__(self, id):
9598
# Convert to PIL Image (USPS images are typically grayscale 16x16)
9699
image = Image.fromarray(self.images[id].astype(np.uint8), mode="L")
97100
label = int(self.labels[id]) # Convert label to integer
98-
101+
label = self.label_shift(label)
99102
if self.transform:
100103
image = self.transform(image)
101104

0 commit comments

Comments
 (0)