1919import logging
2020from llava .utils import rank0_print
2121from action .utils import generate_label_map , MultiChoiceGenerator , match_answer , parse_avion_predictions
22+ import copy
23+ from collections import Counter
2224
2325def datetime2sec (str ):
2426 hh , mm , ss = str .split (':' )
@@ -370,40 +372,93 @@ def prepare_llava(pretrained):
370372
371373 return tokenizer , model , image_processor , max_length
372374
373-
374- def get_topk_predictions (data , idx , k ):
375+ def get_topk_predictions (data , idx , k ):
375376
376377 letters = [chr (65 + i ) for i in range (26 )][:k ]
377378 options = list (range (26 ))[:k ]
378379
379380 predictions = data [str (idx )]['predictions' ][:k ]
380-
381381 predictions = parse_avion_predictions (predictions )
382382
383383 for i in range (len (options )):
384384 options [i ] = f'{ letters [i ]} . { predictions [i ]} '
385385
386386 mc_data = {
387387 'question' : {0 : 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer.' },
388- 'option' : {0 : options }
388+ 'options' : {0 : options },
389+ 'valid_letters' : letters ,
390+ 'avion_pred' : predictions [0 ]
389391 }
392+
393+ return mc_data
394+
395+ def ensemble_llava_evaluation (gt_name ,
396+ frames ,
397+ tokenizer ,
398+ model ,
399+ image_processor ,
400+ mc_data ,
401+ clip_length ,
402+ num_frames ,
403+ temperature = 0 ,
404+ ensemble_k = 1 ,
405+ is_test = False
406+ ):
407+ """
408+ This function tests how consistent the model is if we shuffle the position of the answers
409+ It also should use a higher temperature so we might get better performance by ensemble
410+ """
411+
412+ # shuffle the options
413+ options = mc_data ['options' ][0 ]
414+ letters = mc_data ['valid_letters' ]
415+ avion_pred = mc_data ['avion_pred' ]
416+ # each option was in the format of {letter}. {answer}
417+ preds = []
418+ for _ in range (ensemble_k ):
419+ # let's just shuffle the options
420+ random .shuffle (options )
421+ for idx , (option , letter ) in enumerate (zip (options , letters )):
422+ sep = option .index ('.' )
423+ options [idx ] = f'{ letter } .{ option [sep + 1 :]} '
424+ rank0_print ('generated new option sequence' )
425+ rank0_print (options )
426+
427+ pred = llava_inference (frames ,
428+ tokenizer ,
429+ model ,
430+ image_processor ,
431+ mc_data ,
432+ clip_length = clip_length ,
433+ num_frames = num_frames ,
434+ temperature = temperature ,
435+ is_test = is_test
436+ )
437+
438+ rank0_print ('llava pred' , pred , 'avion_pred' , avion_pred , 'gt_name' , gt_name )
439+ sep = pred .index ('.' )
440+ pred = pred [sep + 1 :].strip ()
441+ preds .append (pred )
442+
443+ counter = Counter (preds )
444+ rank0_print ('inspecting the counter' , counter )
445+ rank0_print ('most common' , counter .most_common (1 )[0 ][0 ])
446+
447+ return match_answer (counter .most_common (1 )[0 ][0 ], gt_name )
448+
390449
391- return mc_data , predictions [0 ]
392450
393451def evaluate_on_EK100 (eval_args ,
394452 model = None ,
395453 tokenizer = None ,
396454 image_processor = None ):
397455
398- if image_processor is None :
456+ if model is not None :
399457 image_processor = model .get_vision_tower ().image_processor
400458
401459 gpu_val_transform_ls = []
402-
403460 val_transform_gpu = torch .nn .Sequential (* gpu_val_transform_ls )
404-
405461 crop_size = 336
406-
407462 labels , mapping_vn2act , verb_maps , noun_maps = generate_label_map (Path (eval_args .val_metadata ).parent )
408463
409464 val_dataset = VideoMultiChoiceDataset (
@@ -468,7 +523,8 @@ def evaluate_on_EK100(eval_args,
468523 gt_name = mc_data ['gt_answer_name' ][0 ][0 ]
469524
470525 if eval_args .action_predictions :
471- mc_data , avion_pred = get_topk_predictions (predictions , idx , eval_args .topk_predictions )
526+ mc_data = get_topk_predictions (predictions , idx , eval_args .topk_predictions )
527+ avion_pred = mc_data ['avion_pred' ]
472528 if gt_name == avion_pred :
473529 avaion_correct += 1
474530
@@ -477,18 +533,30 @@ def evaluate_on_EK100(eval_args,
477533 if finish_early and idx > 999 :
478534 break
479535
480- pred = llava_inference (frames , tokenizer , model , image_processor , mc_data , clip_length = eval_args .clip_length , num_frames = eval_args .llava_num_frames )
536+ # pred = llava_inference(frames, tokenizer, model, image_processor, mc_data, clip_length = eval_args.clip_length, num_frames=eval_args.llava_num_frames)
481537
482- # if valid letter is found in the prediction, then we will use that as the prediction
483- rank0_print ('llava pred' , pred , 'avion_pred' , avion_pred , 'gt_name' , gt_name )
538+ # # if valid letter is found in the prediction, then we will use that as the prediction
539+ # rank0_print ('llava pred', pred, 'avion_pred', avion_pred, 'gt_name', gt_name)
484540
485541 # Update running corrects and total samples
486- running_corrects += (match_answer (pred , gt_name ))
542+ running_corrects += ensemble_llava_evaluation (gt_name ,
543+ frames ,
544+ tokenizer ,
545+ model ,
546+ image_processor ,
547+ mc_data ,
548+ eval_args .clip_length ,
549+ eval_args .llava_num_frames ,
550+ temperature = 2.0 ,
551+ ensemble_k = 5 ,
552+ is_test = not finish_early )
553+
487554 total_samples += 1
488555
489556 # Calculate and log running mean accuracy
490557 running_accuracy = running_corrects / total_samples
491558
559+ logger .info (f'running accuracy: { running_accuracy :.4f} ' )
492560 if eval_args .action_predictions :
493561 avaion_accuracy = avaion_correct / total_samples
494562
0 commit comments