Skip to content

Commit 11e0e64

Browse files
author
Ye Shaokai
committed
updates
1 parent 21742a4 commit 11e0e64

File tree

3 files changed

+181
-102
lines changed

3 files changed

+181
-102
lines changed

action/ek_eval.py

Lines changed: 58 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,13 @@ def video_loader(root, vid, ext, second, end_second,
8585
except decord.DECORDError as error:
8686
print(error)
8787
frames = vr.get_batch([0] * len(frame_ids)).asnumpy()
88-
88+
8989
return torch.from_numpy(frames.astype(np.float32))
9090

9191
else:
92+
time_meta = {}
93+
94+
time_meta['duration'] = end_second - second
9295
chunk_start = int(second) // chunk_len * chunk_len
9396
chunk_end = int(end_second) // chunk_len * chunk_len
9497
while True:
@@ -109,6 +112,7 @@ def video_loader(root, vid, ext, second, end_second,
109112
num_segments=clip_length, jitter=jitter
110113
)
111114
all_frames = []
115+
all_frame_ids = []
112116
# allocate absolute frame-ids into the relative ones
113117
for chunk in range(chunk_start, chunk_end + chunk_len, chunk_len):
114118
rel_frame_ids = list(filter(lambda x: int(chunk * fps) <= x < int((chunk + chunk_len) * fps), frame_ids))
@@ -127,11 +131,17 @@ def video_loader(root, vid, ext, second, end_second,
127131
except IndexError:
128132
print(root, vid, ext, second, end_second)
129133
all_frames.append(frames)
134+
all_frame_ids.append(frame_ids)
130135
if sum(map(lambda x: x.shape[0], all_frames)) == clip_length:
131136
break
132137
res = torch.from_numpy(np.concatenate(all_frames, axis=0).astype(np.float32))
138+
time_meta['n_frames'] = res.shape[0]
139+
all_frame_ids = np.concatenate(all_frame_ids, axis = 0)
140+
frame_time = [e/fps for e in all_frame_ids]
141+
frame_time = [f"{i:.2f}s" for i in frame_time]
142+
time_meta['frame_time'] = frame_time
133143
assert res.shape[0] == clip_length, "{}, {}, {}, {}, {}, {}, {}".format(root, vid, second, end_second, res.shape[0], rel_frame_ids, frame_ids)
134-
return res
144+
return res, time_meta
135145

136146

137147
class VideoCaptionDatasetBase(torch.utils.data.Dataset):
@@ -194,53 +204,11 @@ def get_raw_item(
194204
fast_rrc=False, rrc_params=(224, (0.5, 1.0)),
195205
fast_rcc=False, rcc_params=(224,),
196206
):
197-
if self.dataset == 'ego4d':
198-
vid, start_second, end_second, narration = self.samples[i][:4]
199-
frames = video_loader(self.root, vid, 'mp4',
200-
start_second, end_second,
201-
chunk_len=chunk_len,
202-
clip_length=clip_length,
203-
threads=threads,
204-
fast_rrc=fast_rrc,
205-
rrc_params=rrc_params,
206-
fast_rcc=fast_rcc,
207-
rcc_params=rcc_params,
208-
jitter=is_training)
209-
if isinstance(narration, list):
210-
if narration_selection == 'random':
211-
narration = random.choice(narration)
212-
elif narration_selection == 'concat':
213-
narration = '. '.join(narration)
214-
elif narration_selection == 'list':
215-
pass
216-
else:
217-
raise ValueError
218-
return frames, narration
219-
elif self.dataset == 'ek100_mir':
220-
vid_path, start_second, end_second, fps, narration, verb, noun = self.samples[i]
221-
frames = video_loader(self.root, vid_path, 'MP4',
222-
start_second, end_second,
223-
chunk_len=chunk_len, fps=fps,
224-
clip_length=clip_length,
225-
threads=threads,
226-
fast_rrc=fast_rrc,
227-
rrc_params=rrc_params,
228-
fast_rcc=fast_rcc,
229-
rcc_params=rcc_params,
230-
jitter=is_training)
231-
if is_training:
232-
positive_list = np.where(self.relevancy_mat[i] > self.relevancy)[0].tolist()
233-
if positive_list != []:
234-
pos = random.sample(positive_list, min(len(positive_list), 1))[0]
235-
if pos < len(self.metadata_sentence) and pos < self.relevancy_mat.shape[1]:
236-
return frames, (self.metadata_sentence.iloc[pos][1], self.relevancy_mat[i][pos])
237-
else:
238-
return frames, (narration, 1)
239-
elif self.dataset == 'ek100_cls':
207+
if self.dataset == 'ek100_cls':
240208
vid_path, start_second, end_second, fps, narration, verb, noun = self.samples[i]
241209
# chunk length is the chunked video clip length
242210
# clip length is number of frames we want to sample from the clip
243-
frames = video_loader(self.root, vid_path, 'MP4',
211+
frames, time_meta = video_loader(self.root, vid_path, 'MP4',
244212
start_second, end_second,
245213
chunk_len=chunk_len, fps=fps,
246214
clip_length=clip_length,
@@ -250,7 +218,7 @@ def get_raw_item(
250218
fast_rcc=fast_rcc,
251219
rcc_params=rcc_params,
252220
jitter=is_training)
253-
return frames, '{}:{}'.format(verb, noun)
221+
return frames, '{}:{}'.format(verb, noun), time_meta
254222
else:
255223
raise NotImplementedError
256224

@@ -303,7 +271,7 @@ def __init__(
303271
self.mc_generator = MultiChoiceGenerator(self.ann_root)
304272

305273
def __getitem__(self, i):
306-
frames, label = self.get_raw_item(
274+
frames, label, time_meta = self.get_raw_item(
307275
i, is_training=self.is_training,
308276
chunk_len=self.chunk_len,
309277
num_clips=self.num_clips,
@@ -317,13 +285,15 @@ def __getitem__(self, i):
317285
sparse_sample=self.sparse_sample,
318286
)
319287

288+
# for llava-video to work, we also need time meta data.
289+
320290
# apply transformation
321291
if self.transform is not None:
322292
frames = self.transform(frames)
323293

324294
data = self.mc_generator.generate_multi_choice(label, self.topk_predictions)
325295

326-
return frames, data
296+
return frames, data, time_meta
327297

328298

329299

@@ -350,7 +320,7 @@ def get_args_parser():
350320

351321
# llava related
352322
# llm size is type of string and can only be '7b' or '5b' etc.
353-
parser.add_argument('--llm_size', default='7b', type=str, help='llm size')
323+
parser.add_argument('--pretrained_name', default = '', type = str, help ='the name in huggingface')
354324
parser.add_argument('--llava_num_frames', default=16, type=int, help='number of frames for llava')
355325
## avaion refinement
356326
parser.add_argument('--action_predictions', default=None, type=str, help='path to action predictions')
@@ -362,13 +332,15 @@ def get_args_parser():
362332
def prepare_llava(pretrained):
363333

364334
import warnings
365-
from llava.model.builder import load_pretrained_model
366335
warnings.filterwarnings("ignore")
367-
# Load the OneVision model
336+
from llava.model.builder import load_pretrained_model
368337
model_name = "llava_qwen"
369338

370339
device_map = "auto"
371-
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
340+
print ('pretrained???', pretrained)
341+
#tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
342+
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map) # Add any other thing you want to pass in llava_model_args
343+
372344

373345
return tokenizer, model, image_processor, max_length
374346

@@ -392,7 +364,9 @@ def get_topk_predictions(data, idx, k):
392364

393365
return mc_data
394366

395-
def ensemble_llava_evaluation(gt_name,
367+
def ensemble_llava_evaluation(
368+
pretrained_name,
369+
gt_name,
396370
frames,
397371
tokenizer,
398372
model,
@@ -402,6 +376,7 @@ def ensemble_llava_evaluation(gt_name,
402376
num_frames,
403377
temperature = 0,
404378
ensemble_k = 1,
379+
time_meta = None,
405380
is_test = False
406381
):
407382
"""
@@ -424,20 +399,24 @@ def ensemble_llava_evaluation(gt_name,
424399
rank0_print ('generated new option sequence')
425400
rank0_print (options)
426401

427-
pred = llava_inference(frames,
428-
tokenizer,
429-
model,
430-
image_processor,
431-
mc_data,
432-
clip_length = clip_length,
433-
num_frames=num_frames,
434-
temperature = temperature,
435-
is_test = is_test
402+
pred = llava_inference(
403+
pretrained_name,
404+
frames,
405+
tokenizer,
406+
model,
407+
image_processor,
408+
mc_data,
409+
clip_length = clip_length,
410+
num_frames=num_frames,
411+
temperature = temperature,
412+
is_test = is_test,
413+
time_meta = time_meta
436414
)
437415

438416
rank0_print ('llava pred', pred, 'avion_pred', avion_pred, 'gt_name', gt_name)
439-
sep = pred.index('.')
440-
pred = pred[sep+1:].strip()
417+
if '.' in pred:
418+
sep = pred.index('.')
419+
pred = pred[sep+1:].strip()
441420
preds.append(pred)
442421

443422
counter = Counter(preds)
@@ -482,14 +461,9 @@ def evaluate_on_EK100(eval_args,
482461

483462
running_corrects = 0
484463
total_samples = 0
485-
486-
if not eval_args.action_predictions:
487-
log_filename = f'llava_ov_{eval_args.llava_num_frames}f_{eval_args.llm_size}.log'
488-
else:
489-
log_filename = f'llava_ov_{eval_args.llava_num_frames}f_{eval_args.llm_size}_action_{eval_args.topk_predictions}.log'
490-
464+
491465
# Set up logging
492-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename=log_filename, filemode='w')
466+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filemode='w')
493467

494468
console_handler = logging.StreamHandler(sys.stdout)
495469
console_handler.setLevel(logging.INFO)
@@ -502,23 +476,24 @@ def evaluate_on_EK100(eval_args,
502476

503477
logger = logging.getLogger(__name__)
504478

505-
pretrained = f"lmms-lab/llava-onevision-qwen2-{eval_args.llm_size}-ov"
479+
pretrained = f"lmms-lab/{eval_args.pretrained_name}".strip()
480+
print ('pretrained', pretrained)
506481

507482
# so we know it's evaluation during training
508483
finish_early = model is not None
509484

510485
if model is None:
511-
if hasattr(eval_args, "llava_checkpoint"):
486+
if args.llava_checkpoint is not None:
512487
pretrained = eval_args.llava_checkpoint
513-
tokenizer, model, image_processor, max_length = prepare_llava(pretrained)
488+
tokenizer, model, image_processor, _ = prepare_llava(pretrained)
514489

515490
if eval_args.action_predictions:
516491
with open(eval_args.action_predictions, 'r') as f:
517492
predictions = json.load(f)
518493

519494
avaion_correct = 0
520495

521-
for idx, (frames, mc_data) in tqdm(enumerate(val_dataloader)):
496+
for idx, (frames, mc_data, time_meta) in tqdm(enumerate(val_dataloader)):
522497

523498
gt_name = mc_data['gt_answer_name'][0][0]
524499

@@ -531,24 +506,22 @@ def evaluate_on_EK100(eval_args,
531506
# we don't want to evaluate the whole thing
532507
# let's evaluate 1000 samples to get the complete picture
533508
if finish_early and idx>999:
534-
break
535-
536-
# pred = llava_inference(frames, tokenizer, model, image_processor, mc_data, clip_length = eval_args.clip_length, num_frames=eval_args.llava_num_frames)
537-
538-
# # if valid letter is found in the prediction, then we will use that as the prediction
539-
# rank0_print ('llava pred', pred, 'avion_pred', avion_pred, 'gt_name', gt_name)
509+
break
540510

541511
# Update running corrects and total samples
542-
running_corrects += ensemble_llava_evaluation(gt_name,
512+
running_corrects += ensemble_llava_evaluation(
513+
eval_args.pretrained_name,
514+
gt_name,
543515
frames,
544516
tokenizer,
545517
model,
546518
image_processor,
547519
mc_data,
548520
eval_args.clip_length,
549521
eval_args.llava_num_frames,
550-
temperature = 2.0,
551-
ensemble_k = 5,
522+
temperature = 0,
523+
ensemble_k = 1,
524+
time_meta = time_meta,
552525
is_test = not finish_early)
553526

554527
total_samples += 1

0 commit comments

Comments
 (0)