Skip to content

Commit f1235c1

Browse files
author
Haozhe Qi
committed
udates
1 parent f2d9dd8 commit f1235c1

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

llava/action/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def format_task_related_prompt(question, question_type, meta_data = None, perspe
270270

271271
elif question_type == "direct_narration":
272272

273-
if learn_neighbor_actions == "prior" and meta_data and random.random() < 0.5:
273+
if learn_neighbor_actions == "prior" and meta_data and random.random() < 0.1:
274274
ret = f"{perspective_prefix} {prev2_offset} seconds ago, you started an action {prev2_narration}. {prev1_offset} seconds ago, you started an action {prev1_narration}. What action are you currently performing? Give a short sentence such as 'move knife'. "
275275
else:
276276
ret = f"{perspective_prefix} What action are you performing? Give a short sentence such as 'move knife'."

llava/train/train.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,8 @@ def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer,
989989

990990
from llava.action.generate_interval_pred import get_lookup_dict
991991

992-
self.train_triple_lookup = get_lookup_dict(os.path.join(self.EK100_anno_root, 'EPIC_100_train.csv'), self.eval_args.action_representation)
992+
self.train_triple_lookup_official = get_lookup_dict(os.path.join(self.EK100_anno_root, 'EPIC_100_train.csv'), 'official_key')
993+
self.train_triple_lookup_narration = get_lookup_dict(os.path.join(self.EK100_anno_root, 'EPIC_100_train.csv'), 'GT_random_narration')
993994

994995
# Handle multiple JSON files specified in the data_path
995996
if "{" in data_path and "}" in data_path:
@@ -1282,7 +1283,12 @@ def _get_item(self, i) -> Dict[str, torch.Tensor]:
12821283
start_timestamp = round(float(self.list_data_dict[i]['start_timestamp']), 2)
12831284
end_timestamp = round(float(self.list_data_dict[i]['end_timestamp']), 2)
12841285
uid = f"{vid}_{start_timestamp}_{end_timestamp}"
1285-
meta_data = self.train_triple_lookup.get(uid, None)
1286+
# if True:
1287+
# meta_data = self.train_triple_lookup_narration.get(uid, None)
1288+
if 'official_key' in sources[0]['question_type']:
1289+
meta_data = self.train_triple_lookup_official.get(uid, None)
1290+
elif 'GT_random_narration' in sources[0]['question_type']:
1291+
meta_data = self.train_triple_lookup_narration.get(uid, None)
12861292

12871293

12881294
if 'EK100' not in video_file and 'EKframes' not in video_folder:

run_llmseval_clariden.sbatch

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ PYTHON_FILE="-m lmms_eval"
5555
PYTHON_ARGS=" \
5656
--model llava_vid \
5757
--model_args pretrained=experiments/LLaVA-Video-7B-Qwen2,conv_template=qwen_1_5,max_frames_num=64,mm_spatial_pool_mode=average \
58-
--tasks videomme \
58+
--tasks videomme,egoschema,nextqa \
5959
--batch_size 1 \
6060
--log_samples \
6161
--log_samples_suffix llava_vid \
62-
--output_path ./logs/
62+
--output_path ./benchmarks/
6363
--verbosity=DEBUG \
6464
"
6565

0 commit comments

Comments
 (0)