Skip to content

Commit 0cf8e6d

Browse files
author
Haozhe Qi
committed
supports random generation for multichoice dataset
1 parent ab28d4c commit 0cf8e6d

File tree

4 files changed

+41
-22
lines changed

4 files changed

+41
-22
lines changed

llava/action/dataset.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
import decord
88
from pathlib import Path
9-
from llava.action.utils import AvionMultiChoiceGenerator, avion_video_loader, EK100_frame_loader
9+
from llava.action.utils import AvionMultiChoiceGenerator, RandomMultiChoiceGenerator, avion_video_loader, EK100_frame_loader
1010
from llava.action.prediction_analysis import PredictionAnalysis
1111
import torch.distributed as dist
1212

@@ -147,6 +147,7 @@ def __init__(
147147
verb_maps = None,
148148
noun_maps = None,
149149
eval_result_folder = None,
150+
gen_type = 'action_model',
150151
action_representation = 'GT_random_narration_cut',
151152
mapping_vn2narration = None,
152153
avion_predictions = None,
@@ -175,14 +176,19 @@ def __init__(
175176
self.labels = labels
176177
self.topk_predictions = topk_predictions
177178
self.ann_root = Path(metadata).parent
178-
self.mc_generator = AvionMultiChoiceGenerator(self.ann_root)
179+
self.gen_type = gen_type
180+
if gen_type == 'action_model':
181+
self.mc_generator = AvionMultiChoiceGenerator(self.ann_root)
182+
elif gen_type == 'random':
183+
self.mc_generator = RandomMultiChoiceGenerator(self.ann_root)
179184
self.rank = dist.get_rank()
180185
self.prediction_analysis = PredictionAnalysis(rank = self.rank, save_folder = eval_result_folder)
181186
self.action_representation = action_representation
182187
self.n_narrations = n_narrations
183188
self.mapping_vn2narration = mapping_vn2narration
184189
self.avion_predictions = avion_predictions
185190

191+
186192
def __getitem__(self, i):
187193
frames, label, time_meta = self.get_raw_item(
188194
i, is_training=self.is_training,
@@ -205,19 +211,31 @@ def __getitem__(self, i):
205211
frames = self.transform(frames)
206212
narration = self.samples[i][4]
207213
avion_preds = self.avion_predictions[str(i)]['predictions']
208-
209-
data = self.mc_generator.generate_multi_choice(label,
210-
avion_preds,
211-
narration,
212-
self.topk_predictions,
213-
self.action_representation,
214-
self.n_narrations,
215-
self.labels,
216-
self.mapping_vn2narration,
217-
self.verb_maps,
218-
self.noun_maps,
219-
benchmark_testing = self.eval_args.benchmark_testing,
220-
is_train = False) # note we only use this dataset for evaluation for now.
221-
214+
if self.gen_type =='action_model':
215+
data = self.mc_generator.generate_multi_choice(label,
216+
avion_preds,
217+
narration,
218+
self.topk_predictions,
219+
self.action_representation,
220+
self.n_narrations,
221+
self.labels,
222+
self.mapping_vn2narration,
223+
self.verb_maps,
224+
self.noun_maps,
225+
benchmark_testing = self.eval_args.benchmark_testing,
226+
is_train = False) # note we only use this dataset for evaluation for now.
227+
else:
228+
data = self.mc_generator.generate_multi_choice(label,
229+
narration,
230+
self.topk_predictions,
231+
self.action_representation,
232+
self.n_narrations,
233+
self.labels,
234+
self.mapping_vn2narration,
235+
self.verb_maps,
236+
self.noun_maps,
237+
benchmark_testing = self.eval_args.benchmark_testing,
238+
is_train = False) # no
239+
222240

223241
return frames, data, time_meta, i

llava/action/ek_eval.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def get_args_parser():
131131
parser.add_argument("--perspective", default = "first_person", type = str)
132132
parser.add_argument('--benchmark_testing', action='store_true', default = False)
133133
parser.add_argument('--include_time_instruction', action='store_true', default = False)
134+
parser.add_argument('--gen_type', type = str, default = 'action_model') # action_model, random
134135
return parser
135136

136137
def prepare_llava(pretrained):
@@ -191,7 +192,7 @@ def ensemble_llava_evaluation(
191192
# shuffle the options
192193
options = mc_data['options'][0]
193194
letters = mc_data['valid_letters']
194-
avion_pred = mc_data['avion_pred']
195+
avion_pred = mc_data.get('avion_pred', None)
195196
# each option was in the format of {letter}. {answer}
196197
preds = []
197198
for _ in range(ensemble_k):
@@ -283,6 +284,7 @@ def evaluate_on_EK100(eval_args,
283284
mapping_vn2narration = mapping_vn2narration,
284285
avion_predictions = predictions if eval_args.action_predictions else None,
285286
n_narrations = eval_args.n_narrations,
287+
gen_type = eval_args.gen_type
286288
)
287289

288290
def collate_fn(batch):
@@ -371,7 +373,7 @@ def collate_fn(batch):
371373
local_running_corrects = torch.tensor(0.0, device=device)
372374
local_total_samples = torch.tensor(0.0, device=device)
373375

374-
if eval_args.action_predictions:
376+
if eval_args.action_predictions and eval_args.gen_type == 'action_model':
375377
avion_pred = mc_data['avion_pred']
376378
if gt_name == avion_pred:
377379
local_avion_correct.add_(1)
@@ -420,7 +422,7 @@ def collate_fn(batch):
420422
val_dataset.prediction_analysis.log(global_index,
421423
llava_pred,
422424
gt_name,
423-
mc_data['all_avion_preds'],
425+
mc_data.get('all_avion_preds', None),
424426
time_meta['start_second'],
425427
time_meta['end_second'],
426428
time_meta['vid_path'],

llava/action/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,6 @@ def generate_multi_choice(self,
425425
randomly pick k-1 letters from vn_list
426426
427427
"""
428-
429428
if is_train:
430429
return self.train_generate(gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps)
431430
else:
@@ -547,14 +546,13 @@ def test_generate(self,
547546
answer_ids = action_model_predictions[:k]
548547

549548
if benchmark_testing:
550-
print ("am i here")
551549
# if we are testing on benchmark, we need to ensure that the gt_vn is in the top k predictions
552550
# if not, we remove the last prediction and add the gt_vn
553551
if gt_vn not in answer_ids:
554552
answer_ids.pop()
555553
answer_ids.append(gt_vn)
556554
else:
557-
print ("am i not here")
555+
pass
558556

559557
answers = []
560558
for answer_id in answer_ids:

llava/train/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ class EK100EvalArguments:
206206
pseudo_folder: str = ""
207207
benchmark_testing: bool = False
208208
include_time_instruction: bool = False
209+
gen_type: str = 'action_model'
209210

210211
def maybe_zero_3(param, ignore_status=False, name=None):
211212
from deepspeed import zero

0 commit comments

Comments
 (0)