Skip to content

Commit 06e75af

Browse files
author
Haozhe Qi
committed
debug
1 parent fd301b9 commit 06e75af

File tree

6 files changed

+26
-33
lines changed

6 files changed

+26
-33
lines changed

llava/action/ek_eval.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def get_args_parser():
128128
parser.add_argument('--learn_neighbor_actions', action='store_true', default = False)
129129
parser.add_argument('--pseudo_folder', default = None, type = str)
130130
parser.add_argument('--output_dir', default = None, type = str)
131+
parser.add_argument("--perspective", default = "first_person", type = str)
131132
return parser
132133

133134
def prepare_llava(pretrained):
@@ -169,6 +170,7 @@ def ensemble_llava_evaluation(
169170
learn_neighbor_actions = False,
170171
time_meta = None,
171172
meta_data = None,
173+
perspective = "first_person"
172174
):
173175
"""
174176
This function tests how consistent the model is if we shuffle the position of the answers
@@ -206,7 +208,8 @@ def ensemble_llava_evaluation(
206208
temperature = temperature,
207209
time_meta = time_meta,
208210
learn_neighbor_actions = learn_neighbor_actions,
209-
meta_data = meta_data
211+
meta_data = meta_data,
212+
perspective = perspective
210213
)
211214
# remove the trailing comma if there is one
212215
pred = pred.rstrip(',')
@@ -386,7 +389,9 @@ def collate_fn(batch):
386389
test_type = eval_args.test_type,
387390
learn_neighbor_actions = eval_args.learn_neighbor_actions,
388391
time_meta = time_meta,
389-
meta_data = meta_data)
392+
meta_data = meta_data,
393+
perspective = eval_args.perspective
394+
)
390395

391396

392397

llava/action/llava_inference.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def llava_inference(
2121
test_type = 'base',
2222
time_meta = None,
2323
learn_neighbor_actions = False,
24-
meta_data = None
24+
meta_data = None,
25+
perspective = "first_person"
2526
):
2627

2728
model.eval()
@@ -74,6 +75,7 @@ def llava_inference(
7475
"mc_top5_official_key",
7576
include_frame_time = False,
7677
learn_neighbor_actions = learn_neighbor_actions,
78+
perspective = perspective,
7779
include_time_instruction= False)
7880

7981
question = f"You observed the video before and wrote down the notes: {caption_answer}. Now you watch the same video again and you can do better. " + question
@@ -87,6 +89,7 @@ def llava_inference(
8789
include_frame_time = False,
8890
learn_neighbor_actions = learn_neighbor_actions,
8991
include_time_instruction= False,
92+
perspective = perspective,
9093
meta_data=meta_data)
9194

9295

llava/action/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,13 +328,17 @@ def format_llava_prompt(image_token,
328328
include_time_instruction = False,
329329
include_frame_time = False,
330330
meta_data = None,
331-
learn_neighbor_actions = False
331+
learn_neighbor_actions = False,
332+
perspective = "first_person"
332333
):
333334
"""
334335
baseline llava prompt: {image_token}\n{task_related_prompt}
335336
with time instruction: {image_token}\n{time_instruction}\n{task_related_prompt}
336337
"""
337-
task_related_prompt = format_task_related_prompt(question, question_type, meta_data = meta_data, learn_neighbor_actions = learn_neighbor_actions)
338+
task_related_prompt = format_task_related_prompt(question, question_type,
339+
meta_data = meta_data,
340+
learn_neighbor_actions = learn_neighbor_actions,
341+
perspective = perspective)
338342

339343
time_instruction = format_time_instruction(video_duration, n_frames, include_frame_time)
340344

llava/train/llava_trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,10 @@ def __init__(self,
496496
self.model_max_length = model_max_length
497497

498498
def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval", eval_result_folder = None):
499+
500+
print ('debug')
501+
print (self.eval_args)
502+
499503
accuracy = evaluate_on_EK100(self.eval_args, self.model, self.tokenizer, eval_result_folder = eval_result_folder)
500504
metrics = {f"{metric_key_prefix}_EK100_accuracy": accuracy}
501505
self.log(metrics)

llava/train/train.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -181,23 +181,6 @@ class TrainingArguments(transformers.TrainingArguments):
181181
attn_implementation: str = field(default='flash_attention_2', metadata={"help": "Use transformers attention implementation."})
182182
overwrite_output_dir: bool =True
183183

184-
# @dataclass
185-
# class EvaluationArguments:
186-
# eval_num_processes: int = field(default=1)
187-
# task_names: str = field(default=None)
188-
# model: str = field(default="llava")
189-
# model_args: Optional[str] = field(default=None)
190-
# num_fewshot: Optional[int] = field(default=None)
191-
# batch_size: int = field(default=1)
192-
# device: Optional[str] = field(default=None)
193-
# limit: Optional[int] = field(default=None)
194-
# check_integrity: Optional[bool] = field(default=False)
195-
# show_task_to_terminal: Optional[bool] = field(default=False)
196-
# log_samples: Optional[bool] = field(default=True)
197-
# gen_kwargs: Optional[str] = field(default="")
198-
# log_samples_suffix: Optional[str] = field(default="")
199-
# output_path: Optional[str] = field(default="./logs/")
200-
201184
# for EK100
202185
@dataclass
203186
class EK100EvalArguments:
@@ -219,6 +202,7 @@ class EK100EvalArguments:
219202
n_narrations: int = -1
220203
test_type: str = 'base'
221204
learn_neighbor_actions: bool = False
205+
perspective: str = "first_person"
222206

223207
def maybe_zero_3(param, ignore_status=False, name=None):
224208
from deepspeed import zero
@@ -1327,7 +1311,8 @@ def _get_item(self, i) -> Dict[str, torch.Tensor]:
13271311
include_time_instruction= self.data_args.add_time_instruction,
13281312
meta_data = meta_data,
13291313
include_frame_time = False,
1330-
learn_neighbor_actions = self.eval_args.learn_neighbor_actions)
1314+
learn_neighbor_actions = self.eval_args.learn_neighbor_actions,
1315+
perspective = self.eval_args.perspective)
13311316
sources[0]["conversations"][0]["value"] = llava_prompt
13321317
# rank0_print (sources[0])
13331318

run_llmseval_clariden.sbatch

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,11 @@ LAUNCHER="accelerate launch \
5050
"
5151

5252
PYTHON_FILE="-m lmms_eval"
53-
# PYTHON_ARGS=" \
54-
# --model llava_onevision \
55-
# --model_args pretrained=experiments/llava-onevision-qwen2-0.5b-ov,conv_template=qwen_1_5,model_name=llava_qwen \
56-
# --tasks video_dc499 \
57-
# --batch_size 1 \
58-
# --log_samples_suffix llava_onevision \
59-
# --output_path ./logs/ \
60-
# --verbosity=DEBUG \
61-
# "
53+
6254

6355
PYTHON_ARGS=" \
6456
--model llava_vid \
65-
--model_args pretrained=lmms-lab/LLaVA-Video-7B-Qwen2,conv_template=qwen_1_5,max_frames_num=64,mm_spatial_pool_mode=average \
57+
--model_args pretrained=experiments/LLaVA-Video-7B-Qwen2,conv_template=qwen_1_5,max_frames_num=64,mm_spatial_pool_mode=average \
6658
--tasks videomme \
6759
--batch_size 1 \
6860
--log_samples \

0 commit comments

Comments
 (0)