Skip to content

Commit d524aab

Browse files
fix a bug
1 parent d2c8475 commit d524aab

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

ddranking/utils/data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,4 +221,6 @@ def get_random_images(dataset, class_indices, n_images_per_class):
221221
for i, (image, label) in enumerate(subset_dataset):
222222
selected_images.append(image)
223223
selected_labels.append(label)
224+
selected_images = torch.stack(selected_images, dim=0)
225+
selected_labels = torch.tensor(selected_labels, dtype=torch.long)
224226
return selected_images, selected_labels

0 commit comments

Comments
 (0)