Skip to content

Commit 616a056

Browse files
author
Ye Shaokai
committed
even cleaner code
1 parent a358be2 commit 616a056

File tree

9 files changed

+197
-33
lines changed

9 files changed

+197
-33
lines changed

action/ek_eval.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import json
1919
import logging
2020
from llava.utils import rank0_print
21-
from action.utils import generate_label_map, MultiChoiceGenerator, match_answer
21+
from action.utils import generate_label_map, MultiChoiceGenerator, match_answer, parse_avion_predictions
2222

2323
def datetime2sec(str):
2424
hh, mm, ss = str.split(':')
@@ -377,17 +377,10 @@ def get_topk_predictions(data, idx, k):
377377
options = list(range(26))[:k]
378378

379379
predictions = data[str(idx)]['predictions'][:k]
380-
new_predictions = []
381-
for pred in predictions:
382-
# the prediction looks like verb:noun1:noun2..
383-
# we want to it look like verb noun1:noun2
384-
first_sep = pred.index(':')
385-
prediction = pred[:first_sep] + ' ' + pred[first_sep+1:]
386-
new_predictions.append(prediction)
387-
388-
predictions = new_predictions
389-
for i in range(len(options)):
390-
380+
381+
predictions = parse_avion_predictions(predictions)
382+
383+
for i in range(len(options)):
391384
options[i] = f'{letters[i]}. {predictions[i]}'
392385

393386
mc_data = {

action/generate_description.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import csv
33
import os
44
import argparse
5-
from action.utils import generate_label_map, MultiChoiceGenerator
5+
from action.utils import generate_label_map, MultiChoiceGenerator, AvionMultiChoiceGenerator
66
from pathlib import Path
77

88

@@ -12,7 +12,7 @@ def datetime2sec(str):
1212
hh, mm, ss = str.split(':')
1313
return int(hh) * 3600 + int(mm) * 60 + float(ss)
1414

15-
def generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'naive'):
15+
def generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'naive', avion_prediction_path = ''):
1616
assert gen_type in GEN_TYPES
1717
# epic kitchen uses csv
1818
csv_reader = csv.reader(open(ann_file))
@@ -21,8 +21,12 @@ def generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'naive'):
2121
ann_root = Path(ann_file).parent
2222
if gen_type == "random_mc":
2323
mc_generator = MultiChoiceGenerator(ann_root)
24+
elif gen_type == 'avion_mc':
25+
mc_generator = AvionMultiChoiceGenerator(ann_root)
26+
with open(avion_prediction_path, 'r') as f:
27+
avion_train_predictions = json.load(f)
2428

25-
for row in csv_reader:
29+
for idx, row in enumerate(csv_reader):
2630
start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5])
2731

2832
pid, vid = row[1:3]
@@ -40,6 +44,14 @@ def generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'naive'):
4044
gt_answer_letter = mc_data['gt_answer_letter'][0]
4145
gt_answer_name = mc_data['gt_answer_name'][0]
4246
conversation = generate_random_mc_conversation(options, gt_answer_letter, gt_answer_name )
47+
elif gen_type == "avion_mc":
48+
vn_str = f'{row[10]}:{row[12]}'
49+
avion_preds = avion_train_predictions[str(idx)]['predictions']
50+
mc_data = mc_generator.generate_multi_choice(vn_str, avion_preds, 5)
51+
options = mc_data['option'][0]
52+
gt_answer_letter = mc_data['gt_answer_letter'][0]
53+
gt_answer_name = mc_data['gt_answer_name'][0]
54+
conversation = generate_random_mc_conversation(options, gt_answer_letter, gt_answer_name )
4355

