Skip to content

Commit 223b0e8

Browse files
author
Ye Shaokai
committed
several fixes to temporal dpo
1 parent 11910e9 commit 223b0e8

File tree

6 files changed

+147
-19
lines changed

6 files changed

+147
-19
lines changed

.vscode/launch.json

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@
250250
// "--model_name_or_path", "lmms-lab/llava-onevision-qwen2-0.5b-ov",
251251
// "--version", "qwen_1_5",
252252
// "--data_path", "scripts/train/simple_avion_top5_gt_and_direct.yaml",
253-
// "--video_folder", "/data/shaokai/",
253+
// "--video_folder", "/data/shaokai/EK100_512/",
254254
// "--mm_tunable_parts", "mm_vision_tower,mm_mlp_adapter,mm_language_model",
255255
// "--mm_vision_tower_lr", "2e-6",
256256
// "--vision_tower", "google/siglip-so400m-patch14-384",
@@ -288,7 +288,7 @@
288288
// "--torch_compile_backend", "inductor",
289289
// "--dataloader_drop_last", "True",
290290
// "--frames_upbound", "4",
291-
// "--root", "/data/shaokai/EK100",
291+
// "--root", "/data/shaokai/EK100_512/EK100",
292292
// "--action_predictions", "/data/shaokai/AVION_PREDS/avion_pred_ids_val.json",
293293
// "--val_metadata", "/data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv",
294294
// "--llava_num_frames", "4",
@@ -345,8 +345,9 @@
345345
"--action_predictions","/data/shaokai/TIM_PREDS/tim_pred_ids_val.json",
346346
"--action_representation", "official_key",
347347
"--topk_predictions", "5",
348-
"--test_type", "base",
349-
"--output_dir", "test_0.5b_direct",
348+
"--test_type", "temporal_cot",
349+
"--output_dir", "test_0.5b_direct",
350+
"--learn_neighbor_actions"
350351
],
351352
"console": "integratedTerminal",
352353
"justMyCode": false,

llava/action/dataset.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,99 @@ def __getitem__(self, i):
221221

222222
return frames, data, time_meta, i
223223

224+
225+
226+
227+
class VideoTemporalMultiChoiceDataset(VideoCaptionDatasetBase):
228+
def __init__(
229+
self, dataset, root, metadata, transform=None,
230+
is_training=True, label_mapping=None,
231+
num_clips=1,
232+
chunk_len=300,
233+
clip_length=32, clip_stride=2,
234+
threads=1,
235+
fast_rrc=False,
236+
rrc_params=(224, (0.5, 1.0)),
237+
fast_rcc=False,
238+
rcc_params=(224,),
239+
sparse_sample=False,
240+
labels = None,
241+
is_trimmed=True,
242+
eval_args = None,
243+
topk_predictions = 5,
244+
verb_maps = None,
245+
noun_maps = None,
246+
eval_result_folder = None,
247+
action_representation = 'GT_random_narration',
248+
mapping_vn2narration = None,
249+
avion_predictions = None,
250+
n_narrations = -1,
251+
):
252+
super().__init__(dataset, root, metadata, is_trimmed=is_trimmed)
253+
254+
self.transform = transform
255+
self.is_training = is_training
256+
self.label_mapping = label_mapping
257+
self.num_clips = num_clips
258+
self.chunk_len = chunk_len
259+
self.clip_length = clip_length
260+
self.clip_stride = clip_stride
261+
self.threads = threads
262+
self.fast_rrc = fast_rrc
263+
self.rrc_params = rrc_params
264+
self.fast_rcc = fast_rcc
265+
self.rcc_params = rcc_params
266+
self.sparse_sample = sparse_sample
267+
self.eval_args = eval_args
268+
self.verb_maps = verb_maps
269+
self.noun_maps = noun_maps
270+
self.vn_list = list(self.label_mapping.keys())
271+
272+
self.labels = labels
273+
self.topk_predictions = topk_predictions
274+
self.ann_root = Path(metadata).parent
275+
self.mc_generator = AvionMultiChoiceGenerator(self.ann_root)
276+
self.rank = dist.get_rank()
277+
self.prediction_analysis = PredictionAnalysis(rank = self.rank, save_folder = eval_result_folder)
278+
self.action_representation = action_representation
279+
self.n_narrations = n_narrations
280+
self.mapping_vn2narration = mapping_vn2narration
281+
self.avion_predictions = avion_predictions
282+
283+
def __getitem__(self, i):
284+
frames, label, time_meta = self.get_raw_item(
285+
i, is_training=self.is_training,
286+
chunk_len=self.chunk_len,
287+
num_clips=self.num_clips,
288+
clip_length=self.clip_length,
289+
clip_stride=self.clip_stride,
290+
threads=self.threads,
291+
fast_rrc=self.fast_rrc,
292+
rrc_params=self.rrc_params,
293+
fast_rcc=self.fast_rcc,
294+
rcc_params=self.rcc_params,
295+
sparse_sample=self.sparse_sample,
296+
)
297+
298+
# for llava-video to work, we also need time meta data.
299+
300+
# apply transformation
301+
if self.transform is not None:
302+
frames = self.transform(frames)
303+
narration = self.samples[i][4]
304+
avion_preds = self.avion_predictions[str(i)]['predictions']
305+
306+
data = self.mc_generator.generate_multi_choice(label,
307+
avion_preds,
308+
narration,
309+
self.topk_predictions,
310+
self.action_representation,
311+
self.n_narrations,
312+
self.labels,
313+
self.mapping_vn2narration,
314+
self.verb_maps,
315+
self.noun_maps,
316+
is_train = False) # note we only use this dataset for evaluation for now.
317+
318+
319+
return frames, data, time_meta, i

