Skip to content

Commit e97bccb

Browse files
author
Haozhe Qi
committed
updates
1 parent 0cf8e6d commit e97bccb

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

llava/action/ek_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def collate_fn(batch):
350350
from llava.action.generate_interval_pred import get_lookup_dict
351351
if eval_args.test_type.startswith('temporal_cot'):
352352
lookup_table = get_lookup_dict(eval_args.val_metadata,
353-
'GT_random_narration',
353+
eval_args.action_representation,
354354
test_type = eval_args.test_type,
355355
pseudo_folder = eval_args.pseudo_folder)
356356

llava/train/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1290,8 +1290,10 @@ def _get_item(self, i) -> Dict[str, torch.Tensor]:
12901290
start_timestamp = round(float(self.list_data_dict[i]['start_timestamp']), 2)
12911291
end_timestamp = round(float(self.list_data_dict[i]['end_timestamp']), 2)
12921292
uid = f"{vid}_{start_timestamp}_{end_timestamp}"
1293-
if True:
1293+
if 'narration' in self.eval_args.action_representation:
12941294
meta_data = self.train_triple_lookup_narration.get(uid, None)
1295+
elif 'official_key' in self.eval_args.action_representation:
1296+
meta_data = self.train_triple_lookup_official.get(uid, None)
12951297
# if 'official_key' in sources[0]['question_type']:
12961298
# meta_data = self.train_triple_lookup_official.get(uid, None)
12971299
# elif 'GT_random_narration' in sources[0]['question_type']:

0 commit comments

Comments
 (0)