Skip to content

Commit dc9f524

Browse files
author
Ye Shaokai
committed
updates
1 parent fabd803 commit dc9f524

File tree

5 files changed

+16
-11
lines changed

5 files changed

+16
-11
lines changed

action/ek_eval.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,10 @@ def get_topk_predictions(data, idx, k):
390390

391391
return mc_data, predictions[0]
392392

393-
def evaluate_on_EK100(eval_args, model= None, tokenizer= None, max_length= None, image_processor= None):
393+
def evaluate_on_EK100(eval_args,
394+
model= None,
395+
tokenizer= None,
396+
image_processor= None):
394397

395398
if image_processor is None:
396399
image_processor = model.get_vision_tower().image_processor
@@ -474,7 +477,7 @@ def evaluate_on_EK100(eval_args, model= None, tokenizer= None, max_length= None,
474477
if finish_early and idx>999:
475478
break
476479

477-
pred = llava_inference(frames, tokenizer, model, image_processor, max_length, mc_data, clip_length = eval_args.clip_length, num_frames=eval_args.llava_num_frames)
480+
pred = llava_inference(frames, tokenizer, model, image_processor, mc_data, clip_length = eval_args.clip_length, num_frames=eval_args.llava_num_frames)
478481

479482
# if valid letter is found in the prediction, then we will use that as the prediction
480483
rank0_print ('llava pred', pred, 'avion_pred', avion_pred, 'gt_name', gt_name)

action/generate_description.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def datetime2sec(str):
1212
hh, mm, ss = str.split(':')
1313
return int(hh) * 3600 + int(mm) * 60 + float(ss)
1414

15-
def generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'naive', avion_prediction_path = ''):
15+
def generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'naive', avion_prediction_path = '', n_options = 5):
1616
assert gen_type in GEN_TYPES
1717
# epic kitchen uses csv
1818
csv_reader = csv.reader(open(ann_file))
@@ -39,15 +39,15 @@ def generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'naive', avion_p
3939
elif gen_type == "random_mc":
4040
# here we use the index
4141
vn_str = f'{row[10]}:{row[12]}'
42-
mc_data = mc_generator.generate_multi_choice(vn_str, 5)
42+
mc_data = mc_generator.generate_multi_choice(vn_str, n_options)
4343
options = mc_data['option'][0]
4444
gt_answer_letter = mc_data['gt_answer_letter'][0]
4545
gt_answer_name = mc_data['gt_answer_name'][0]
4646
conversation = generate_random_mc_conversation(options, gt_answer_letter, gt_answer_name )
4747
elif gen_type == "avion_mc":
4848
vn_str = f'{row[10]}:{row[12]}'
4949
avion_preds = avion_train_predictions[str(idx)]['predictions']
50-
mc_data = mc_generator.generate_multi_choice(vn_str, avion_preds, 5)
50+
mc_data = mc_generator.generate_multi_choice(vn_str, avion_preds, n_options)
5151
options = mc_data['option'][0]
5252
gt_answer_letter = mc_data['gt_answer_letter'][0]
5353
gt_answer_name = mc_data['gt_answer_name'][0]
@@ -86,27 +86,30 @@ def get_args():
8686
parser.add_argument('--out_folder', default = '/data/shaokai/EK100_in_LLAVA/', type = str)
8787
parser.add_argument('--avion_train_predictions', default = '/data/shaokai/avion_predictions_train.json', type = str)
8888
parser.add_argument('--gen_type', default = 'avion_mc', type = str, choices = GEN_TYPES)
89+
parser.add_argument('--n_options', default = 5, type = int)
8990
return parser.parse_args()
9091

9192
def main():
9293
args = get_args()
9394
ann_file = args.train_metadata
94-
inst_train_folder = os.path.join(args.out_folder, args.gen_type)
95+
inst_train_folder = os.path.join(args.out_folder, f'{args.gen_type}_top{args.n_options}')
9596

9697
print ('train_metadata', args.train_metadata)
9798
print ('out_folder', args.out_folder)
9899
print ('loading predictions from ', args.avion_train_predictions)
99100
print ('gen_type is ', args.gen_type)
101+
print ('n_options', args.n_options)
100102

101-
os.makedirs(inst_train_folder, exist_ok=True)
103+
os.makedirs(inst_train_folder, exist_ok=True)
102104

103105
anno_path = Path(ann_file).parent
104106
_, _, verb_ids, noun_ids = generate_label_map(anno_path)
105107
conv_lst = generate_train_ann(ann_file,
106108
verb_ids,
107109
noun_ids,
108110
gen_type = args.gen_type,
109-
avion_prediction_path = args.avion_train_predictions)
111+
avion_prediction_path = args.avion_train_predictions,
112+
n_options = args.n_options)
110113

111114
# save it to a jsonl
112115
with open(os.path.join(inst_train_folder,'train_convs_narration.jsonl'), 'w') as f:

action/llava_ov_inference.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def llava_inference(video_frames,
1818
tokenizer,
1919
model,
2020
image_processor,
21-
max_length,
2221
mc_data,
2322
clip_length = 16,
2423
num_frames=16):
@@ -29,7 +28,6 @@ def llava_inference(video_frames,
2928
temporal_stride = clip_length // num_frames
3029
video_frames = video_frames[::temporal_stride]
3130
image_tensors = []
32-
#frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].half().cuda()
3331
frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].cuda().to(torch.bfloat16)
3432
image_tensors.append(frames)
3533

llava/train/llava_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def __init__(self, *args, tokenizer = None, eval_args = None, model_max_length =
251251
def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
252252
from action.ek_eval import evaluate_on_EK100
253253

254-
accuracy = evaluate_on_EK100(self.eval_args, self.model, self.tokenizer, self.model_max_length)
254+
accuracy = evaluate_on_EK100(self.eval_args, self.model, self.tokenizer)
255255

256256
metrics = {f"{metric_key_prefix}_EK100_accuracy": accuracy}
257257

shaokai_generate_train.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ python3 action/generate_description.py \
88
--out_folder /storage-rcp-pure/upmwmathis_scratch/shaokai/EK100_inst_train \
99
--avion_train_predictions /storage-rcp-pure/upmwmathis_scratch/shaokai/avion_predictions_train.json \
1010
--gen_type avion_mc \
11+
--n_options 10
1112

0 commit comments

Comments
 (0)