Skip to content

Commit f94b866

Browse files
committed
possible to refine action recognition predictions
1 parent 0cacc5f commit f94b866

File tree

3 files changed

+52
-9
lines changed

3 files changed

+52
-9
lines changed

action/dataset.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
sys.path[0] = os.path.dirname(sys.path[0])
1717
from action.llava_ov_inference import llava_inference
18+
import json
1819
import logging
1920

2021

@@ -607,6 +608,9 @@ def get_args_parser():
607608
# llm size is type of string and can only be '7b' or '5b' etc.
608609
parser.add_argument('--llm_size', default='7b', type=str, help='llm size')
609610
parser.add_argument('--llava_num_frames', default=16, type=int, help='number of frames for llava')
611+
## avaion refinement
612+
parser.add_argument('--action_predictions', default=None, type=str, help='path to action predictions')
613+
parser.add_argument('--topk_predictions', default = 5, type =int)
610614

611615
return parser
612616

@@ -625,6 +629,27 @@ def prepare_llava():
625629
return tokenizer, model, image_processor, max_length
626630

627631

632+
def get_topk_predictions(prediction_file, idx, k):
633+
634+
with open(prediction_file, 'r') as f:
635+
data = json.load(f)
636+
637+
letters = [chr(65+i) for i in range(26)][:k]
638+
options = list(range(26))[:k]
639+
640+
predictions = data[str(idx)]['predictions'][:k]
641+
642+
for i in range(len(options)):
643+
options[i] = f'{letters[i]}. {predictions[i]}'
644+
645+
646+
mc_data = {
647+
'question': {0: 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer'},
648+
'option': {0: options}
649+
}
650+
return mc_data
651+
652+
628653
if __name__ == '__main__':
629654
from moviepy.editor import ImageSequenceClip
630655
import torchvision
@@ -651,10 +676,18 @@ def prepare_llava():
651676
running_corrects = 0
652677
total_samples = 0
653678

654-
valid_letters = ['A', 'B', 'C', 'D', 'E']
679+
if args.action_predictions:
680+
valid_letters = [chr(65+i) for i in range(26)][args.topk_predictions]
681+
else:
682+
valid_letters = ['A', 'B', 'C', 'D', 'E']
655683

684+
if not args.action_predictions:
685+
log_filename = f'llava_ov_{args.llava_num_frames}f_{args.llm_size}.log'
686+
else:
687+
log_filename = f'llava_ov_{args.llava_num_frames}f_{args.llm_size}_action_{args.topk_predictions}.log'
688+
656689
# Set up logging
657-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename=f'llava_ov_{args.llava_num_frames}f_{args.llm_size}.log', filemode='w')
690+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename=log_filename, filemode='w')
658691

659692
console_handler = logging.StreamHandler(sys.stdout)
660693
console_handler.setLevel(logging.INFO)
@@ -671,11 +704,20 @@ def prepare_llava():
671704

672705
tokenizer, model, image_processor, max_length = prepare_llava()
673706

674-
for idx, (frames, gt) in tqdm(enumerate(val_dataloader)):
675-
pred = llava_inference(frames, tokenizer, model, image_processor, max_length, gt, num_frames=args.llava_num_frames)
707+
for idx, (frames, mc_data) in tqdm(enumerate(val_dataloader)):
708+
709+
gt = mc_data['answer'][0][0]
710+
711+
gts.append(gt)
712+
713+
if args.action_predictions:
714+
mc_data = get_topk_predictions(args.action_predictions, idx, args.topk_predictions)
715+
716+
pred = llava_inference(frames, tokenizer, model, image_processor, max_length, mc_data, num_frames=args.llava_num_frames)
676717

677718
# if valid letter is found in the prediction, then we will use that as the prediction
678719
found = False
720+
679721
for letter in valid_letters:
680722
if letter in pred:
681723
pred = letter
@@ -684,11 +726,10 @@ def prepare_llava():
684726
if not found:
685727
pred = 'N/A'
686728

687-
gts.append(gt['answer'][0][0])
688729
preds.append(pred)
689730

690731
# Update running corrects and total samples
691-
running_corrects += (pred == gt['answer'][0][0])
732+
running_corrects += (pred == gt)
692733
total_samples += 1
693734

694735
# Calculate and log running mean accuracy

action/llava_ov_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from decord import VideoReader, cpu
1515

1616

17-
def llava_inference(video_frames, tokenizer, model, image_processor, max_length, gt, num_frames=16):
17+
def llava_inference(video_frames, tokenizer, model, image_processor, max_length, mc_data, num_frames=16):
1818

1919
model.eval()
2020
device = "cuda"
@@ -27,8 +27,8 @@ def llava_inference(video_frames, tokenizer, model, image_processor, max_length,
2727

2828
conv_template = "qwen_1_5"
2929

30-
question = gt['question'][0]
31-
option = gt['option'][0]
30+
question = mc_data['question'][0]
31+
option = mc_data['option'][0]
3232

3333
question = f"{DEFAULT_IMAGE_TOKEN}\n{question}:{option}"
3434

run_EK100.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ python3 action/dataset.py \
44
--val-metadata /media/data/haozhe/VFM/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv \
55
--llm_size 7b \
66
--llava_num_frames 16 > kitchen_test.out 2>&1 \
7+
# --action_predictions action/avaion_predictions.json \
8+
# --topk_predictions 10

0 commit comments

Comments
 (0)