Skip to content

Commit 9eb7e9e

Browse files
author
Ye Shaokai
committed
multi choice training
1 parent 74c835e commit 9eb7e9e

File tree

9 files changed

+173
-121
lines changed

9 files changed

+173
-121
lines changed

action/ek_eval.py

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import json
1919
import logging
2020
from llava.utils import rank0_print
21-
21+
from action.utils import generate_label_map
2222

2323
def datetime2sec(str):
2424
hh, mm, ss = str.split(':')
@@ -453,8 +453,7 @@ def __getitem__(self, i):
453453

454454
# randomly sample topk actions from valid gts
455455

456-
wrong_answer_indices = np.random.choice(len(self.valid_gts), size = self.eval_args.topk_predictions, replace = False)
457-
456+
wrong_answer_indices = np.random.choice(len(self.valid_gts), size = self.eval_args.topk_predictions, replace = False)
458457
wrong_answers = [self.valid_gts[index] for index in wrong_answer_indices]
459458

460459
for i in range(len(wrong_answers)):
@@ -501,37 +500,6 @@ def get_downstream_dataset(transform, crop_size, eval_args, subset='train', labe
501500
assert ValueError("subset should be either 'train' or 'val'")
502501

503502

504-
def generate_label_map(eval_args):
505-
print("Preprocess ek100 action label space")
506-
vn_list = []
507-
mapping_vn2narration = {}
508-
anno_root = Path(eval_args.val_metadata).parent
509-
for f in [
510-
anno_root / 'EPIC_100_train.csv',
511-
anno_root / 'EPIC_100_validation.csv',
512-
]:
513-
csv_reader = csv.reader(open(f))
514-
_ = next(csv_reader) # skip the header
515-
for row in csv_reader:
516-
517-
vn = '{}:{}'.format(int(row[10]), int(row[12]))
518-
narration = row[8]
519-
if vn not in vn_list:
520-
vn_list.append(vn)
521-
if vn not in mapping_vn2narration:
522-
mapping_vn2narration[vn] = [narration]
523-
else:
524-
mapping_vn2narration[vn].append(narration)
525-
# mapping_vn2narration[vn] = [narration]
526-
vn_list = sorted(vn_list)
527-
print('# of action= {}'.format(len(vn_list)))
528-
mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
529-
530-
labels = [list(set(mapping_vn2narration[vn_list[i]])) for i in range(len(mapping_vn2act))]
531-
print(labels[:5])
532-
return labels, mapping_vn2act
533-
534-
535503
def get_args_parser():
536504
parser = argparse.ArgumentParser(description='AVION finetune ek100 cls', add_help=False)
537505
parser.add_argument('--dataset', default='ek100_cls', type=str, choices=['ek100_mir'])
@@ -596,9 +564,7 @@ def get_topk_predictions(data, idx, k):
596564
'option': {0: options}
597565
}
598566

599-
return mc_data, predictions, target
600-
601-
567+
return mc_data, predictions, target
602568

603569
def evaluate_on_EK100(eval_args, model= None, tokenizer= None, max_length= None, image_processor= None):
604570

