Skip to content

Commit d6191f5

Browse files
committed
better logger. Better prediction extraction. Better control
1 parent 0e8073f commit d6191f5

File tree

1 file changed

+37
-16
lines changed

1 file changed

+37
-16
lines changed

action/dataset.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from pathlib import Path
1414
import sys
1515
import os
16+
from action.llava_ov_inference import llava_inference
17+
import logging
1618
sys.path[0] = os.path.dirname(sys.path[0])
1719

1820

@@ -484,13 +486,14 @@ def get_downstream_dataset(transform, crop_size, args, subset='train', label_map
484486
assert ValueError("subset should be either 'train' or 'val'")
485487

486488

487-
def generate_label_map():
489+
def generate_label_map(args):
488490
print("Preprocess ek100 action label space")
489491
vn_list = []
490492
mapping_vn2narration = {}
491-
for f in [
492-
'/media/data/haozhe/VFM/EK100/epic-kitchens-100-annotations/EPIC_100_train.csv',
493-
'/media/data/haozhe/VFM/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv',
493+
anno_root = Path(args.train_metadata).parent
494+
for f in [ ,
495+
anno_root / 'EPIC_100_train.csv',
496+
anno_root / 'EPIC_100_validation.csv',
494497
]:
495498
csv_reader = csv.reader(open(f))
496499
_ = next(csv_reader) # skip the header
@@ -514,7 +517,7 @@ def generate_label_map():
514517
return labels, mapping_vn2act
515518

516519

517-
def get_args_parser():
520+
def get_args_parser(args):
518521
parser = argparse.ArgumentParser(description='AVION finetune ek100 cls', add_help=False)
519522
parser.add_argument('--dataset', default='ek100_cls', type=str, choices=['ek100_mir'])
520523
parser.add_argument('--root', default='/data/EK100/EK100_320p_15sec_30fps_libx264', type=str, help='path to train dataset root')
@@ -600,6 +603,11 @@ def get_args_parser():
600603
parser.add_argument('--dist-backend', default='nccl', type=str)
601604
parser.add_argument('--seed', default=0, type=int)
602605
parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
606+
# llava related
607+
# llm size is type of string and can only be '7b' or '5b' etc.
608+
parser.add_argument('--llm_size', default='7b', type=str, help='llm size')
609+
parser.add_argument('--llava_num_frames', default=16, type=int, help='number of frames for llava')
610+
603611
return parser
604612

605613
if __name__ == '__main__':
@@ -615,39 +623,52 @@ def get_args_parser():
615623
val_transform_gpu = torch.nn.Sequential(*gpu_val_transform_ls)
616624
crop_size = 336
617625

618-
labels, mapping_vn2act = generate_label_map()
626+
labels, mapping_vn2act = generate_label_map(args)
619627
val_dataset = get_downstream_dataset(
620628
val_transform_gpu, crop_size, args, subset='val', label_mapping=mapping_vn2act,
621629
labels = labels
622630
)
623631

624632
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)
625-
from action.llava_ov_inference import llava_inference
633+
626634
gts = []
627635
preds = []
628636
running_corrects = 0
629637
total_samples = 0
630638

639+
valid_letters = ['A', 'B', 'C', 'D', 'E']
640+
641+
# Set up logging
642+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename=f'llava_ov_{llava_num_frames}f_{args.llm_size}.log', filemode='w')
643+
logger = logging.getLogger(__name__)
644+
631645
for idx, (frames, gt) in tqdm(enumerate(val_dataloader)):
632-
pred = llava_inference(frames, gt)
633-
pred = pred[:pred.index('.')]
646+
pred = llava_inference(frames, gt, logger, num_frames=args.llava_num_frames, llm_size=args.llm_size)
647+
648+
# if valid letter is found in the prediction, then we will use that as the prediction
649+
found = False
650+
for letter in valid_letters:
651+
if letter in pred:
652+
pred = letter
653+
found = True
654+
break
655+
if not found:
656+
pred = 'N/A'
657+
634658
gts.append(gt['answer'][0][0])
635659
preds.append(pred)
636660

637661
# Update running corrects and total samples
638-
print (pred)
639-
print (gt['answer'][0][0])
640662
running_corrects += (pred == gt['answer'][0][0])
641663
total_samples += 1
642664

643-
# Calculate and print running mean accuracy
665+
# Calculate and log running mean accuracy
644666
running_accuracy = running_corrects / total_samples
645-
print(f'Running accuracy after {total_samples} samples: {running_accuracy:.4f}')
667+
logger.info(f'Running accuracy after {total_samples} samples: {running_accuracy:.4f}')
646668

647669
gts = np.array(gts)
648670
preds = np.array(preds)
649671
# get final accuracy
650672
accuracy = np.mean(gts == preds)
651-
print('Final accuracy', accuracy)
652-
with open('llava_ov_16f_7b_result.txt', 'w') as f:
653-
f.write(f'Final accuracy: {accuracy:.4f}\n')
673+
logger.info(f'Final accuracy: {accuracy:.4f}')
674+

0 commit comments

Comments
 (0)