Skip to content

Commit 3e4bba9

Browse files
author
Ye Shaokai
committed
updates
1 parent 1934163 commit 3e4bba9

File tree

5 files changed

+256
-24
lines changed

5 files changed

+256
-24
lines changed

action/ek_eval.py

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import logging
2020
from llava.utils import rank0_print
2121
from action.utils import generate_label_map, MultiChoiceGenerator, match_answer, parse_avion_predictions
22+
from action.prediction_analysis import PredictionAnalysis
2223
import copy
2324
from collections import Counter
2425
import torch.distributed as dist
@@ -224,6 +225,10 @@ def get_raw_item(
224225
fast_rcc=fast_rcc,
225226
rcc_params=rcc_params,
226227
jitter=is_training)
228+
time_meta['start_second'] = start_second
229+
time_meta['end_second'] = end_second
230+
time_meta['fps'] = fps
231+
time_meta['vid_path'] = vid_path
227232
return frames, '{}:{}'.format(verb, noun), time_meta
228233
else:
229234
raise NotImplementedError
@@ -271,10 +276,13 @@ def __init__(
271276
self.verb_maps = verb_maps
272277
self.noun_maps = noun_maps
273278
self.vn_list = list(self.label_mapping.keys())
279+
274280
self.labels = labels
275281
self.topk_predictions = topk_predictions
276282
self.ann_root = Path(metadata).parent
277283
self.mc_generator = MultiChoiceGenerator(self.ann_root)
284+
self.rank = dist.get_rank()
285+
self.prediction_analysis = PredictionAnalysis(f'prediction_analysis_buf_rank{self.rank}.json')
278286

279287
def __getitem__(self, i):
280288
frames, label, time_meta = self.get_raw_item(
@@ -299,8 +307,6 @@ def __getitem__(self, i):
299307

300308
data = self.mc_generator.generate_multi_choice(label, self.topk_predictions)
301309

302-
dataset_size = len(self.samples)
303-
304310
return frames, data, time_meta, i
305311

306312

@@ -330,10 +336,11 @@ def get_args_parser():
330336
# llm size is type of string and can only be '7b' or '5b' etc.
331337
parser.add_argument('--pretrained_name', default = '', type = str, help ='the name in huggingface')
332338
parser.add_argument('--llava_num_frames', default=16, type=int, help='number of frames for llava')
333-
## avaion refinement
339+
## avion refinement
334340
parser.add_argument('--action_predictions', default=None, type=str, help='path to action predictions')
335341
parser.add_argument('--topk_predictions', default = 5, type =int)
336342
parser.add_argument('--llava_checkpoint', default = None, type = str)
343+
parser.add_argument('--early_stop', default = None, type = int)
337344

338345
return parser
339346

@@ -438,7 +445,7 @@ def ensemble_llava_evaluation(
438445
rank0_print ('inspecting the counter', counter)
439446
rank0_print ('most common', counter.most_common(1)[0][0])
440447

441-
return match_answer(counter.most_common(1)[0][0], gt_name)
448+
return match_answer(counter.most_common(1)[0][0], gt_name), counter.most_common(1)[0][0]
442449

443450

444451

@@ -497,7 +504,7 @@ def evaluate_on_EK100(eval_args,
497504
print ('pretrained', pretrained)
498505

499506
# so we know it's evaluation during training
500-
finish_early = model is not None
507+
finish_early = False #model is not None
501508

502509
if model is None:
503510
if args.llava_checkpoint is not None:
@@ -508,26 +515,40 @@ def evaluate_on_EK100(eval_args,
508515
with open(eval_args.action_predictions, 'r') as f:
509516
predictions = json.load(f)
510517

511-
avaion_correct = torch.tensor(0, device='cuda')
512-
running_corrects = torch.tensor(0, device='cuda')
513-
total_samples = torch.tensor(0, device='cuda')
518+
device = torch.device(f'cuda:{rank}')
519+
520+
global_avion_correct = torch.tensor(0.0, device=device)
521+
global_running_corrects = torch.tensor(0.0, device=device)
522+
global_total_samples = torch.tensor(0.0, device=device)
523+
514524

515525
for idx, (frames, mc_data, time_meta, global_index) in tqdm(enumerate(val_dataloader)):
526+
527+
global_index = global_index.item()
528+
516529
gt_name = mc_data['gt_answer_name'][0][0]
530+
local_avion_correct = torch.tensor(0.0, device=device)
531+
local_running_corrects = torch.tensor(0.0, device=device)
532+
local_total_samples = torch.tensor(0.0, device=device)
517533

518534
if eval_args.action_predictions:
519-
mc_data = get_topk_predictions(predictions, global_index.item(), eval_args.topk_predictions)
535+
mc_data = get_topk_predictions(predictions, global_index, eval_args.topk_predictions)
520536
avion_pred = mc_data['avion_pred']
521537
if gt_name == avion_pred:
522-
avaion_correct+=1
538+
local_avion_correct.add_(1)
539+
global_avion_correct.add_(1)
523540

524541
# we don't want to evaluate the whole thing
525542
# let's evaluate 1000 samples to get the complete picture
526543
if finish_early and idx> (1000 / dist.get_world_size()):
527544
break
528545

546+
if eval_args.early_stop and idx > eval_args.early_stop:
547+
break
548+
529549
# Update running corrects and total samples
530-
running_corrects += ensemble_llava_evaluation(
550+
551+
llava_correct, llava_pred = ensemble_llava_evaluation(
531552
eval_args.pretrained_name,
532553
gt_name,
533554
frames,
@@ -541,33 +562,69 @@ def evaluate_on_EK100(eval_args,
541562
ensemble_k = 1,
542563
time_meta = time_meta,
543564
is_test = not finish_early)
565+
566+
# log the predictions into prediciton analysis
567+
568+
# val_dataset.prediction_analysis.log(global_index,
569+
# llava_pred,
570+
# gt_name,
571+
# predictions[str(global_index)],
572+
# time_meta['start_second'].item(),
573+
# time_meta['end_second'].item(),
574+
# time_meta['vid_path'],
575+
# dataset_name = 'EK100')
576+
577+
578+
579+
580+
local_running_corrects.add_(llava_correct)
581+
global_running_corrects.add_(llava_correct)
544582

545-
total_samples += 1
583+
local_total_samples.add_(1)
584+
global_total_samples.add_(1)
585+
586+
logger.info(f'Process {dist.get_rank()} - local_total_samples: {local_total_samples:.4f}')
587+
588+
logger.info(f'Process {dist.get_rank()} - loca_llava_correct: {llava_correct:.4f}')
589+
590+
logger.info(f'Process {dist.get_rank()} - local_running_corrects: {local_running_corrects:.4f}')
591+
546592

547593
# Calculate and log running mean accuracy
548-
running_accuracy = running_corrects / total_samples
594+
# dist.barrier()
595+
# dist.all_reduce(local_running_corrects, op=dist.ReduceOp.SUM)
596+
# dist.all_reduce(local_total_samples, op=dist.ReduceOp.SUM)
597+
# if eval_args.action_predictions:
598+
# dist.all_reduce(local_avion_correct, op=dist.ReduceOp.SUM)
599+
# dist.barrier()
600+
# # Calculate global accuracy after reduction
601+
# local_running_accuracy = local_running_corrects.item() / local_total_samples.item()
602+
# local_avion_accuracy = local_avion_correct.item() / local_total_samples.item()
603+
604+
# logger.info(f'Process {dist.get_rank()} - Running accuracy: {local_running_accuracy:.4f}')
605+
# logger.info(f'Process {dist.get_rank()} - AvionRunning accuracy: {local_avion_accuracy:.4f}')
549606

550-
logger.info(f'Process {dist.get_rank()} - Running accuracy: {running_accuracy:.4f}')
551-
if eval_args.action_predictions:
552-
avaion_accuracy = avaion_correct / total_samples
607+
553608

554609
dist.barrier()
555-
dist.all_reduce(running_corrects, op=dist.ReduceOp.SUM)
556-
dist.all_reduce(total_samples, op=dist.ReduceOp.SUM)
610+
dist.all_reduce(global_running_corrects, op=dist.ReduceOp.SUM)
611+
dist.all_reduce(global_total_samples, op=dist.ReduceOp.SUM)
557612
if eval_args.action_predictions:
558-
dist.all_reduce(avaion_correct, op=dist.ReduceOp.SUM)
613+
dist.all_reduce(global_avion_correct, op=dist.ReduceOp.SUM)
559614

560615
# Calculate global accuracy after reduction
561-
global_accuracy = running_corrects.item() / total_samples.item()
616+
global_accuracy = global_running_corrects.item() / global_total_samples.item()
562617
if eval_args.action_predictions:
563-
global_avaion_accuracy = avaion_correct.item() / total_samples.item()
618+
global_avion_accuracy = global_avion_correct.item() / global_total_samples.item()
564619

565620
# Ensure only the main process (rank 0) prints the final result
566621
if dist.get_rank() == 0:
567622
if eval_args.action_predictions:
568-
logger.info(f'Global Avaion Accuracy: {global_avaion_accuracy:.4f}')
623+
logger.info(f'Global Avion Accuracy: {global_avion_accuracy:.4f}')
569624
logger.info(f'Final Global Accuracy: {global_accuracy:.4f}')
570625

626+
#val_dataset.prediction_analysis.save()
627+
571628
return global_accuracy
572629

573630

action/prediction_analysis.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import json
2+
import glob
3+
4+
class PredictionAnalysis:
5+
"""
6+
We save data that can be used for ad-hoc analysis
7+
8+
We want to save the following:
9+
10+
# saving global index to make distributed code work better
11+
{global_index: {
12+
llava_pred: pred_name,
13+
gt_name: pred_name,
14+
avion_preds: avion_predictions,
15+
# to locate the video clip
16+
dataset_name: '',
17+
start_second: '',
18+
end_second: '',
19+
vid_path: ''
20+
}
21+
"""
22+
def __init__(self, save_path):
23+
self.save_path = save_path
24+
self.data = {}
25+
def log(self,
26+
global_index,
27+
llava_pred,
28+
gt_name,
29+
avion_preds,
30+
start_second,
31+
end_second,
32+
vid_path,
33+
dataset_name = 'EK100',
34+
):
35+
self.data[global_index] = {
36+
'llava_pred': llava_pred,
37+
'gt_name': gt_name,
38+
'avion_preds': avion_preds,
39+
'dataset_name' : dataset_name,
40+
'start_second' : start_second,
41+
'end_second': end_second,
42+
'vid_path': vid_path
43+
}
44+
45+
# print ('check what is here')
46+
# print (self.data[global_index])
47+
48+
def save(self):
49+
with open(self.save_path, 'w') as f:
50+
json.dump(self.data, f, indent = 4)
51+
52+
53+
class Analysis:
54+
"""
55+
56+
This same code should be applied to the training too.
57+
58+
collect all the wrong top-1 prediction from avion
59+
collect all the wrong top-1 prediction from llava
60+
61+
Determine percentage of wrong llava prediction that has wrong verb only
62+
Determine percentage of wrong llava prediction that has wrong noun only
63+
Determine percentage of wrong llava prediciton that has both verb and noun wrong
64+
Determine percentage of wrong llava prediction that was wrong because the answer not in the top k
65+
"""
66+
pass
67+
68+
def __init__(self, prefix):
69+
70+
files = glob.glob(prefix + '*')
71+
72+
self.data = {}
73+
74+
for file in files:
75+
print ('loading pred checkpoint from: ', file)
76+
with open(file, 'r') as f:
77+
_data = json.load(f)
78+
self.data.update(_data)
79+
80+
# add some assertion for number of keys in the data
81+
82+
def wrong_verb(self):
83+
84+
N = len(self.data)
85+
86+
wrong_verb_collections = []
87+
wrong_noun_collections = []
88+
wrong_verb_noun_collections = []
89+
90+
wrong_llava_collections = []
91+
wrong_avion_collections = []
92+
93+
indices = sorted(self.data.keys())
94+
95+
for index in indices:
96+
items = self.data[index]
97+
98+
99+
100+
if __name__ == '__main__':
101+
pass
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
datasets:
2+
- json_path: /data/shaokai/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl
3+
sampling_strategy: all

shaokai_generate_train.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# python3 action/generate_description.py \
22
# --train_metadata /data/shaokai/epic-kitchens-100-annotations/EPIC_100_train.csv \
33
# --out_folder /data/shaokai/EK100_avion_mc/ \
4-
# > train_gen.out 2>&1
4+
# --gen_type avion_mc \
5+
# --n_options 10 \
6+
# > train_gen.out 2>&1
57

68
python3 action/generate_description.py \
79
--train_metadata /storage-rcp-pure/upmwmathis_scratch/shaokai/epic-kitchens-100-annotations/EPIC_100_train.csv \
810
--out_folder /storage-rcp-pure/upmwmathis_scratch/shaokai/EK100_inst_train \
911
--avion_train_predictions /storage-rcp-pure/upmwmathis_scratch/shaokai/avion_predictions_train.json \
1012
--gen_type avion_mc \
11-
--n_options 3
13+
--n_options 10
1214

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#!/bin/bash
2+
3+
# Export environment variables
4+
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
5+
export OMP_NUM_THREADS="8"
6+
export NCCL_IB_DISABLE="0"
7+
export NCCL_IB_GID_INDEX="3"
8+
export NCCL_SOCKET_IFNAME="eth0"
9+
export NCCL_DEBUG="INFO"
10+
export ACCELERATE_CPU_AFFINITY="1"
11+
export WANDB_API_KEY="4474ec79de023b0c3ffb43588ab6163264f875db"
12+
export HF_HOME=/data/shaokai
13+
14+
15+
# Run the command using torchrun
16+
torchrun --nproc_per_node=8 \
17+
--nnodes=1 \
18+
--node_rank=0 \
19+
--master_addr=127.0.0.1 \
20+
--master_port=29500 \
21+
llava/train/train_mem.py \
22+
--deepspeed scripts/zero3.json \
23+
--model_name_or_path lmms-lab/LLaVA-Video-7B-Qwen2 \
24+
--version qwen_1_5 \
25+
--data_path scripts/train/EK100_avion_mc_top10.yaml \
26+
--video_folder /data/shaokai/\
27+
--mm_tunable_parts mm_vision_tower,mm_mlp_adapter,mm_language_model \
28+
--mm_vision_tower_lr 2e-6 \
29+
--vision_tower google/siglip-so400m-patch14-384 \
30+
--mm_projector_type mlp2x_gelu \
31+
--mm_vision_select_layer -2 \
32+
--mm_use_im_start_end False \
33+
--mm_use_im_patch_token False \
34+
--group_by_modality_length True \
35+
--image_aspect_ratio anyres_max_9 \
36+
--image_grid_pinpoints "(1x1),...,(6x6)" \
37+
--mm_patch_merge_type spatial_unpad \
38+
--bf16 True \
39+
--run_name shaokai_llava_video_7b_avion_mc_top10_5epochs \
40+
--output_dir experiments/shaokai_llava_video_7b_avion_mc_top10_5epochs \
41+
--num_train_epochs 5 \
42+
--per_device_train_batch_size 2 \
43+
--per_device_eval_batch_size 4 \
44+
--gradient_accumulation_steps 2 \
45+
--evaluation_strategy steps \
46+
--eval_steps 2000\
47+
--save_strategy steps \
48+
--save_steps 1000 \
49+
--learning_rate 1e-5 \
50+
--weight_decay 0. \
51+
--warmup_ratio 0.03 \
52+
--lr_scheduler_type cosine \
53+
--logging_steps 1 \
54+
--tf32 True \
55+
--model_max_length 32768 \
56+
--gradient_checkpointing True \
57+
--dataloader_num_workers 4 \
58+
--lazy_preprocess True \
59+
--report_to wandb \
60+
--torch_compile True \
61+
--torch_compile_backend inductor \
62+
--dataloader_drop_last True \
63+
--frames_upbound 32 \
64+
--root /data/shaokai/EK100 \
65+
--action_predictions /data/shaokai/avion_predictions_test.json \
66+
--val_metadata /data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv \
67+
--llava_num_frames 32 \
68+
--clip_length 32 \
69+
--topk_predictions 10 > train_llavavideo_kitchen_7b_avion_mc_32f_top10_5epochs.out 2>&1

0 commit comments

Comments
 (0)