@@ -611,7 +577,7 @@ def evaluate_on_EK100(eval_args, model= None, tokenizer= None, max_length= None,
611577

612578
crop_size = 336
613579

614-
labels, mapping_vn2act = generate_label_map(eval_args)
580+
labels, mapping_vn2act, _, _ = generate_label_map(Path(eval_args.val_metadata).parent)
615581
val_dataset = get_downstream_dataset(
616582
val_transform_gpu, crop_size, eval_args, subset='val', label_mapping=mapping_vn2act,
617583
labels = labels

action/generate_description.py

Lines changed: 51 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,46 @@
11
import json
22
import csv
33
import os
4+
import argparse
5+
from action.utils import generate_label_map, MultiChoiceGenerator
6+
from pathlib import Path
7+
8+
9+
GEN_TYPES = ['naive', 'random_mc', 'avion_mc']
410

511
def datetime2sec(str):
612
hh, mm, ss = str.split(':')
713
return int(hh) * 3600 + int(mm) * 60 + float(ss)
814

9-
def generate_label_map(dataset):
10-
if dataset == 'ek100_cls':
11-
print("Preprocess ek100 action label space")
12-
vn_list = []
13-
mapping_vn2narration = {}
14-
verb_ids = {}
15-
noun_ids = {}
16-
for f in [
17-
'/data/shaokai/epic-kitchens-100-annotations/EPIC_100_train.csv',
18-
'/data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv',
19-
]:
20-
csv_reader = csv.reader(open(f))
21-
_ = next(csv_reader) # skip the header
22-
for row in csv_reader:
23-
vn = '{}:{}'.format(int(row[10]), int(row[12]))
24-
narration = row[8]
25-
if row[10] not in verb_ids.keys():
26-
verb_ids[row[10]] = row[9]
27-
if row[12] not in noun_ids.keys():
28-
noun_ids[row[12]] = row[11]
29-
if vn not in vn_list:
30-
vn_list.append(vn)
31-
if vn not in mapping_vn2narration:
32-
mapping_vn2narration[vn] = [narration]
33-
else:
34-
mapping_vn2narration[vn].append(narration)
35-
# mapping_vn2narration[vn] = [narration]
36-
vn_list = sorted(vn_list)
37-
print('# of action= {}'.format(len(vn_list)))
38-
mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
39-
labels = [list(set(mapping_vn2narration[vn_list[i]])) for i in range(len(mapping_vn2act))]
40-
print(labels[:5])
41-
elif dataset == 'charades_ego':
42-
print("=> preprocessing charades_ego action label space")
43-
vn_list = []
44-
labels = []
45-
with open('datasets/CharadesEgo/CharadesEgo/Charades_v1_classes.txt') as f:
46-
csv_reader = csv.reader(f)
47-
for row in csv_reader:
48-
vn = row[0][:4]
49-
vn_list.append(vn)
50-
narration = row[0][5:]
51-
labels.append(narration)
52-
mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
53-
print(labels[:5])
54-
elif dataset == 'egtea':
55-
print("=> preprocessing egtea action label space")
56-
labels = []
57-
with open('datasets/EGTEA/action_idx.txt') as f:
58-
for row in f:
59-
row = row.strip()
60-
narration = ' '.join(row.split(' ')[:-1])
61-
labels.append(narration.replace('_', ' ').lower())
62-
# labels.append(narration)
63-
mapping_vn2act = {label: i for i, label in enumerate(labels)}
64-
print(len(labels), labels[:5])
65-
else:
66-
raise NotImplementedError
67-
return labels, mapping_vn2act, verb_ids, noun_ids
68-
69-
70-
def parse_train_ann(ann_file, verb_ids, noun_ids):
15+
def generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'naive'):
16+
assert gen_type in GEN_TYPES
7117
# epic kitchen uses csv
7218
csv_reader = csv.reader(open(ann_file))
7319
_ = next(csv_reader)
7420
ret = []
21+
ann_root = Path(ann_file).parent
22+
if gen_type == "random_mc":
23+
mc_generator = MultiChoiceGenerator(ann_root)
24+
7525
for row in csv_reader:
76-
# start_frame, end_frame = row[6], row[7]
7726
start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5])
78-
narration = f'{verb_ids[row[10]]} {noun_ids[row[12]]}'
27+
7928
pid, vid = row[1:3]
80-
vid_path = '{}-{}'.format(pid, vid)
81-
conversation = generate_naive_conversation(narration)
29+
vid_path = '{}-{}'.format(pid, vid)
30+
31+
if gen_type == 'naive':
32+
# here we directly use the names
33+
verb_noun = f'{verb_ids[row[10]]} {noun_ids[row[12]]}'
34+
conversation = generate_naive_conversation(verb_noun)
35+
elif gen_type == "random_mc":
36+
# here we use the index
37+
vn_str = f'{row[10]}:{row[12]}'
38+
mc_data = mc_generator.generate_multi_choice(vn_str, 5)
39+
options = mc_data['option'][0]
40+
gt_answer_letter = mc_data['gt_answer_letter'][0]
41+
gt_answer_name = mc_data['gt_answer_name'][0]
42+
conversation = generate_random_mc_conversation(options, gt_answer_letter, gt_answer_name )
43+
8244
data = {'video': vid_path,
8345
'conversations': conversation,
8446
'id': vid_path,
@@ -92,19 +54,35 @@ def parse_train_ann(ann_file, verb_ids, noun_ids):
9254
ret.append(data)
9355
return ret
9456

95-
def generate_naive_conversation(narration):
57+
def generate_naive_conversation(vn_str:str):
9658
# in this version, we do not care about diversifying the questions
9759
return [
9860
{"from": "human", "value": "<image>\n the video is taken from egocentric view. What action is the person performing? Hint: provide your answer in verb-noun pair. "},
99-
{"from": "gpt", "value": f"{narration}"}
61+
{"from": "gpt", "value": f"{vn_str}"}
10062
]
10163

102-
def main():
64+
def generate_random_mc_conversation(options:list[str], gt_answer_letter, gt_answer_name):
65+
return [
66+
{"from": "human", "value": f"<image>\n the video is taken from egocentric view. What action is the person performing? Please select the letter for the right answer {options}"},
67+
{"from": "gpt", "value": f"{gt_answer_letter}. {gt_answer_name}"}
68+
]
69+
70+
71+
def get_args():
72+
parser = argparse.ArgumentParser(description="For generating VQA for EPIC-KITCHEN")
73+
parser.add_argument('--train_metadata', default='/data/shaokai/epic-kitchens-100-annotations/EPIC_100_train.csv', type=str)
74+
parser.add_argument('--out_folder', default = '/data/shaokai/EK100_in_LLAVA/', type = str)
75+
return parser.parse_args()
76+
77+
def main():
78+
args = get_args()
79+
ann_file = args.train_metadata
80+
inst_train_folder = args.out_folder
81+
print (ann_file)
82+
anno_path = Path(ann_file).parent
83+
labels, mapping_vn2act, verb_ids, noun_ids = generate_label_map(anno_path)
84+
conv_lst = generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'random_mc')
10385

