Skip to content

Commit 0148ce7

Browse files
committed
Fixed a bug
1 parent 03cfaf7 commit 0148ce7

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

action/dataset.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,9 @@ def __init__(
395395
rcc_params=(224,),
396396
sparse_sample=False,
397397
labels = None,
398-
is_trimmed=True):
398+
is_trimmed=True,
399+
topk_predictions = 5
400+
):
399401
super().__init__(dataset, root, metadata, is_trimmed=is_trimmed)
400402

401403
self.transform = transform
@@ -416,6 +418,7 @@ def __init__(
416418
for noun in self.nouns:
417419
self.valid_gts.append(f'{verb} {noun}')
418420
self.labels = labels
421+
self.topk_predictions = topk_predictions
419422

420423
def __getitem__(self, i):
421424
frames, label = self.get_raw_item(
@@ -440,21 +443,25 @@ def __getitem__(self, i):
440443

441444
verb, noun = self.verbs[int(verb)], self.nouns[int(noun)]
442445

443-
letters = ['A', 'B', 'C', 'D', 'E']
444-
options = [0,1,2,3,4]
446+
letters = [chr(65+i) for i in range(26)][:self.topk_predictions]
447+
options = list(range(26))[:self.topk_predictions]
448+
449+
445450
wrong_answer_indices = np.random.choice(len(self.valid_gts), size = 5, replace = False)
446451
wrong_answers = [self.valid_gts[index] for index in wrong_answer_indices]
447452
for i in range(len(wrong_answers)):
448453
options[i] = f'{letters[i]}. {wrong_answers[i]}'
449-
correct_answer_index = np.random.choice(5, size=1, replace=False)[0]
454+
455+
correct_answer_index = np.random.choice(len(letters), size=1, replace=False)[0]
450456
correct_answer_letter = letters[correct_answer_index]
451457

452458
options[correct_answer_index] = f'{correct_answer_letter}. {verb} {noun}'
453459

454460
data = {
455461
'question': {0: 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer'},
456462
'option': {0: options},
457-
'answer': {0: correct_answer_letter}
463+
'answer': {0: correct_answer_letter},
464+
'answer_name': {0: f'{verb} {noun}'}
458465
}
459466

460467
return frames, data
@@ -481,7 +488,8 @@ def get_downstream_dataset(transform, crop_size, args, subset='train', label_map
481488
threads=args.decode_threads,
482489
fast_rcc=args.fused_decode_crop, rcc_params=(crop_size, ),
483490
is_trimmed=not args.dataset == 'charades_ego',
484-
labels = labels
491+
labels = labels,
492+
topk_predictions = args.topk_predictions
485493
)
486494
else:
487495
assert ValueError("subset should be either 'train' or 'val'")
@@ -677,7 +685,7 @@ def get_topk_predictions(prediction_file, idx, k):
677685
total_samples = 0
678686

679687
if args.action_predictions:
680-
valid_letters = [chr(65+i) for i in range(26)][args.topk_predictions]
688+
valid_letters = [chr(65+i) for i in range(26)][:args.topk_predictions]
681689
else:
682690
valid_letters = ['A', 'B', 'C', 'D', 'E']
683691

@@ -712,7 +720,8 @@ def get_topk_predictions(prediction_file, idx, k):
712720

713721
if args.action_predictions:
714722
mc_data = get_topk_predictions(args.action_predictions, idx, args.topk_predictions)
715-
723+
724+
716725
pred = llava_inference(frames, tokenizer, model, image_processor, max_length, mc_data, num_frames=args.llava_num_frames)
717726

718727
# if valid letter is found in the prediction, then we will use that as the prediction
@@ -726,6 +735,7 @@ def get_topk_predictions(prediction_file, idx, k):
726735
if not found:
727736
pred = 'N/A'
728737

738+
729739
preds.append(pred)
730740

731741
# Update running corrects and total samples

0 commit comments

Comments
 (0)