1313from pathlib import Path
1414import sys
1515import os
16+ from action .llava_ov_inference import llava_inference
17+ import logging
1618sys .path [0 ] = os .path .dirname (sys .path [0 ])
1719
1820
@@ -484,13 +486,14 @@ def get_downstream_dataset(transform, crop_size, args, subset='train', label_map
484486 assert ValueError ("subset should be either 'train' or 'val'" )
485487
486488
487- def generate_label_map ():
489+ def generate_label_map (args ):
488490 print ("Preprocess ek100 action label space" )
489491 vn_list = []
490492 mapping_vn2narration = {}
491- for f in [
492- '/media/data/haozhe/VFM/EK100/epic-kitchens-100-annotations/EPIC_100_train.csv' ,
493- '/media/data/haozhe/VFM/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv' ,
493+ anno_root = Path (args .train_metadata ).parent
494+ for f in [ ,
495+ anno_root / 'EPIC_100_train.csv' ,
496+ anno_root / 'EPIC_100_validation.csv' ,
494497 ]:
495498 csv_reader = csv .reader (open (f ))
496499 _ = next (csv_reader ) # skip the header
@@ -514,7 +517,7 @@ def generate_label_map():
514517 return labels , mapping_vn2act
515518
516519
517- def get_args_parser ():
520+ def get_args_parser (args ):
518521 parser = argparse .ArgumentParser (description = 'AVION finetune ek100 cls' , add_help = False )
519522 parser .add_argument ('--dataset' , default = 'ek100_cls' , type = str , choices = ['ek100_mir' ])
520523 parser .add_argument ('--root' , default = '/data/EK100/EK100_320p_15sec_30fps_libx264' , type = str , help = 'path to train dataset root' )
@@ -600,6 +603,11 @@ def get_args_parser():
600603 parser .add_argument ('--dist-backend' , default = 'nccl' , type = str )
601604 parser .add_argument ('--seed' , default = 0 , type = int )
602605 parser .add_argument ('--gpu' , default = None , type = int , help = 'GPU id to use.' )
606+ # llava related
607+ # llm size is type of string and can only be '7b' or '5b' etc.
608+ parser .add_argument ('--llm_size' , default = '7b' , type = str , help = 'llm size' )
609+ parser .add_argument ('--llava_num_frames' , default = 16 , type = int , help = 'number of frames for llava' )
610+
603611 return parser
604612
605613if __name__ == '__main__' :
@@ -615,39 +623,52 @@ def get_args_parser():
615623 val_transform_gpu = torch .nn .Sequential (* gpu_val_transform_ls )
616624 crop_size = 336
617625
618- labels , mapping_vn2act = generate_label_map ()
626+ labels , mapping_vn2act = generate_label_map (args )
619627 val_dataset = get_downstream_dataset (
620628 val_transform_gpu , crop_size , args , subset = 'val' , label_mapping = mapping_vn2act ,
621629 labels = labels
622630 )
623631
624632 val_dataloader = DataLoader (val_dataset , batch_size = 1 , shuffle = False )
625- from action . llava_ov_inference import llava_inference
633+
626634 gts = []
627635 preds = []
628636 running_corrects = 0
629637 total_samples = 0
630638
639+ valid_letters = ['A' , 'B' , 'C' , 'D' , 'E' ]
640+
641+ # Set up logging
642+ logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' , filename = f'llava_ov_{ llava_num_frames } f_{ args .llm_size } .log' , filemode = 'w' )
643+ logger = logging .getLogger (__name__ )
644+
631645 for idx , (frames , gt ) in tqdm (enumerate (val_dataloader )):
632- pred = llava_inference (frames , gt )
633- pred = pred [:pred .index ('.' )]
646+ pred = llava_inference (frames , gt , logger , num_frames = args .llava_num_frames , llm_size = args .llm_size )
647+
648+ # if valid letter is found in the prediction, then we will use that as the prediction
649+ found = False
650+ for letter in valid_letters :
651+ if letter in pred :
652+ pred = letter
653+ found = True
654+ break
655+ if not found :
656+ pred = 'N/A'
657+
634658 gts .append (gt ['answer' ][0 ][0 ])
635659 preds .append (pred )
636660
637661 # Update running corrects and total samples
638- print (pred )
639- print (gt ['answer' ][0 ][0 ])
640662 running_corrects += (pred == gt ['answer' ][0 ][0 ])
641663 total_samples += 1
642664
643- # Calculate and print running mean accuracy
665+ # Calculate and log running mean accuracy
644666 running_accuracy = running_corrects / total_samples
645- print (f'Running accuracy after { total_samples } samples: { running_accuracy :.4f} ' )
667+ logger . info (f'Running accuracy after { total_samples } samples: { running_accuracy :.4f} ' )
646668
647669 gts = np .array (gts )
648670 preds = np .array (preds )
649671 # get final accuracy
650672 accuracy = np .mean (gts == preds )
651- print ('Final accuracy' , accuracy )
652- with open ('llava_ov_16f_7b_result.txt' , 'w' ) as f :
653- f .write (f'Final accuracy: { accuracy :.4f} \n ' )
673+ logger .info (f'Final accuracy: { accuracy :.4f} ' )
674+
0 commit comments