Skip to content

Commit 16f503f

Browse files
author
Haozhe Qi
committed
added back temporal_cot_oracle
1 parent 6c3db97 commit 16f503f

File tree

3 files changed

+23
-26
lines changed

3 files changed

+23
-26
lines changed

llava/action/ek_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def get_args_parser():
124124
'random_narration_cut', 'top1_narration_cut', 'topk_narration_cut_key',
125125
'GT_key', 'GT_random_narration', 'GT_random_narration_cut', 'gpt_narration'])
126126
parser.add_argument('--n_narrations', default = -1, type = int)
127-
parser.add_argument('--test_type', default = 'base', type = str, choices = ['caption', 'base', 'temporal_cot', 'caption_then_answer', 'direct_narration'])
127+
parser.add_argument('--test_type', default = 'base', type = str, choices = ['caption', 'base', 'temporal_cot', 'temporal_cot_oracle', 'caption_then_answer', 'direct_narration'])
128128
parser.add_argument('--learn_neighbor_actions', type= str, default = "")
129129
parser.add_argument('--pseudo_folder', default = None, type = str)
130130
parser.add_argument('--output_dir', default = None, type = str)

llava/action/generate_interval_pred.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def get_lookup_dict(ann_file, test_type = 'base', delta = 3, pseudo_folder = Non
158158
uid2 = f"{id}_{round(start_times[i+1],2)}_{round(end_times[i+1],2)}"
159159
uid3 = f"{id}_{round(start_times[i+2],2)}_{round(end_times[i+2],2)}"
160160

161-
if test_type == 'base':
161+
if test_type == 'base' or test_type == 'temporal_cot_oracle':
162162
narration1 = sorted_narrations[i]
163163
narration2 = sorted_narrations[i+1]
164164
narration3 = sorted_narrations[i+2]

llava/action/utils.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from llava.action.render_utils import render_frame
1818
from collections import defaultdict
1919
import json
20+
from llava.utils import rank0_print
2021

2122
def remove_sub_nouns(nlp, narration, verb, nouns):
2223
narration = copy.deepcopy(narration)
@@ -111,13 +112,7 @@ def generate_label_map(anno_root, action_representation):
111112
verb_maps[str(row['id'])] = row['key']
112113
for _, row in noun_classes_pd.iterrows():
113114
elements = row['key'].split(':')
114-
noun_maps[str(row['id'])] = ' '.join(elements[1:] + [elements[0]]) if len(elements) > 1 else row['key']
115-
# print ('verb_maps')
116-
# print (verb_maps)
117-
# print ('noun_maps')
118-
# print (noun_maps)
119-
# import sys
120-
# sys.exit()
115+
noun_maps[str(row['id'])] = ' '.join(elements[1:] + [elements[0]]) if len(elements) > 1 else row['key']
121116
# Batch processing setup
122117
if 'cut' in action_representation:
123118
import spacy
@@ -235,33 +230,31 @@ def format_task_related_prompt(question, question_type, meta_data = None, perspe
235230
perspective_prefix = "You are seeing this video from egocentric view and you are the person. Your hands are sometimes interacting with objects. What action are you doing? "
236231
elif perspective == "third_person":
237232
perspective_prefix = "The video is taken from egocentric view. The person's hands are sometimes interacting with objects. What action is the person doing?"
238-
239-
if question_type.startswith("mc_") or question_type == 'temporal_cot':
233+
234+
if learn_neighbor_actions == "prior" and meta_data:
235+
prev2_narration = meta_data['prev2_narration']
236+
prev2_offset = meta_data['prev2_offset']
237+
prev1_narration = meta_data['prev1_narration']
238+
prev1_offset = meta_data['prev1_offset']
239+
cur_narration = meta_data['cur_narration']
240+
241+
if question_type.startswith("mc_") or question_type.startswith('temporal_cot'):
240242

241243
if question_type.startswith("mc_") and learn_neighbor_actions == "prior" and meta_data and random.random() < 0.3:
242244
# this means it's training time and we are learning the prior actions
243245
prefix = f"{perspective_prefix}\n"
244246
assert isinstance(question, list)
245247
suffix = ", ".join(question)
246-
prev2_narration = meta_data['prev2_narration']
247-
prev2_offset = meta_data['prev2_offset']
248-
prev1_narration = meta_data['prev1_narration']
249-
prev1_offset = meta_data['prev1_offset']
250-
cur_narration = meta_data['cur_narration']
248+
251249
suffix = f"{prev2_offset} seconds ago, you started an action {prev2_narration}. {prev1_offset} seconds ago, you started an action {prev1_narration}. What action are you currently performing? Here are the options of actions you can select:\n" + suffix
252250
ret = prefix + suffix
253-
elif question_type == "temporal_cot" and learn_neighbor_actions == "prior" and meta_data:
251+
elif question_type.startswith("temporal_cot") and learn_neighbor_actions == "prior" and meta_data:
254252
# means it's test time
255253
prefix = f"{perspective_prefix}\n"
256254
assert isinstance(question, list)
257-
suffix = ", ".join(question)
258-
prev2_narration = meta_data['prev2_narration']
259-
prev2_offset = meta_data['prev2_offset']
260-
prev1_narration = meta_data['prev1_narration']
261-
prev1_offset = meta_data['prev1_offset']
262-
cur_narration = meta_data['cur_narration']
263-
suffix = f"{prev2_offset} seconds ago, you started an action {prev2_narration}. {prev1_offset} seconds ago, you started an action {prev1_narration}. What action are you currently performing? Here are the options of actions you can select:\n" + suffix
264-
ret = prefix + suffix
255+
suffix = ", ".join(question)
256+
suffix = f"{prev2_offset} seconds ago, you started an action {prev2_narration}. {prev1_offset} seconds ago, you started an action {prev1_narration}. What action are you currently performing? Explain your thoughts and then answer the question. Here are the options of actions you can select:\n" + suffix
257+
ret = prefix + suffix
265258
else:
266259
action_rep_suffix = "Given multiple choices, format your answer briefly such as 'A. move knife'. "
267260
prefix = f"{perspective_prefix}{action_rep_suffix}\n"
@@ -271,7 +264,11 @@ def format_task_related_prompt(question, question_type, meta_data = None, perspe
271264
ret = prefix + suffix
272265

273266
elif question_type == "direct_narration":
274-
ret = f"{perspective_prefix} What action are you performing? Give a short sentence such as 'move knife'."
267+
268+
if learn_neighbor_actions == "prior" and meta_data and random.random() < 0.5:
269+
ret = f"{perspective_prefix} {prev2_offset} seconds ago, you started an action {prev2_narration}. {prev1_offset} seconds ago, you started an action {prev1_narration}. What action are you currently performing? Give a short sentence such as 'move knife'. "
270+
else:
271+
ret = f"{perspective_prefix} What action are you performing? Give a short sentence such as 'move knife'."
275272

276273
elif question_type == "temporal_detection":
277274
ret = question

0 commit comments

Comments
 (0)