Skip to content

Commit 658acdb

Browse files
committed
cleaner code
1 parent 0148ce7 commit 658acdb

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

action/dataset.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -446,12 +446,13 @@ def __getitem__(self, i):
446446
letters = [chr(65+i) for i in range(26)][:self.topk_predictions]
447447
options = list(range(26))[:self.topk_predictions]
448448

449-
449+
# wrong answer can come from any valid gt
450450
wrong_answer_indices = np.random.choice(len(self.valid_gts), size = 5, replace = False)
451451
wrong_answers = [self.valid_gts[index] for index in wrong_answer_indices]
452452
for i in range(len(wrong_answers)):
453453
options[i] = f'{letters[i]}. {wrong_answers[i]}'
454-
454+
455+
# correct answer must come from the available letters
455456
correct_answer_index = np.random.choice(len(letters), size=1, replace=False)[0]
456457
correct_answer_letter = letters[correct_answer_index]
457458

@@ -460,7 +461,9 @@ def __getitem__(self, i):
460461
data = {
461462
'question': {0: 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer'},
462463
'option': {0: options},
464+
# the correct letter in mc
463465
'answer': {0: correct_answer_letter},
466+
# for inspecting
464467
'answer_name': {0: f'{verb} {noun}'}
465468
}
466469

@@ -637,10 +640,7 @@ def prepare_llava():
637640
return tokenizer, model, image_processor, max_length
638641

639642

640-
def get_topk_predictions(prediction_file, idx, k):
641-
642-
with open(prediction_file, 'r') as f:
643-
data = json.load(f)
643+
def get_topk_predictions(data, idx, k):
644644

645645
letters = [chr(65+i) for i in range(26)][:k]
646646
options = list(range(26))[:k]
@@ -711,6 +711,9 @@ def get_topk_predictions(prediction_file, idx, k):
711711
pretrained = f"lmms-lab/llava-onevision-qwen2-{args.llm_size}-ov"
712712

713713
tokenizer, model, image_processor, max_length = prepare_llava()
714+
715+
with open(args.action_predictions, 'r') as f:
716+
predictions = json.load(f)
714717

715718
for idx, (frames, mc_data) in tqdm(enumerate(val_dataloader)):
716719

@@ -719,7 +722,7 @@ def get_topk_predictions(prediction_file, idx, k):
719722
gts.append(gt)
720723

721724
if args.action_predictions:
722-
mc_data = get_topk_predictions(args.action_predictions, idx, args.topk_predictions)
725+
mc_data = get_topk_predictions(predictions, idx, args.topk_predictions)
723726

724727

725728
pred = llava_inference(frames, tokenizer, model, image_processor, max_length, mc_data, num_frames=args.llava_num_frames)

0 commit comments

Comments
 (0)