4456
data = {'video': vid_path,
4557
'conversations': conversation,
@@ -67,24 +79,38 @@ def generate_random_mc_conversation(options:list[str], gt_answer_letter, gt_answ
6779
{"from": "gpt", "value": f"{gt_answer_letter}. {gt_answer_name}"}
6880
]
6981

82+
def generate_avion_mc_conversation():
83+
pass
84+
7085

7186
def get_args():
7287
parser = argparse.ArgumentParser(description="For generating VQA for EPIC-KITCHEN")
7388
parser.add_argument('--train_metadata', default='/data/shaokai/epic-kitchens-100-annotations/EPIC_100_train.csv', type=str)
7489
parser.add_argument('--out_folder', default = '/data/shaokai/EK100_in_LLAVA/', type = str)
90+
parser.add_argument('--avion_train_predictions', default = '/data/shaokai/avion_predictions_train.json', type = str)
91+
parser.add_argument('--gen_type', default = 'avion_mc', type = str, choices = GEN_TYPES)
7592
return parser.parse_args()
7693

77-
def main():
94+
def main():
7895
args = get_args()
7996
ann_file = args.train_metadata
80-
inst_train_folder = args.out_folder
81-
print (ann_file)
82-
anno_path = Path(ann_file).parent
83-
labels, mapping_vn2act, verb_ids, noun_ids = generate_label_map(anno_path)
84-
conv_lst = generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'random_mc')
85-
97+
inst_train_folder = os.path.join(args.out_folder, args.gen_type)
98+
99+
print ('train_metadata', args.train_metadata)
100+
print ('out_folder', args.out_folder)
101+
print ('loading predictions from ', args.avion_train_predictions)
102+
print ('gen_type is ', args.gen_type)
103+
86104
os.makedirs(inst_train_folder, exist_ok=True)
87105

106+
anno_path = Path(ann_file).parent
107+
_, _, verb_ids, noun_ids = generate_label_map(anno_path)
108+
conv_lst = generate_train_ann(ann_file,
109+
verb_ids,
110+
noun_ids,
111+
gen_type = args.gen_type,
112+
avion_prediction_path = args.avion_train_predictions)
113+
88114
# save it to a jsonl
89115
with open(os.path.join(inst_train_folder,'train_convs_narration.jsonl'), 'w') as f:
90116
for conv in conv_lst:

action/utils.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,25 @@ def match_answer(pred, gt):
4848
return pred.intersection(gt) == gt
4949

5050

51+
def parse_avion_predictions(predictions):
52+
new_predictions = []
53+
for pred in predictions:
54+
# the prediction looks like verb:noun1:noun2..
55+
# we want to it look like verb noun1:noun2
56+
first_sep = pred.index(':')
57+
prediction = pred[:first_sep] + ' ' + pred[first_sep+1:]
58+
new_predictions.append(prediction)
59+
return new_predictions
60+
61+
5162
class MultiChoiceGenerator:
5263
"""
5364
Generating multi choice
5465
"""
5566
def __init__(self, ann_root):
5667
self.ann_root = ann_root
5768
_, self.mapping_vn2act, self.verb_maps, self.noun_maps = generate_label_map(ann_root)
69+
5870

