Skip to content

Commit 46ea7d6

Browse files
author
Ye Shaokai
committed
Merge branch 'shaokai/dev' of github.com:yeshaokai/LLaVA-NeXT into shaokai/dev
2 parents 153434c + 79829db commit 46ea7d6

File tree

6 files changed

+48
-13
lines changed

6 files changed

+48
-13
lines changed

llava/action/benchmark.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
n_frames = 4
1111
topk = 5
1212
action_representation = 'GT_random_narration'
13-
gpt_model = 'gpt-4o-mini-2024-07-18'
14-
# gpt_model = 'gpt-4o-2024-08-06'
15-
perspective = 'third_person'
13+
#gpt_model = 'gpt-4o-mini-2024-07-18'
14+
gpt_model = 'gpt-4o-2024-08-06'
15+
perspective = 'first_person'
16+
benchmark_testing = True
1617

1718

1819
def benchmark_avion_mcq(n_samples):
@@ -26,6 +27,7 @@ def benchmark_avion_mcq(n_samples):
2627
question_type = 'mc_',
2728
action_representation=action_representation,
2829
perspective = perspective,
30+
benchmark_testing = benchmark_testing,
2931
topk = topk)
3032
inferencer.multi_process_run(n_samples)
3133

@@ -40,6 +42,7 @@ def benchmark_tim_mcq(n_samples):
4042
question_type = 'mc_',
4143
action_representation=action_representation,
4244
perspective = perspective,
45+
benchmark_testing = benchmark_testing,
4346
topk = topk)
4447
inferencer.multi_process_run(n_samples)
4548

@@ -53,6 +56,7 @@ def benchmark_random_mcq(n_samples):
5356
question_type = 'mc_',
5457
action_representation=action_representation,
5558
perspective = perspective,
59+
benchmark_testing = benchmark_testing,
5660
topk = topk)
5761

5862
inferencer.multi_process_run(n_samples)
@@ -61,4 +65,4 @@ def benchmark_random_mcq(n_samples):
6165
if __name__ == '__main__':
6266
benchmark_avion_mcq(100)
6367
benchmark_tim_mcq(100)
64-
benchmark_random_mcq(100)
68+
#benchmark_random_mcq(100)

llava/action/chatgpt_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,8 @@ def __init__(self,
347347
question_type = 'cot_mc',
348348
debug = False,
349349
topk = 10,
350-
perspective = 'first_person'
350+
perspective = 'first_person',
351+
benchmark_testing = False
351352
):
352353
"""
353354
Parameters
@@ -370,6 +371,7 @@ def __init__(self,
370371

371372
self.gen_type = gen_type
372373
self.perspective = perspective
374+
self.benchmark_testing = benchmark_testing
373375
assert gen_type in ['avion', 'tim', 'random']
374376

375377
if gen_type == 'avion' or gen_type == 'tim':
@@ -409,6 +411,7 @@ def init_data(self):
409411
self.mapping_vn2narration,
410412
self.verb_maps,
411413
self.noun_maps,
414+
benchmark_testing = self.benchmark_testing,
412415
is_train = False)
413416
else:
414417
mc_data = self.mc_generator.generate_multi_choice(gt_vn,
@@ -420,6 +423,7 @@ def init_data(self):
420423
self.mapping_vn2narration,
421424
self.verb_maps,
422425
self.noun_maps,
426+
benchmark_testing = self.benchmark_testing,
423427
is_train = False)
424428

425429
options = mc_data['options'][0]

llava/action/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ def __getitem__(self, i):
313313
self.mapping_vn2narration,
314314
self.verb_maps,
315315
self.noun_maps,
316-
is_train = False) # note we only use this dataset for evaluation for now.
316+
is_train = False,
317+
benchmark_testing = eval_args.benchmark_testing) # note we only use this dataset for evaluation for now.
317318

318319

319320
return frames, data, time_meta, i

llava/action/ek_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def get_args_parser():
129129
parser.add_argument('--pseudo_folder', default = None, type = str)
130130
parser.add_argument('--output_dir', default = None, type = str)
131131
parser.add_argument("--perspective", default = "first_person", type = str)
132+
parser.add_argument('--benchmark_testing', action='store_true', default = False)
132133
return parser
133134

134135
def prepare_llava(pretrained):

llava/action/utils.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,8 @@ def generate_multi_choice(self,
411411
mapping_vn2narration,
412412
verb_maps,
413413
noun_maps,
414-
is_train = True
414+
is_train = True,
415+
benchmark_testing = False
415416
):
416417

417418
"""
@@ -425,7 +426,7 @@ def generate_multi_choice(self,
425426
if is_train:
426427
return self.train_generate(gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps)
427428
else:
428-
return self.test_generate(gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps)
429+
return self.test_generate(gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps, benchmark_testing = benchmark_testing)
429430

430431
def train_generate(self, gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps):
431432
# letters as A, B, C, D, .. Note we maximally support 26 letters
@@ -522,12 +523,34 @@ def train_generate(self, gt_vn, avion_predictions, narration, k, action_represen
522523
}
523524
return mc_data
524525

525-
def test_generate(self, gt_vn, avion_predictions, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps):
526+
def test_generate(self,
527+
gt_vn,
528+
action_model_predictions,
529+
narration,
530+
k,
531+
action_representation,
532+
n_narrations,
533+
labels,
534+
mapping_vn2narration,
535+
verb_maps,
536+
noun_maps,
537+
benchmark_testing = False
538+
):
526539
"""
527540
During testing, we use the top k predictions from avion. No randomness. We do not mix the gt_vn with the avion predictions
528541
"""
529-
530-
answer_ids = avion_predictions[:k]
542+
answer_ids = action_model_predictions[:k]
543+
544+
if benchmark_testing:
545+
# if we are testing on benchmark, we need to ensure that the gt_vn is in the top k predictions
546+
# if not, we remove the last prediction and add the gt_vn
547+
if gt_vn not in answer_ids:
548+
answer_ids.pop()
549+
answer_ids.append(gt_vn)
550+
551+
# let's shuffle answer_ids so that the gt_vn is not always at the end
552+
random.shuffle(answer_ids)
553+
531554
answers = []
532555
for answer_id in answer_ids:
533556
answer = parse_vn_ids(answer_id, gt_vn, narration, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps)
@@ -566,7 +589,8 @@ def generate_multi_choice(self,
566589
mapping_vn2narration,
567590
verb_maps,
568591
noun_maps,
569-
is_train = True
592+
is_train = True,
593+
benchmark_testing = False
570594
):
571595
"""
572596
Generate k multiple choices from gt_vn pairs
@@ -578,7 +602,7 @@ def generate_multi_choice(self,
578602
if is_train:
579603
return self.train_generate(gt_vn, avion_predictions, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps)
580604
else:
581-
return self.test_generate(gt_vn, avion_predictions, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps)
605+
return self.test_generate(gt_vn, avion_predictions, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps, benchmark_testing = benchmark_testing)
582606

583607
def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
584608
frame_ids = np.convolve(np.linspace(start_frame, end_frame, num_segments + 1), [0.5, 0.5], mode='valid')

llava/train/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class EK100EvalArguments:
204204
learn_neighbor_actions: bool = False
205205
perspective: str = "first_person"
206206
pseudo_folder: str = ""
207+
benchmark_testing: bool = False
207208

208209
def maybe_zero_3(param, ignore_status=False, name=None):
209210
from deepspeed import zero

0 commit comments

Comments
 (0)