Skip to content

Commit ec1e6e0

Browse files
author
Haozhe Qi
committed
fixed an important bug
1 parent 1407244 commit ec1e6e0

File tree

2 files changed

+4
-99
lines changed

2 files changed

+4
-99
lines changed

llava/action/dataset.py

Lines changed: 1 addition & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -216,105 +216,8 @@ def __getitem__(self, i):
216216
self.mapping_vn2narration,
217217
self.verb_maps,
218218
self.noun_maps,
219+
benchmark_testing = self.eval_args.benchmark_testing,
219220
is_train = False) # note we only use this dataset for evaluation for now.
220221

221222

222223
return frames, data, time_meta, i
223-
224-
225-
226-
227-
class VideoTemporalMultiChoiceDataset(VideoCaptionDatasetBase):
228-
def __init__(
229-
self, dataset, root, metadata, transform=None,
230-
is_training=True, label_mapping=None,
231-
num_clips=1,
232-
chunk_len=300,
233-
clip_length=32, clip_stride=2,
234-
threads=1,
235-
fast_rrc=False,
236-
rrc_params=(224, (0.5, 1.0)),
237-
fast_rcc=False,
238-
rcc_params=(224,),
239-
sparse_sample=False,
240-
labels = None,
241-
is_trimmed=True,
242-
eval_args = None,
243-
topk_predictions = 5,
244-
verb_maps = None,
245-
noun_maps = None,
246-
eval_result_folder = None,
247-
action_representation = 'GT_random_narration',
248-
mapping_vn2narration = None,
249-
avion_predictions = None,
250-
n_narrations = -1,
251-
):
252-
super().__init__(dataset, root, metadata, is_trimmed=is_trimmed)
253-
254-
self.transform = transform
255-
self.is_training = is_training
256-
self.label_mapping = label_mapping
257-
self.num_clips = num_clips
258-
self.chunk_len = chunk_len
259-
self.clip_length = clip_length
260-
self.clip_stride = clip_stride
261-
self.threads = threads
262-
self.fast_rrc = fast_rrc
263-
self.rrc_params = rrc_params
264-
self.fast_rcc = fast_rcc
265-
self.rcc_params = rcc_params
266-
self.sparse_sample = sparse_sample
267-
self.eval_args = eval_args
268-
self.verb_maps = verb_maps
269-
self.noun_maps = noun_maps
270-
self.vn_list = list(self.label_mapping.keys())
271-
272-
self.labels = labels
273-
self.topk_predictions = topk_predictions
274-
self.ann_root = Path(metadata).parent
275-
self.mc_generator = AvionMultiChoiceGenerator(self.ann_root)
276-
self.rank = dist.get_rank()
277-
self.prediction_analysis = PredictionAnalysis(rank = self.rank, save_folder = eval_result_folder)
278-
self.action_representation = action_representation
279-
self.n_narrations = n_narrations
280-
self.mapping_vn2narration = mapping_vn2narration
281-
self.avion_predictions = avion_predictions
282-
283-
def __getitem__(self, i):
284-
frames, label, time_meta = self.get_raw_item(
285-
i, is_training=self.is_training,
286-
chunk_len=self.chunk_len,
287-
num_clips=self.num_clips,
288-
clip_length=self.clip_length,
289-
clip_stride=self.clip_stride,
290-
threads=self.threads,
291-
fast_rrc=self.fast_rrc,
292-
rrc_params=self.rrc_params,
293-
fast_rcc=self.fast_rcc,
294-
rcc_params=self.rcc_params,
295-
sparse_sample=self.sparse_sample,
296-
)
297-
298-
# for llava-video to work, we also need time meta data.
299-
300-
# apply transformation
301-
if self.transform is not None:
302-
frames = self.transform(frames)
303-
narration = self.samples[i][4]
304-
avion_preds = self.avion_predictions[str(i)]['predictions']
305-
306-
data = self.mc_generator.generate_multi_choice(label,
307-
avion_preds,
308-
narration,
309-
self.topk_predictions,
310-
self.action_representation,
311-
self.n_narrations,
312-
self.labels,
313-
self.mapping_vn2narration,
314-
self.verb_maps,
315-
self.noun_maps,
316-
is_train = False,
317-
benchmark_testing = eval_args.benchmark_testing) # note we only use this dataset for evaluation for now.
318-
319-
320-
return frames, data, time_meta, i

llava/action/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,12 +542,14 @@ def test_generate(self,
542542
answer_ids = action_model_predictions[:k]
543543

544544
if benchmark_testing:
545+
print ("am i here")
545546
# if we are testing on benchmark, we need to ensure that the gt_vn is in the top k predictions
546547
# if not, we remove the last prediction and add the gt_vn
547548
if gt_vn not in answer_ids:
548549
answer_ids.pop()
549550
answer_ids.append(gt_vn)
550-
551+
else:
552+
print ("am i not here")
551553

552554
answers = []
553555
for answer_id in answer_ids:

0 commit comments

Comments
 (0)