Skip to content

Commit 8b5970f

Browse files
committed
fixed another bug
1 parent 5454769 commit 8b5970f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

action/dataset.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,15 +445,21 @@ def __getitem__(self, i):
445445
letters = [chr(65+i) for i in range(26)][:self.topk_predictions]
446446
options = list(range(26))[:self.topk_predictions]
447447
option_names = []
448-
# wrong answer can come from any valid gt
449-
wrong_answer_indices = np.random.choice(len(self.valid_gts), size = 5, replace = False)
448+
449+
# randomly sample topk actions from valid gts
450+
451+
wrong_answer_indices = np.random.choice(len(self.valid_gts), size = args.topk_predictions, replace = False)
452+
450453
wrong_answers = [self.valid_gts[index] for index in wrong_answer_indices]
454+
451455
for i in range(len(wrong_answers)):
452456
options[i] = f'{letters[i]}. {wrong_answers[i]}'
453457
option_names.append(wrong_answers[i])
454458

455459
# correct answer must come from the available letters
460+
456461
correct_answer_index = np.random.choice(len(letters), size=1, replace=False)[0]
462+
457463
correct_answer_letter = letters[correct_answer_index]
458464

459465
option_names[correct_answer_index] = f'{verb} {noun}'

0 commit comments

Comments
 (0)