Skip to content

Commit 7665a5e

Browse files
committed
update
1 parent 8891c50 commit 7665a5e

File tree

7 files changed

+25
-23
lines changed

7 files changed

+25
-23
lines changed

.vscode/launch.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@
299299
"--vision_supervision", "three_tokens",
300300
"--vision_token_training", "all_layers",
301301
"--action_types", "97,300,3806",
302-
"--perspective", "first_person"
302+
"--learn_neighbor_actions", "prior",
303+
"--test_type", "temporal_cot"
303304
],
304305
"console": "integratedTerminal",
305306
"justMyCode": false,

llava/action/ek_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def get_args_parser():
125125
'GT_key', 'GT_random_narration', 'GT_random_narration_cut', 'gpt_narration'])
126126
parser.add_argument('--n_narrations', default = -1, type = int)
127127
parser.add_argument('--test_type', default = 'base', type = str, choices = ['caption', 'base', 'temporal_cot', 'caption_then_answer', 'direct_narration'])
128-
parser.add_argument('--learn_neighbor_actions', action='store_true', default = False)
128+
parser.add_argument('--learn_neighbor_actions', type= str, default = "")
129129
parser.add_argument('--pseudo_folder', default = None, type = str)
130130
parser.add_argument('--output_dir', default = None, type = str)
131131
parser.add_argument("--perspective", default = "first_person", type = str)
@@ -168,7 +168,7 @@ def ensemble_llava_evaluation(
168168
clip_length,
169169
num_frames,
170170
test_type = 'base',
171-
learn_neighbor_actions = False,
171+
learn_neighbor_actions = "",
172172
time_meta = None,
173173
meta_data = None,
174174
perspective = "first_person"

llava/action/llava_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def llava_inference(
2020
temperature = 0,
2121
test_type = 'base',
2222
time_meta = None,
23-
learn_neighbor_actions = False,
23+
learn_neighbor_actions = "",
2424
meta_data = None,
2525
perspective = "first_person"
2626
):

llava/action/make_visualizations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def visualize_with_llava(pretrained_path, uid, question_type, gen_type):
206206
num_frames=n_frames,
207207
temperature = 0,
208208
time_meta = time_meta,
209-
learn_neighbor_actions = False,
209+
learn_neighbor_actions = "",
210210
meta_data = None,
211211
perspective = perspective
212212
)

llava/action/utils.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def remove_sub_nouns_with_doc(doc, verb: str, noun: str) -> str:
224224
return processed_text
225225

226226

227-
def format_task_related_prompt(question, question_type, meta_data = None, perspective = "first_person", learn_neighbor_actions = False):
227+
def format_task_related_prompt(question, question_type, meta_data = None, perspective = "first_person", learn_neighbor_actions = ""):
228228
"""
229229
Task related prompt is impacted by the question_type.
230230
We currently support mc_{action_representation} and gpt-gt-reason
@@ -238,7 +238,20 @@ def format_task_related_prompt(question, question_type, meta_data = None, perspe
238238

239239
if question_type.startswith("mc_") or question_type == 'temporal_cot':
240240

241-
if learn_neighbor_actions and meta_data:
241+
if question_type.startswith("mc_") and learn_neighbor_actions == "prior" and meta_data and random.random() < 0.3:
242+
# this means it's training time and we are learning the prior actions
243+
prefix = f"{perspective_prefix}\n"
244+
assert isinstance(question, list)
245+
suffix = ", ".join(question)
246+
prev2_narration = meta_data['prev2_narration']
247+
prev2_offset = meta_data['prev2_offset']
248+
prev1_narration = meta_data['prev1_narration']
249+
prev1_offset = meta_data['prev1_offset']
250+
cur_narration = meta_data['cur_narration']
251+
suffix = f"{prev2_offset} seconds ago, you started an action {prev2_narration}. {prev1_offset} seconds ago, you started an action {prev1_narration}. What action are you currently performing? Here are the options of actions you can select:\n" + suffix
252+
ret = prefix + suffix
253+
elif question_type == "temporal_cot" and learn_neighbor_actions == "prior" and meta_data:
254+
# means it's test time
242255
prefix = f"{perspective_prefix}\n"
243256
assert isinstance(question, list)
244257
suffix = ", ".join(question)
@@ -264,18 +277,7 @@ def format_task_related_prompt(question, question_type, meta_data = None, perspe
264277
ret = question
265278
elif question_type == "gpt-gt-reason" or question_type == "caption":
266279
ret = f"{perspective_prefix} Describe in details what you see from the video frames. You must talk in the first person perspective. Try to focus on what you are doing. "
267-
268-
elif question_type == "triple_direct_answer":
269-
assert meta_data
270-
duration1 = meta_data[0]['duration']
271-
duration2 = meta_data[1]['duration']
272-
duration3 = meta_data[2]['duration']
273-
prompt = f"The video consists of 3 sequential actions. What are the actions? Format your answer as action1, action2, action3."
274-
ret = f"{perspective_prefix}{prompt}"
275-
276-
277-
elif question_type == "validation":
278-
ret = f"Ask yourself questions to validate your notes."
280+
279281

280282
elif question_type == "gpt-gt-strong-reason":
281283
ret = f"{perspective_prefix} Describe in details what you see and answer the multi-choice question. Explain why wrong answers are wrong and why the correct answer is correct. "
@@ -328,7 +330,7 @@ def format_llava_prompt(image_token,
328330
include_time_instruction = False,
329331
include_frame_time = False,
330332
meta_data = None,
331-
learn_neighbor_actions = False,
333+
learn_neighbor_actions = "",
332334
perspective = "first_person"
333335
):
334336
"""

llava/model/language_model/llava_qwen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def forward(
244244
pass
245245
# by default, distilaltion uses all layers
246246
# First check if any process has valid examples across all triples
247-
world_has_valid = torch.tensor(actions[:, 0].any() > 0, device=actions.device)
247+
world_has_valid = torch.tensor(actions[:, 0].any() >= 0, device=actions.device)
248248
torch.distributed.all_reduce(world_has_valid, op=torch.distributed.ReduceOp.MAX)
249249

250250
if world_has_valid: # If any process has valid examples

llava/train/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class EK100EvalArguments:
201201
action_representation: str = "GT_random_narration_cut"
202202
n_narrations: int = -1
203203
test_type: str = 'base'
204-
learn_neighbor_actions: bool = False
204+
learn_neighbor_actions: str = "" # "prior", "triple_direct"
205205
perspective: str = "first_person"
206206
pseudo_folder: str = ""
207207
benchmark_testing: bool = False
@@ -990,7 +990,6 @@ def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer,
990990
from llava.action.generate_interval_pred import get_lookup_dict
991991

992992
self.train_triple_lookup = get_lookup_dict(os.path.join(self.EK100_anno_root, 'EPIC_100_train.csv'))
993-
#self.val_triple_lookup = get_lookup_dict(os.path.join(self.EK100_anno_root, 'EPIC_100_validation.csv'))
994993

995994
# Handle multiple JSON files specified in the data_path
996995
if "{" in data_path and "}" in data_path:

0 commit comments

Comments
 (0)