Skip to content
Discussion options

You must be logged in to vote

You mean to extract the list of batch indexes for each class? If so you could just do something simple like

from collections import defaultdict
classes = {}
for i, d in enumerate(ds):
    classes[d.y.item()].append(i)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jiaruHithub
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants