Skip to content

Commit 9e1f72e

Browse files
Fix a bug
1 parent 9c70456 commit 9e1f72e

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

ddranking/utils/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import kornia as K
77
from tqdm import tqdm
88
from torch import Tensor
9+
from torch.utils.data import Subset
910

1011

1112
class Config:
@@ -219,7 +220,7 @@ def get_random_images(dataset, class_indices, n_images_per_class):
219220

220221
subset_indices = []
221222
for indices in class_indices:
222-
subset_indices.extend(random.sample(indices, n))
223+
subset_indices.extend(random.sample(indices, n_images_per_class))
223224
subset_dataset = Subset(dataset, subset_indices)
224225

225226
selected_images, selected_labels = [], []

0 commit comments

Comments
 (0)