5971
def generate_multi_choice(self, gt_vn, k):
6072
"""
@@ -98,12 +110,73 @@ def generate_multi_choice(self, gt_vn, k):
98110

99111
return data
100112

113+
class AvionMultiChoiceGenerator(MultiChoiceGenerator):
114+
"""
115+
Generate multichoice using avion predictions
116+
"""
117+
def __init__(self, ann_root):
118+
super().__init__(ann_root)
119+
120+
def generate_multi_choice(self, gt_vn, avion_predictions, k):
121+
"""
122+
Generate k multiple choices from gt_vn pairs
123+
124+
randomly pick 1 letter for gt_vn
125+
randomly pick k-1 letters from vn_list that is not gt_vn (this is important as avion_predictions can contain correct prediction)
126+
127+
"""
128+
gt_v_id, gt_n_id = gt_vn.split(':')
129+
gt_v_name, gt_n_name = self.verb_maps[gt_v_id], self.noun_maps[gt_n_id]
130+
gt_answer = f'{gt_v_name} {gt_n_name}'
131+
132+
letters = [chr(65+i) for i in range(26)][:k]
133+
options = list(range(26))[:k]
134+
135+
# we should have plenty of predictions to select, so let's not always pick the hardest
136+
assert len(avion_predictions) > 2*k
137+
avion_predictions = avion_predictions[:k*2]
138+
avion_predictions = parse_avion_predictions(avion_predictions)
139+
if gt_answer in avion_predictions:
140+
avion_predictions.remove(gt_answer)
141+
# just so that it's not strictly desending with confidence
142+
random.shuffle(avion_predictions)
143+
avion_predictions = avion_predictions[:k-1]
144+
145+
answers = [gt_answer] + avion_predictions
146+
random.shuffle(answers)
147+
148+
options = []
149+
for answer, letter in zip(answers, letters):
150+
options.append(f'{letter}. {answer}')
151+
152+
gt_letter = letters[answers.index(gt_answer)]
153+
154+
data = {
155+
'question': {0: 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer'},
156+
'option': {0: options},
157+
# the correct letter in mc
158+
# for inspecting
159+
'gt_answer_letter': {0: gt_letter},
160+
'gt_answer_name': {0: gt_answer},
161+
'valid_letters': letters
162+
}
163+
return data
164+
101165

102166
if __name__ == '__main__':
103167

104168
anno_root = "/storage-rcp-pure/upmwmathis_scratch/shaokai/epic-kitchens-100-annotations/"
105-
generator = MultiChoiceGenerator(anno_root)
169+
#generator = MultiChoiceGenerator(anno_root)
170+
generator = AvionMultiChoiceGenerator(anno_root)
171+
import json
172+
173+
with open('/storage-rcp-pure/upmwmathis_scratch/shaokai/avion_predictions_train.json') as f:
174+
predictions = json.load(f)
175+
176+
print (len(predictions))
177+
print (predictions['0'])
178+
print (len(predictions['0']['predictions']))
106179

107-
print (generator.generate_multi_choice('3:3',5))
180+
print (generator.generate_multi_choice('3:3', predictions['0']['predictions'], 5))
108181

109182
pass

scripts/train/EK100.yaml

Lines changed: 0 additions & 3 deletions
This file was deleted.

scripts/train/EK100_avion_mc.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
datasets:
2+
- json_path: /data/shaokai/EK100_inst_train/avion_mc/train_convs_narration.jsonl
3+
sampling_strategy: all

scripts/train/EK100_random_mc.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
datasets:
2+
- json_path: /data/shaokai/EK100_inst_train/random_mc/train_convs_narration.jsonl
3+
sampling_strategy: all

shaokai_generate_train.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,7 @@
55

66
python3 action/generate_description.py \
77
--train_metadata /storage-rcp-pure/upmwmathis_scratch/shaokai/epic-kitchens-100-annotations/EPIC_100_train.csv \
8-
--out_folder /storage-rcp-pure/upmwmathis_scratch/shaokai/EK100_avion_mc # > train_gen.out 2>&1
8+
--out_folder /storage-rcp-pure/upmwmathis_scratch/shaokai/EK100_inst_train \
9+
--avion_train_predictions /storage-rcp-pure/upmwmathis_scratch/shaokai/avion_predictions_train.json \
10+
--gen_type avion_mc \
11+

shaokai_train_avion_mc.sh

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#!/bin/bash
2+
3+
# Export environment variables
4+
export CUDA_VISIBLE_DEVICES="0,1,2,3"
5+
export OMP_NUM_THREADS="8"
6+
export NCCL_IB_DISABLE="0"
7+
export NCCL_IB_GID_INDEX="3"
8+
export NCCL_SOCKET_IFNAME="eth0"
9+
export NCCL_DEBUG="INFO"
10+
export ACCELERATE_CPU_AFFINITY="1"
11+
export WANDB_API_KEY="4474ec79de023b0c3ffb43588ab6163264f875db"
12+
export HF_HOME=/data/shaokai
13+
14+
15+
# Run the command using torchrun
16+
torchrun --nproc_per_node=4 \
17+
--nnodes=1 \
18+
--node_rank=0 \
19+
--master_addr=127.0.0.1 \
20+
--master_port=29500 \
21+
llava/train/train_mem.py \
22+
--deepspeed scripts/zero3.json \
23+
--model_name_or_path lmms-lab/llava-onevision-qwen2-0.5b-ov \
24+
--version qwen_1_5 \
25+
--data_path scripts/train/EK100_avion_mc.yaml \
26+
--video_folder /data/shaokai/\
27+
--mm_tunable_parts mm_vision_tower,mm_mlp_adapter,mm_language_model \
28+
--mm_vision_tower_lr 2e-6 \
29+
--vision_tower google/siglip-so400m-patch14-384 \
30+
--mm_projector_type mlp2x_gelu \
31+
--mm_vision_select_layer -2 \
32+
--mm_use_im_start_end False \
33+
--mm_use_im_patch_token False \
34+
--group_by_modality_length True \
35+
--image_aspect_ratio anyres_max_9 \
36+
--image_grid_pinpoints "(1x1),...,(6x6)" \
37+
--mm_patch_merge_type spatial_unpad \
38+
--bf16 True \
39+
--run_name shaokai_llama_ov_0.5b_avion_mc \
40+
--output_dir experiments/shaokai_llama_ov_0.5b_avion_mc \
41+
--num_train_epochs 1 \
42+
--per_device_train_batch_size 1 \
43+
--per_device_eval_batch_size 4 \
44+
--gradient_accumulation_steps 2 \
45+
--evaluation_strategy steps \
46+
--eval_steps 500\
47+
--save_strategy steps \
48+
--save_steps 1000 \
49+
--learning_rate 1e-5 \
50+
--weight_decay 0. \
51+
--warmup_ratio 0.03 \
52+
--lr_scheduler_type cosine \
53+
--logging_steps 1 \
54+
--tf32 True \
55+
--model_max_length 32768 \
56+
--gradient_checkpointing True \
57+
--dataloader_num_workers 4 \
58+
--lazy_preprocess True \
59+
--report_to wandb \
60+
--torch_compile True \
61+
--torch_compile_backend inductor \
62+
--dataloader_drop_last True \
63+
--frames_upbound 32 \
64+
--root /data/shaokai/EK100 \
65+
--action_predictions /data/shaokai/avaion_predictions_test.json \
66+
--val_metadata /data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv \
67+
--llava_num_frames 16 \
68+
--topk_predictions 5 > train_kitchen_0.5b_avion_mc.out 2>&1
Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ export NCCL_IB_GID_INDEX="3"
88
export NCCL_SOCKET_IFNAME="eth0"
99
export NCCL_DEBUG="INFO"
1010
export ACCELERATE_CPU_AFFINITY="1"
11-
# export LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libffi.so.7"
1211
export WANDB_API_KEY="4474ec79de023b0c3ffb43588ab6163264f875db"
13-
experiment_name="shaokai_llama_ov_0.5b_debug"
1412
export HF_HOME=/data/shaokai
1513

1614

@@ -24,7 +22,7 @@ torchrun --nproc_per_node=4 \
2422
--deepspeed scripts/zero3.json \
2523
--model_name_or_path lmms-lab/llava-onevision-qwen2-0.5b-ov \
2624
--version qwen_1_5 \
27-
--data_path scripts/train/EK100.yaml \
25+
--data_path scripts/train/EK100_random_mc.yaml \
2826
--video_folder /data/shaokai/\
2927
--mm_tunable_parts mm_vision_tower,mm_mlp_adapter,mm_language_model \
3028
--mm_vision_tower_lr 2e-6 \
@@ -64,7 +62,7 @@ torchrun --nproc_per_node=4 \
6462
--dataloader_drop_last True \
6563
--frames_upbound 32 \
6664
--root /data/shaokai/EK100 \
67-
--action_predictions /data/shaokai/avaion_predictions.json \
65+
--action_predictions /data/shaokai/avaion_predictions_test.json \
6866
--val_metadata /data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv \
6967
--llava_num_frames 16 \
70-
--topk_predictions 5 > train_kitchen_0.5b.out 2>&1
68+
--topk_predictions 5 > train_kitchen_0.5b_random_mc.out 2>&1

0 commit comments

Comments
 (0)