We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent efe6894 commit 68b5616Copy full SHA for 68b5616
utils/dataloaders/usps_0_6.py
@@ -106,6 +106,12 @@ def __getitem__(self, idx):
106
107
data = data.reshape(16, 16)
108
109
+ # one hot encode the target
110
+ target = np.eye(self.num_classes, dtype=np.float32)[target]
111
+
112
+ # Add channel dimension
113
+ data = np.expand_dims(data, axis=0)
114
115
if self.transform:
116
data = self.transform(data)
117
0 commit comments