Skip to content

Commit ff72c19

Browse files
author
Ye Shaokai
committed
able to do distributed evaluation
1 parent f3b6f6b commit ff72c19

File tree

4 files changed

+46
-25
lines changed

4 files changed

+46
-25
lines changed

action/ek_eval.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
import argparse
1010
import decord
11-
from torch.utils.data import DataLoader
11+
from torch.utils.data import DataLoader, DistributedSampler
1212
from tqdm import tqdm
1313
from pathlib import Path
1414
import sys
@@ -21,6 +21,9 @@
2121
from action.utils import generate_label_map, MultiChoiceGenerator, match_answer, parse_avion_predictions
2222
import copy
2323
from collections import Counter
24+
import torch.distributed as dist
25+
26+
dist.init_process_group(backend='nccl')
2427

2528
def datetime2sec(str):
2629
hh, mm, ss = str.split(':')
@@ -138,7 +141,8 @@ def video_loader(root, vid, ext, second, end_second,
138141
time_meta['n_frames'] = res.shape[0]
139142
all_frame_ids = np.concatenate(all_frame_ids, axis = 0)
140143
frame_time = [e/fps for e in all_frame_ids]
141-
frame_time = [f"{i:.2f}s" for i in frame_time]
144+
frame_time-= frame_time[0]
145+
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
142146
time_meta['frame_time'] = frame_time
143147
assert res.shape[0] == clip_length, "{}, {}, {}, {}, {}, {}, {}".format(root, vid, second, end_second, res.shape[0], rel_frame_ids, frame_ids)
144148
return res, time_meta
@@ -342,7 +346,7 @@ def prepare_llava(pretrained):
342346
if 'video' in pretrained:
343347
overwrite_config = {'tie_word_embeddings': False, 'use_cache': True, "vocab_size": 152064}
344348

345-
print ('overwrite config', overwrite_config)
349+
346350
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained,
347351
None,
348352
model_name,
@@ -404,9 +408,7 @@ def ensemble_llava_evaluation(
404408
random.shuffle(options)
405409
for idx, (option, letter) in enumerate(zip(options, letters)):
406410
sep = option.index('.')
407-
options[idx] = f'{letter}.{option[sep+1:]}'
408-
rank0_print ('generated new option sequence')
409-
rank0_print (options)
411+
options[idx] = f'{letter}.{option[sep+1:]}'
410412

411413
pred = llava_inference(
412414
pretrained_name,
@@ -466,10 +468,12 @@ def evaluate_on_EK100(eval_args,
466468

467469
)
468470

469-
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)
471+
if dist.is_initialized():
472+
sampler = DistributedSampler(val_dataset, shuffle=False)
473+
else:
474+
sampler = None
470475

471-
running_corrects = 0
472-
total_samples = 0
476+
val_dataloader = DataLoader(val_dataset, sampler = sampler, batch_size=1, shuffle=False)
473477

474478
# Set up logging
475479
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filemode='w')
@@ -500,8 +504,12 @@ def evaluate_on_EK100(eval_args,
500504
with open(eval_args.action_predictions, 'r') as f:
501505
predictions = json.load(f)
502506

503-
avaion_correct = 0
504-
507+
508+
509+
avaion_correct = torch.tensor(0, device='cuda')
510+
running_corrects = torch.tensor(0, device='cuda')
511+
total_samples = torch.tensor(0, device='cuda')
512+
505513
for idx, (frames, mc_data, time_meta) in tqdm(enumerate(val_dataloader)):
506514

507515
gt_name = mc_data['gt_answer_name'][0][0]
@@ -538,14 +546,28 @@ def evaluate_on_EK100(eval_args,
538546
# Calculate and log running mean accuracy
539547
running_accuracy = running_corrects / total_samples
540548

541-
logger.info(f'running accuracy: {running_accuracy:.4f}')
549+
logger.info(f'Process {dist.get_rank()} - Running accuracy: {running_accuracy:.4f}')
542550
if eval_args.action_predictions:
543551
avaion_accuracy = avaion_correct / total_samples
544552

545553

546-
logger.info(f'Running avaion accuracy after {total_samples} samples: {avaion_accuracy:.4f}')
547-
logger.info(f'Final accuracy: {running_accuracy:.4f}')
548-
return running_accuracy
554+
dist.all_reduce(running_corrects, op=dist.ReduceOp.SUM)
555+
dist.all_reduce(total_samples, op=dist.ReduceOp.SUM)
556+
if eval_args.action_predictions:
557+
dist.all_reduce(avaion_correct, op=dist.ReduceOp.SUM)
558+
559+
# Calculate global accuracy after reduction
560+
global_accuracy = running_corrects.item() / total_samples.item()
561+
if eval_args.action_predictions:
562+
global_avaion_accuracy = avaion_correct.item() / total_samples.item()
563+
564+
# Ensure only the main process (rank 0) prints the final result
565+
if dist.get_rank() == 0:
566+
if eval_args.action_predictions:
567+
logger.info(f'Global Avaion Accuracy: {global_avaion_accuracy:.4f}')
568+
logger.info(f'Final Global Accuracy: {global_accuracy:.4f}')
569+
570+
return global_accuracy
549571

550572

551573
if __name__ == '__main__':

action/llava_ov_inference.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,6 @@ def llava_video_process(
8383
n_frames = time_meta['n_frames'].item()
8484
frame_time = time_meta['frame_time']
8585
frame_time = [e[0] for e in frame_time]
86-
87-
print ("what is meta")
88-
print ('n_frame', n_frames)
89-
print ('true video frames', len(video_frames))
90-
print ('frame_time', frame_time)
91-
print ('video_duration', video_duration)
92-
print ('is_test', is_test)
93-
9486
time_instruciton = f"The video lasts for {video_duration:.2f} seconds, and {n_frames} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video."
9587

9688
frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].cuda().to(torch.bfloat16)

llava/train/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,7 @@ def _get_item(self, i) -> Dict[str, torch.Tensor]:
12021202
total_frames = len(frame_files)
12031203
sampled_indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int)
12041204

1205-
1205+
12061206
frame_time = [i/2 for i in sampled_indices]
12071207
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
12081208

llava/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,23 +146,30 @@ def process_EK100_video_with_decord(video_file, data_args, start_second, end_sec
146146

147147
# calculate frame_ids
148148
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=data_args.frames_upbound, jitter=False)
149-
frame_time = [i/fps for i in frame_ids]
149+
150+
150151

151152
all_frames = []
153+
all_frame_ids = []
152154
# allocate absolute frame-ids into the relative ones
153155
for chunk in range(chunk_start, chunk_end + chunk_len, chunk_len):
154156
rel_frame_ids = list(filter(lambda x: int(chunk * fps) <= x < int((chunk + chunk_len) * fps), frame_ids))
155157
rel_frame_ids = [int(frame_id - chunk * fps) for frame_id in rel_frame_ids]
156158
vr = VideoReader(os.path.join(video_file, '{}.MP4'.format(chunk)),ctx=cpu(0), num_threads=1)
157159
frames = vr.get_batch(rel_frame_ids).asnumpy()
158160
all_frames.append(frames)
161+
all_frame_ids.append(frame_ids)
159162
vr.seek(0)
160163
if sum(map(lambda x: x.shape[0], all_frames)) == data_args.frames_upbound:
161164
break
162165

163166
video = np.concatenate(all_frames, axis=0).astype(np.float32)
164167

168+
all_frame_ids = np.concatenate(all_frame_ids, axis = 0)
169+
frame_time = [e/fps for e in all_frame_ids]
170+
frame_time-= frame_time[0]
165171
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
172+
166173
num_frames_to_sample = len(frame_ids)
167174

168175
return video, video_time, frame_time, num_frames_to_sample

0 commit comments

Comments
 (0)