104-
ann_file = "/data/shaokai/epic-kitchens-100-annotations/EPIC_100_train.csv"
105-
labels, mapping_vn2act, verb_ids, noun_ids = generate_label_map('ek100_cls')
106-
conv_lst = parse_train_ann(ann_file, verb_ids, noun_ids)
107-
inst_train_folder = '/data/shaokai/EK100_in_LLAVA/'
10886
os.makedirs(inst_train_folder, exist_ok=True)
10987

11088
# save it to a jsonl
@@ -113,6 +91,5 @@ def main():
11391
f.write(json.dumps(conv) + '\n')
11492

11593

116-
11794
if __name__ == "__main__":
11895
main()

action/utils.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import csv
2+
import numpy as np
3+
import random
4+
import os
5+
6+
def generate_label_map(anno_root):
7+
print("Preprocess ek100 action label space")
8+
vn_list = []
9+
mapping_vn2narration = {}
10+
# from id to name
11+
verb_maps = {}
12+
noun_maps = {}
13+
for f in [
14+
os.path.join(anno_root,'EPIC_100_train.csv'),
15+
os.path.join(anno_root, 'EPIC_100_validation.csv'),
16+
]:
17+
csv_reader = csv.reader(open(f))
18+
_ = next(csv_reader) # skip the header
19+
for row in csv_reader:
20+
21+
vn = '{}:{}'.format(int(row[10]), int(row[12]))
22+
narration = row[8]
23+
if row[10] not in verb_maps.keys():
24+
verb_maps[row[10]] = row[9]
25+
if row[12] not in noun_maps.keys():
26+
noun_maps[row[12]] = row[11]
27+
28+
if vn not in vn_list:
29+
vn_list.append(vn)
30+
if vn not in mapping_vn2narration:
31+
mapping_vn2narration[vn] = [narration]
32+
else:
33+
mapping_vn2narration[vn].append(narration)
34+
# mapping_vn2narration[vn] = [narration]
35+
vn_list = sorted(vn_list)
36+
print('# of action= {}'.format(len(vn_list)))
37+
mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
38+
39+
labels = [list(set(mapping_vn2narration[vn_list[i]])) for i in range(len(mapping_vn2act))]
40+
return labels, mapping_vn2act, verb_maps, noun_maps
41+
42+
43+
44+
class MultiChoiceGenerator:
45+
"""
46+
Generating multi choice
47+
"""
48+
def __init__(self, ann_root):
49+
self.ann_root = ann_root
50+
_, self.mapping_vn2act, self.verb_maps, self.noun_maps = generate_label_map(ann_root)
51+
52+
def generate_multi_choice(self, gt_vn, k):
53+
"""
54+
Generate k multiple choices from gt_vn pairs
55+
56+
randomly pick 1 letter for gt_vn
57+
randomly pick k-1 letters from vn_list
58+
59+
"""
60+
61+
# let v_id and n_id be string type
62+
gt_v_id, gt_n_id = gt_vn.split(':')
63+
assert isinstance(gt_v_id, str) and isinstance(gt_n_id, str)
64+
gt_v_name, gt_n_name = self.verb_maps[gt_v_id], self.noun_maps[gt_n_id]
65+
66+
# letters as A, B, C, D, .. Note we maximally support 26 letters
67+
letters = [chr(65+i) for i in range(26)][:k]
68+
options = list(range(26))[:k]
69+
vn_list = list(self.mapping_vn2act.keys())
70+
action_list = [f"{self.verb_maps[e.split(':')[0]]} {self.noun_maps[e.split(':')[1]]}" for e in vn_list]
71+
wrong_answers = np.random.choice(action_list, size = k-1, replace = False)
72+
gt_answer = f'{gt_v_name} {gt_n_name}'
73+
74+
answers = [gt_answer] + list(wrong_answers)
75+
random.shuffle(answers)
76+
77+
options = []
78+
for answer, letter in zip(answers, letters):
79+
options.append(f'{letter}. {answer}')
80+
81+
gt_letter = letters[answers.index(gt_answer)]
82+
data = {
83+
'question': {0: 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer'},
84+
'option': {0: options},
85+
# the correct letter in mc
86+
# for inspecting
87+
'gt_answer_letter': {0: gt_letter},
88+
'gt_answer_name': {0: gt_answer}
89+
}
90+
91+
return data
92+
93+
94+
if __name__ == '__main__':
95+
96+
anno_root = "/storage-rcp-pure/upmwmathis_scratch/shaokai/epic-kitchens-100-annotations/"
97+
generator = MultiChoiceGenerator(anno_root)
98+
99+
print (generator.generate_multi_choice('3:3',5))
100+
101+
pass

llava/model/builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ def load_from_hf(repo_id, filename, subfolder=None):
216216
else:
217217
from llava.model.language_model.llava_qwen import LlavaQwenConfig
218218

219-
if overwrite_config is not None:
219+
#if overwrite_config is not None:
220+
if True:
220221
llava_cfg = LlavaQwenConfig.from_pretrained(model_path)
221-
rank0_print(f"Overwriting config with {overwrite_config}")
222222
for k, v in overwrite_config.items():
223223
setattr(llava_cfg, k, v)
224224
model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)

