Skip to content

Commit 21742a4

Browse files
author
Ye Shaokai
committed
able to do ensemble evaluation
1 parent dc9f524 commit 21742a4

File tree

6 files changed

+100
-26
lines changed

6 files changed

+100
-26
lines changed

action/ek_eval.py

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import logging
2020
from llava.utils import rank0_print
2121
from action.utils import generate_label_map, MultiChoiceGenerator, match_answer, parse_avion_predictions
22+
import copy
23+
from collections import Counter
2224

2325
def datetime2sec(str):
2426
hh, mm, ss = str.split(':')
@@ -370,40 +372,93 @@ def prepare_llava(pretrained):
370372

371373
return tokenizer, model, image_processor, max_length
372374

373-
374-
def get_topk_predictions(data, idx, k):
375+
def get_topk_predictions(data, idx, k):
375376

376377
letters = [chr(65+i) for i in range(26)][:k]
377378
options = list(range(26))[:k]
378379

379380
predictions = data[str(idx)]['predictions'][:k]
380-
381381
predictions = parse_avion_predictions(predictions)
382382

383383
for i in range(len(options)):
384384
options[i] = f'{letters[i]}. {predictions[i]}'
385385

386386
mc_data = {
387387
'question': {0: 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer.'},
388-
'option': {0: options}
388+
'options': {0: options},
389+
'valid_letters': letters,
390+
'avion_pred': predictions[0]
389391
}
392+
393+
return mc_data
394+
395+
def ensemble_llava_evaluation(gt_name,
396+
frames,
397+
tokenizer,
398+
model,
399+
image_processor,
400+
mc_data,
401+
clip_length,
402+
num_frames,
403+
temperature = 0,
404+
ensemble_k = 1,
405+
is_test = False
406+
):
407+
"""
408+
This function tests how consistent the model is if we shuffle the position of the answers
409+
It also should use a higher temperature so we might get better performance by ensemble
410+
"""
411+
412+
# shuffle the options
413+
options = mc_data['options'][0]
414+
letters = mc_data['valid_letters']
415+
avion_pred = mc_data['avion_pred']
416+
# each option was in the format of {letter}. {answer}
417+
preds = []
418+
for _ in range(ensemble_k):
419+
# let's just shuffle the options
420+
random.shuffle(options)
421+
for idx, (option, letter) in enumerate(zip(options, letters)):
422+
sep = option.index('.')
423+
options[idx] = f'{letter}.{option[sep+1:]}'
424+
rank0_print ('generated new option sequence')
425+
rank0_print (options)
426+
427+
pred = llava_inference(frames,
428+
tokenizer,
429+
model,
430+
image_processor,
431+
mc_data,
432+
clip_length = clip_length,
433+
num_frames=num_frames,
434+
temperature = temperature,
435+
is_test = is_test
436+
)
437+
438+
rank0_print ('llava pred', pred, 'avion_pred', avion_pred, 'gt_name', gt_name)
439+
sep = pred.index('.')
440+
pred = pred[sep+1:].strip()
441+
preds.append(pred)
442+
443+
counter = Counter(preds)
444+
rank0_print ('inspecting the counter', counter)
445+
rank0_print ('most common', counter.most_common(1)[0][0])
446+
447+
return match_answer(counter.most_common(1)[0][0], gt_name)
448+
390449

391-
return mc_data, predictions[0]
392450

393451
def evaluate_on_EK100(eval_args,
394452
model= None,
395453
tokenizer= None,
396454
image_processor= None):
397455

398-
if image_processor is None:
456+
if model is not None:
399457
image_processor = model.get_vision_tower().image_processor
400458

401459
gpu_val_transform_ls = []
402-
403460
val_transform_gpu = torch.nn.Sequential(*gpu_val_transform_ls)
404-
405461
crop_size = 336
406-
407462
labels, mapping_vn2act, verb_maps, noun_maps = generate_label_map(Path(eval_args.val_metadata).parent)
408463

409464
val_dataset = VideoMultiChoiceDataset(
@@ -468,7 +523,8 @@ def evaluate_on_EK100(eval_args,
468523
gt_name = mc_data['gt_answer_name'][0][0]
469524

470525
if eval_args.action_predictions:
471-
mc_data, avion_pred = get_topk_predictions(predictions, idx, eval_args.topk_predictions)
526+
mc_data = get_topk_predictions(predictions, idx, eval_args.topk_predictions)
527+
avion_pred = mc_data['avion_pred']
472528
if gt_name == avion_pred:
473529
avaion_correct+=1
474530

@@ -477,18 +533,30 @@ def evaluate_on_EK100(eval_args,
477533
if finish_early and idx>999:
478534
break
479535

480-
pred = llava_inference(frames, tokenizer, model, image_processor, mc_data, clip_length = eval_args.clip_length, num_frames=eval_args.llava_num_frames)
536+
# pred = llava_inference(frames, tokenizer, model, image_processor, mc_data, clip_length = eval_args.clip_length, num_frames=eval_args.llava_num_frames)
481537

482-
# if valid letter is found in the prediction, then we will use that as the prediction
483-
rank0_print ('llava pred', pred, 'avion_pred', avion_pred, 'gt_name', gt_name)
538+
# # if valid letter is found in the prediction, then we will use that as the prediction
539+
# rank0_print ('llava pred', pred, 'avion_pred', avion_pred, 'gt_name', gt_name)
484540

485541
# Update running corrects and total samples
486-
running_corrects += (match_answer(pred, gt_name))
542+
running_corrects += ensemble_llava_evaluation(gt_name,
543+
frames,
544+
tokenizer,
545+
model,
546+
image_processor,
547+
mc_data,
548+
eval_args.clip_length,
549+
eval_args.llava_num_frames,
550+
temperature = 2.0,
551+
ensemble_k = 5,
552+
is_test = not finish_early)
553+
487554
total_samples += 1
488555

489556
# Calculate and log running mean accuracy
490557
running_accuracy = running_corrects / total_samples
491558

559+
logger.info(f'running accuracy: {running_accuracy:.4f}')
492560
if eval_args.action_predictions:
493561
avaion_accuracy = avaion_correct / total_samples
494562

action/generate_description.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ def generate_train_ann(ann_file, verb_ids, noun_ids, gen_type = 'naive', avion_p
4040
# here we use the index
4141
vn_str = f'{row[10]}:{row[12]}'
4242
mc_data = mc_generator.generate_multi_choice(vn_str, n_options)
43-
options = mc_data['option'][0]
43+
options = mc_data['options'][0]
4444
gt_answer_letter = mc_data['gt_answer_letter'][0]
4545
gt_answer_name = mc_data['gt_answer_name'][0]
4646
conversation = generate_random_mc_conversation(options, gt_answer_letter, gt_answer_name )
4747
elif gen_type == "avion_mc":
4848
vn_str = f'{row[10]}:{row[12]}'
4949
avion_preds = avion_train_predictions[str(idx)]['predictions']
5050
mc_data = mc_generator.generate_multi_choice(vn_str, avion_preds, n_options)
51-
options = mc_data['option'][0]
51+
options = mc_data['options'][0]
5252
gt_answer_letter = mc_data['gt_answer_letter'][0]
5353
gt_answer_name = mc_data['gt_answer_name'][0]
5454
conversation = generate_random_mc_conversation(options, gt_answer_letter, gt_answer_name )

action/llava_ov_inference.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,29 @@ def llava_inference(video_frames,
2020
image_processor,
2121
mc_data,
2222
clip_length = 16,
23-
num_frames=16):
23+
num_frames = 16,
24+
temperature = 0,
25+
is_test = False
26+
):
2427

2528
model.eval()
2629
device = "cuda"
2730
video_frames = video_frames[0]
2831
temporal_stride = clip_length // num_frames
2932
video_frames = video_frames[::temporal_stride]
3033
image_tensors = []
31-
frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].cuda().to(torch.bfloat16)
34+
if is_test:
35+
frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].half().cuda()
36+
else:
37+
frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].cuda().to(torch.bfloat16)
3238
image_tensors.append(frames)
3339

3440
conv_template = "qwen_1_5"
3541

3642
question = mc_data['question'][0]
37-
option = mc_data['option'][0]
43+
options = mc_data['options'][0]
3844

39-
question = f"{DEFAULT_IMAGE_TOKEN}\n{question}:{option}"
45+
question = f"{DEFAULT_IMAGE_TOKEN}\n{question}:{options}"
4046

4147
conv = copy.deepcopy(conv_templates[conv_template])
4248
conv.append_message(conv.roles[0], question)
@@ -52,7 +58,7 @@ def llava_inference(video_frames,
5258
images=image_tensors,
5359
image_sizes=image_sizes,
5460
do_sample=False,
55-
temperature=0,
61+
temperature=temperature,
5662
max_new_tokens=4096,
5763
modalities=["video"],
5864
)

action/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def generate_multi_choice(self, gt_vn, k):
100100
gt_letter = letters[answers.index(gt_answer)]
101101
data = {
102102
'question': {0: 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer'},
103-
'option': {0: options},
103+
'options': {0: options},
104104
# the correct letter in mc
105105
# for inspecting
106106
'gt_answer_letter': {0: gt_letter},
@@ -153,7 +153,7 @@ def generate_multi_choice(self, gt_vn, avion_predictions, k):
153153

154154
data = {
155155
'question': {0: 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer'},
156-
'option': {0: options},
156+
'options': {0: options},
157157
# the correct letter in mc
158158
# for inspecting
159159
'gt_answer_letter': {0: gt_letter},

llava/model/builder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,7 @@ def load_from_hf(repo_id, filename, subfolder=None):
234234
else:
235235
from llava.model.language_model.llava_qwen import LlavaQwenConfig
236236

237-
#if overwrite_config is not None:
238-
if True:
237+
if overwrite_config is not None:
239238
llava_cfg = LlavaQwenConfig.from_pretrained(model_path)
240239
for k, v in overwrite_config.items():
241240
setattr(llava_cfg, k, v)

llava/train/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1725,6 +1725,8 @@ def make_inputs_require_grad(module, input, output):
17251725

17261726
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
17271727

1728+
1729+
17281730
trainer = LLaVATrainer(model=model,
17291731
tokenizer = tokenizer,
17301732
eval_args = eval_args,
@@ -1734,7 +1736,6 @@ def make_inputs_require_grad(module, input, output):
17341736

17351737
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
17361738
trainer.train(resume_from_checkpoint=True)
1737-
#trainer.train()
17381739
else:
17391740
trainer.train()
17401741
trainer.save_state()

0 commit comments

Comments
 (0)