Skip to content

Commit 0f32693

Browse files
author
Ye Shaokai
committed
WIP
1 parent c5ee60b commit 0f32693

File tree

5 files changed

+37
-7
lines changed

5 files changed

+37
-7
lines changed

llava/action/chatgpt_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,8 @@ def __init__(self,
347347
question_type = 'cot_mc',
348348
debug = False,
349349
topk = 10,
350-
perspective = 'first_person'
350+
perspective = 'first_person',
351+
benchmark_testing = False
351352
):
352353
"""
353354
Parameters
@@ -370,6 +371,7 @@ def __init__(self,
370371

371372
self.gen_type = gen_type
372373
self.perspective = perspective
374+
self.benchmark_testing = benchmark_testing
373375
assert gen_type in ['avion', 'tim', 'random']
374376

375377
if gen_type == 'avion' or gen_type == 'tim':
@@ -409,6 +411,7 @@ def init_data(self):
409411
self.mapping_vn2narration,
410412
self.verb_maps,
411413
self.noun_maps,
414+
benchmark_tesitng = self.benchmark_testing,
412415
is_train = False)
413416
else:
414417
mc_data = self.mc_generator.generate_multi_choice(gt_vn,
@@ -420,6 +423,7 @@ def init_data(self):
420423
self.mapping_vn2narration,
421424
self.verb_maps,
422425
self.noun_maps,
426+
benchmark_testing = self.benchmark_testing,
423427
is_train = False)
424428

425429
options = mc_data['options'][0]

llava/action/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ def __getitem__(self, i):
313313
self.mapping_vn2narration,
314314
self.verb_maps,
315315
self.noun_maps,
316-
is_train = False) # note we only use this dataset for evaluation for now.
316+
is_train = False,
317+
benchmark_testing = eval_args.benchmark_testing) # note we only use this dataset for evaluation for now.
317318

318319

319320
return frames, data, time_meta, i

llava/action/ek_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def get_args_parser():
129129
parser.add_argument('--pseudo_folder', default = None, type = str)
130130
parser.add_argument('--output_dir', default = None, type = str)
131131
parser.add_argument("--perspective", default = "first_person", type = str)
132+
parser.add_argument('--benchmark_testing', action='store_true', default = False)
132133
return parser
133134

134135
def prepare_llava(pretrained):

llava/action/utils.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -522,12 +522,34 @@ def train_generate(self, gt_vn, avion_predictions, narration, k, action_represen
522522
}
523523
return mc_data
524524

525-
def test_generate(self, gt_vn, avion_predictions, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps):
525+
def test_generate(self,
526+
gt_vn,
527+
action_model_predictions,
528+
narration,
529+
k,
530+
action_representation,
531+
n_narrations,
532+
labels,
533+
mapping_vn2narration,
534+
verb_maps,
535+
noun_maps,
536+
benchmark_testing = False
537+
):
526538
"""
527539
During testing, we use the top k predictions from avion. No randomness. We do not mix the gt_vn with the avion predictions
528540
"""
529-
530-
answer_ids = avion_predictions[:k]
541+
answer_ids = action_model_predictions[:k]
542+
543+
if benchmark_testing:
544+
# if we are testing on benchmark, we need to ensure that the gt_vn is in the top k predictions
545+
# if not, we remove the last prediction and add the gt_vn
546+
if gt_vn not in answer_ids:
547+
answer_ids.pop()
548+
answer_ids.append(gt_vn)
549+
550+
# let's shuffle answer_ids so that the gt_vn is not always at the end
551+
random.shuffle(answer_ids)
552+
531553
answers = []
532554
for answer_id in answer_ids:
533555
answer = parse_vn_ids(answer_id, gt_vn, narration, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps)
@@ -566,7 +588,8 @@ def generate_multi_choice(self,
566588
mapping_vn2narration,
567589
verb_maps,
568590
noun_maps,
569-
is_train = True
591+
is_train = True,
592+
benchmark_testing = False
570593
):
571594
"""
572595
Generate k multiple choices from gt_vn pairs
@@ -578,7 +601,7 @@ def generate_multi_choice(self,
578601
if is_train:
579602
return self.train_generate(gt_vn, avion_predictions, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps)
580603
else:
581-
return self.test_generate(gt_vn, avion_predictions, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps)
604+
return self.test_generate(gt_vn, avion_predictions, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps, benchmark_testing = benchmark_testing)
582605

583606
def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
584607
frame_ids = np.convolve(np.linspace(start_frame, end_frame, num_segments + 1), [0.5, 0.5], mode='valid')

llava/train/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class EK100EvalArguments:
204204
learn_neighbor_actions: bool = False
205205
perspective: str = "first_person"
206206
pseudo_folder: str = ""
207+
benchmark_testing: bool = False
207208

208209
def maybe_zero_3(param, ignore_status=False, name=None):
209210
from deepspeed import zero

0 commit comments

Comments
 (0)