llava/action/ek_eval.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from llava.action.utils import generate_label_map, match_answer
1515
from collections import Counter
1616
import torch.distributed as dist
17-
from llava.action.dataset import VideoMultiChoiceDataset
17+
from llava.action.dataset import VideoMultiChoiceDataset, VideoTemporalMultiChoiceDataset
1818
import torchvision.io as io
1919
import re
2020

@@ -124,8 +124,9 @@ def get_args_parser():
124124
'random_narration_cut', 'top1_narration_cut', 'topk_narration_cut_key',
125125
'GT_key', 'GT_random_narration', 'GT_random_narration_cut', 'gpt_narration'])
126126
parser.add_argument('--n_narrations', default = -1, type = int)
127-
parser.add_argument('--test_type', default = 'base', type = str, choices = ['caption', 'base', 'caption_then_answer', 'direct_narration'])
127+
parser.add_argument('--test_type', default = 'base', type = str, choices = ['caption', 'base', 'temporal_cot', 'caption_then_answer', 'direct_narration'])
128128
parser.add_argument('--learn_neighbor_actions', action='store_true', default = False)
129+
parser.add_argument('--pseudo_folder', default = None, type = str)
129130
parser.add_argument('--output_dir', default = None, type = str)
130131
return parser
131132

@@ -253,7 +254,7 @@ def evaluate_on_EK100(eval_args,
253254
if eval_args.action_predictions:
254255
with open(eval_args.action_predictions, 'r') as f:
255256
predictions = json.load(f)
256-
257+
257258
val_dataset = VideoMultiChoiceDataset(
258259
eval_args.dataset, eval_args.root, eval_args.val_metadata, val_transform_gpu,
259260
is_training=False, label_mapping=mapping_vn2act,
@@ -332,12 +333,11 @@ def collate_fn(batch):
332333
os.makedirs('debug_and_vis', exist_ok = True)
333334

334335

335-
uid_pad_dict = None
336336
lookup_table = None
337337
meta_data = None
338338
if eval_args.learn_neighbor_actions:
339339
from llava.action.generate_interval_pred import get_lookup_dict
340-
lookup_table = get_lookup_dict(eval_args.val_metadata)
340+
lookup_table = get_lookup_dict(eval_args.val_metadata, test_type = eval_args.test_type, pseudo_folder = eval_args.pseudo_folder)
341341

342342

343343
for idx, (frames, mc_data, time_meta, global_index) in tqdm(enumerate(val_dataloader)):

llava/action/generate_interval_pred.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,38 @@ def build_uid_pad_dict(ann_file,
103103
return uid_to_neighbors
104104

105105

106+
def get_pseudo_dict(pseudo_folder, delta = 3):
107+
import glob
108+
109+
110+
files = glob.glob(os.path.join(pseudo_folder, 'prediction*.json'))
111+
112+
pseudo_data = {}
113+
ret = {}
114+
for file in files:
115+
with open(file, 'r') as f:
116+
pseudo_data.update(json.load(f))
117+
for k,v in pseudo_data.items():
118+
start_timestamp = round(float(v['start_second']),2)
119+
end_timestamp = round(float(v['end_second']), 2)
120+
vid = v['vid_path'].replace('/', '-')
121+
uid = f"{vid}_{start_timestamp}_{end_timestamp}"
122+
ret[uid] = v['llava_pred']
123+
124+
assert len(ret) == len(pseudo_data)
125+
return ret
106126

107-
def get_lookup_dict(ann_file, delta = 3):
127+
def get_lookup_dict(ann_file, test_type = 'base', delta = 3, pseudo_folder = None):
108128

109129
vid_to_intervals, vid_to_gt_narration, _ = get_annotated_intervals(ann_file)
110130
table = {}
111131

132+
pseudo_dict = None
133+
if test_type == 'temporal_cot':
134+
pseudo_dict = get_pseudo_dict(pseudo_folder)
135+
112136
for vid, intervals in vid_to_intervals.items():
113-
114-
#sorted_intervals = sorted(intervals, key=lambda x: x[1])
115-
137+
116138
sorted_indices = sorted(range(len(intervals)), key=lambda i: intervals[i][1])
117139

118140
sorted_intervals = [intervals[i] for i in sorted_indices]
@@ -136,10 +158,14 @@ def get_lookup_dict(ann_file, delta = 3):
136158
uid2 = f"{id}_{round(start_times[i+1],2)}_{round(end_times[i+1],2)}"
137159
uid3 = f"{id}_{round(start_times[i+2],2)}_{round(end_times[i+2],2)}"
138160

139-
140-
narration1 = sorted_narrations[i]
141-
narration2 = sorted_narrations[i+1]
142-
narration3 = sorted_narrations[i+2]
161+
if test_type == 'base':
162+
narration1 = sorted_narrations[i]
163+
narration2 = sorted_narrations[i+1]
164+
narration3 = sorted_narrations[i+2]
165+
elif test_type == 'temporal_cot':
166+
narration1 = pseudo_dict[uid1]
167+
narration2 = pseudo_dict[uid2]
168+
narration3 = sorted_narrations[i+2]
143169

144170
table[uid3] = {'prev2_narration': narration1,
145171
'prev2_offset': round(start_times[i+2] - start_times[i],2),

llava/action/llava_inference.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def llava_inference(
5252
question_type = "direct_narration"
5353
elif test_type == 'caption' or test_type == 'debug':
5454
question_type = "gpt-gt-reason"
55+
elif test_type == 'temporal_cot':
56+
question_type = 'temporal_cot'
57+
5558
if test_type == 'caption_then_answer':
5659
caption_answer = llava_inference([video_frames],
5760
tokenizer,
@@ -73,7 +76,8 @@ def llava_inference(
7376
learn_neighbor_actions = learn_neighbor_actions,
7477
include_time_instruction= False)
7578

76-
question = f"You observed the video before and wrote down the notes: {caption_answer}. Now you watch the same video again and you can do better. " + question
79+
question = f"You observed the video before and wrote down the notes: {caption_answer}. Now you watch the same video again and you can do better. " + question
80+
7781
else:
7882
question = format_llava_prompt(DEFAULT_IMAGE_TOKEN,
7983
options,

llava/action/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ def format_task_related_prompt(question, question_type, meta_data = None, perspe
235235
perspective_prefix = "You are seeing this video from egocentric view and you are the person. Your hands are sometimes interacting with obects. "
236236
elif perspective == "third_person":
237237
perspective_prefix = "The video is taken from egocentric view. What action is the person performing? "
238-
if question_type.startswith("mc_"):
238+
239+
if question_type.startswith("mc_") or question_type == 'temporal_cot':
239240

240241
if learn_neighbor_actions and meta_data:
241242
prefix = f"{perspective_prefix}\n"

0 commit comments

Comments
 (0)