Skip to content

Commit 5454769

Browse files
committed
fixed known bugs
1 parent 658acdb commit 5454769

File tree

2 files changed

+52
-20
lines changed

2 files changed

+52
-20
lines changed

action/dataset.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -440,22 +440,24 @@ def __getitem__(self, i):
440440
frames = self.transform(frames)
441441

442442
verb, noun = label.split(':')
443-
444443
verb, noun = self.verbs[int(verb)], self.nouns[int(noun)]
445444

446445
letters = [chr(65+i) for i in range(26)][:self.topk_predictions]
447446
options = list(range(26))[:self.topk_predictions]
448-
447+
option_names = []
449448
# wrong answer can come from any valid gt
450449
wrong_answer_indices = np.random.choice(len(self.valid_gts), size = 5, replace = False)
451450
wrong_answers = [self.valid_gts[index] for index in wrong_answer_indices]
452451
for i in range(len(wrong_answers)):
453452
options[i] = f'{letters[i]}. {wrong_answers[i]}'
453+
option_names.append(wrong_answers[i])
454454

455455
# correct answer must come from the available letters
456456
correct_answer_index = np.random.choice(len(letters), size=1, replace=False)[0]
457457
correct_answer_letter = letters[correct_answer_index]
458458

459+
option_names[correct_answer_index] = f'{verb} {noun}'
460+
459461
options[correct_answer_index] = f'{correct_answer_letter}. {verb} {noun}'
460462

461463
data = {
@@ -467,7 +469,7 @@ def __getitem__(self, i):
467469
'answer_name': {0: f'{verb} {noun}'}
468470
}
469471

470-
return frames, data
472+
return frames, data, option_names
471473

472474

473475
def get_downstream_dataset(transform, crop_size, args, subset='train', label_mapping=None, labels = None):
@@ -644,18 +646,19 @@ def get_topk_predictions(data, idx, k):
644646

645647
letters = [chr(65+i) for i in range(26)][:k]
646648
options = list(range(26))[:k]
647-
648-
predictions = data[str(idx)]['predictions'][:k]
649649

650+
predictions = data[str(idx)]['predictions'][:k]
651+
predictions = [e.replace(':', ' ') for e in predictions]
652+
650653
for i in range(len(options)):
651654
options[i] = f'{letters[i]}. {predictions[i]}'
652-
653-
655+
654656
mc_data = {
655657
'question': {0: 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer'},
656658
'option': {0: options}
657-
}
658-
return mc_data
659+
}
660+
661+
return mc_data, predictions
659662

660663

661664
if __name__ == '__main__':
@@ -712,21 +715,29 @@ def get_topk_predictions(data, idx, k):
712715

713716
tokenizer, model, image_processor, max_length = prepare_llava()
714717

715-
with open(args.action_predictions, 'r') as f:
716-
predictions = json.load(f)
718+
if args.action_predictions:
719+
with open(args.action_predictions, 'r') as f:
720+
predictions = json.load(f)
721+
722+
avaion_correct = 0
717723

718-
for idx, (frames, mc_data) in tqdm(enumerate(val_dataloader)):
724+
for idx, (frames, mc_data, option_names) in tqdm(enumerate(val_dataloader)):
719725

720726
gt = mc_data['answer'][0][0]
727+
728+
gt_name = mc_data['answer_name'][0][0]
721729

722-
gts.append(gt)
730+
gts.append(gt_name)
723731

724732
if args.action_predictions:
725-
mc_data = get_topk_predictions(predictions, idx, args.topk_predictions)
726-
733+
mc_data, avaion_pred = get_topk_predictions(predictions, idx, args.topk_predictions)
734+
if gt_name == avaion_pred[0]:
735+
avaion_correct+=1
736+
else:
737+
pass
727738

728739
pred = llava_inference(frames, tokenizer, model, image_processor, max_length, mc_data, num_frames=args.llava_num_frames)
729-
740+
730741
# if valid letter is found in the prediction, then we will use that as the prediction
731742
found = False
732743

@@ -738,17 +749,38 @@ def get_topk_predictions(data, idx, k):
738749
if not found:
739750
pred = 'N/A'
740751

741-
742-
preds.append(pred)
752+
if args.action_predictions:
753+
if pred in valid_letters:
754+
pred_index = valid_letters.index(pred)
755+
pred_name = avaion_pred[pred_index]
756+
else:
757+
pred_name = 'N/A'
758+
else:
759+
if pred in valid_letters:
760+
pred_index = valid_letters.index(pred)
761+
pred_name = option_names[pred_index][0]
762+
else:
763+
pred_name = 'N/A'
764+
765+
print ('gt_name', gt_name)
766+
print ('pred_name', pred_name)
767+
768+
preds.append(pred_name)
743769

744770
# Update running corrects and total samples
745-
running_corrects += (pred == gt)
771+
running_corrects += (pred_name == gt_name)
746772
total_samples += 1
747773

748774
# Calculate and log running mean accuracy
749775
running_accuracy = running_corrects / total_samples
776+
750777
logger.info(f'Running accuracy after {total_samples} samples: {running_accuracy:.4f}')
751778

779+
if args.action_predictions:
780+
avaion_accuracy = avaion_correct / total_samples
781+
logger.info(f'Running avaion accuracy after {total_samples} samples: {avaion_accuracy:.4f}')
782+
783+
752784
gts = np.array(gts)
753785
preds = np.array(preds)
754786
# get final accuracy

action/llava_ov_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def llava_inference(video_frames, tokenizer, model, image_processor, max_length,
3131
option = mc_data['option'][0]
3232

3333
question = f"{DEFAULT_IMAGE_TOKEN}\n{question}:{option}"
34-
34+
3535
conv = copy.deepcopy(conv_templates[conv_template])
3636
conv.append_message(conv.roles[0], question)
3737
conv.append_message(conv.roles[1], None)

0 commit comments

Comments
 (0)