Skip to content

Commit 68b5616

Browse files
committed
Onehot encode labels in dataset
1 parent efe6894 commit 68b5616

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

utils/dataloaders/usps_0_6.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ def __getitem__(self, idx):
106106

107107
data = data.reshape(16, 16)
108108

109+
# one hot encode the target
110+
target = np.eye(self.num_classes, dtype=np.float32)[target]
111+
112+
# Add channel dimension
113+
data = np.expand_dims(data, axis=0)
114+
109115
if self.transform:
110116
data = self.transform(data)
111117

0 commit comments

Comments
 (0)