llava/model/language_model/llava_qwen.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
4747
config_class = LlavaQwenConfig
4848

4949
def __init__(self, config):
50-
# super(Qwen2ForCausalLM, self).__init__(config)
51-
print ('what does config look like')
52-
print (config)
50+
# super(Qwen2ForCausalLM, self).__init__(config)
5351
Qwen2ForCausalLM.__init__(self, config)
5452

5553
config.model_type = "llava_qwen"

llava/train/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1733,7 +1733,9 @@ def make_inputs_require_grad(module, input, output):
17331733

17341734

17351735
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
1736-
trainer.train(resume_from_checkpoint=True)
1736+
#trainer.train(resume_from_checkpoint=True)
1737+
# for debug purpose, let's not resume
1738+
trainer.train()
17371739
else:
17381740
trainer.train()
17391741
trainer.save_state()

scripts/train/EK100.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
datasets:
2-
- json_path: /data/shaokai/EK100_in_LLAVA/train_convs_narration.jsonl
2+
- json_path: /data/shaokai/EK100_avion_mc/train_convs_narration.jsonl
33
sampling_strategy: all

shaokai_generate_train.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# python3 action/generate_description.py \
2+
# --train_metadata /data/shaokai/epic-kitchens-100-annotations/EPIC_100_train.csv \
3+
# --out_folder /data/shaokai/EK100_avion_mc/ \
4+
# > train_gen.out 2>&1
5+
6+
python3 action/generate_description.py \
7+
--train_metadata /storage-rcp-pure/upmwmathis_scratch/shaokai/epic-kitchens-100-annotations/EPIC_100_train.csv \
8+
--out_folder /storage-rcp-pure/upmwmathis_scratch/shaokai/EK100_avion_mc # > train_gen.out 2>&1

0 commit comments

Comments
 (0)