@@ -667,7 +667,7 @@ def evaluate_on_EK100(eval_args, model= None, tokenizer= None, max_length= None,
667667 for idx , (frames , mc_data , option_names ) in tqdm (enumerate (val_dataloader )):
668668 gt = mc_data ['answer' ][0 ][0 ]
669669 gt_name = mc_data ['answer_name' ][0 ][0 ]
670- gts . append ( gt_name )
670+
671671 if eval_args .action_predictions :
672672 mc_data , avaion_pred , target = get_topk_predictions (predictions , idx , eval_args .topk_predictions )
673673 target = target .replace (':' , ' ' )
@@ -676,7 +676,7 @@ def evaluate_on_EK100(eval_args, model= None, tokenizer= None, max_length= None,
676676 # we don't want to evaluate the whole thing
677677 if finish_early and idx > 9 :
678678 break
679-
679+
680680 pred = llava_inference (frames , tokenizer , model , image_processor , max_length , mc_data , clip_length = eval_args .clip_length , num_frames = eval_args .llava_num_frames )
681681
682682 # if valid letter is found in the prediction, then we will use that as the prediction
@@ -704,7 +704,7 @@ def evaluate_on_EK100(eval_args, model= None, tokenizer= None, max_length= None,
704704 pred_name = option_names [pred_index ][0 ]
705705 else :
706706 pred_name = 'N/A'
707-
707+ gts . append ( gt_name )
708708 preds .append (pred_name )
709709
710710 # Update running corrects and total samples
0 commit comments