Skip to content

Commit 6ff8a9f

Browse files
committed
merge conflict
2 parents 34a8659 + 72db420 commit 6ff8a9f

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

action/chatgpt_utils.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,21 @@ class GPTInferenceAnnotator(ChatGPT):
150150
Given the images, this class will annotate the video frames
151151
"""
152152

153-
def __init__(self, root, prediction_save_folder, clip_length = 4, debug = False):
153+
def __init__(self,
154+
root,
155+
prediction_save_folder,
156+
clip_length = 4,
157+
debug = False,
158+
topk = 10
159+
):
154160
super().__init__(clip_length = clip_length)
155161
self.root = root
156162
self.prediction_save_folder = prediction_save_folder
157163
self.prediction_analysis = PredictionAnalysis(self.prediction_save_folder)
158164
self.prediction_analysis.load()
159165
self.data = self.prediction_analysis.data
160166
self.debug = debug
167+
self.topk = topk
161168

162169
def multi_process_run(self):
163170
prediction_analysis = PredictionAnalysis(self.prediction_save_folder)
@@ -187,7 +194,14 @@ def multi_process_run(self):
187194
def parse_item(self, item):
188195

189196
gt_name = item['gt_name']
190-
avion_predictions = item['avion_preds']['predictions']
197+
avion_predictions = item['avion_preds']['predictions']
198+
assert self.topk <= len(avion_predictions)
199+
avion_predictions = avion_predictions[:self.topk]
200+
# _avion_predictions = [e.replace(':', ' ', 1) for e in avion_predictions]
201+
# if gt_name not in _avion_predictions:
202+
# print ('gt_name not in avion_predictions')
203+
# else:
204+
# print ('gt_name in avion_predictions')
191205

192206
vid_path = item['vid_path'][0]
193207
start_second = item['start_second']
@@ -453,12 +467,17 @@ def explore_wrong_examples(root, prediction_save_folder, debug = False):
453467
debug = debug)
454468
annotator.explore_wrong_examples()
455469

456-
def multi_process_inference(root, prediction_save_folder, debug = False):
470+
def multi_process_inference(root,
471+
prediction_save_folder,
472+
clip_length = 4,
473+
topk = 10,
474+
debug = False):
457475

458476
annotator = GPTInferenceAnnotator(root,
459477
prediction_save_folder,
460-
clip_length = 32,
461-
debug = debug)
478+
clip_length = clip_length,
479+
debug = debug,
480+
topk = topk)
462481

463482
annotator.multi_process_run()
464483

@@ -488,5 +507,8 @@ def calculate_gpt_accuracy(path):
488507

489508
#multi_process_annotate(train_file_path, root)
490509
#explore_wrong_examples(root, pred_folder)
491-
multi_process_inference(root, pred_folder, debug = True)
492-
#calculate_gpt_accuracy('valset_chatgpt_inference_results/gpt-4o-avion_top10_4frames.json')
510+
multi_process_inference(root,
511+
pred_folder,
512+
debug = False,
513+
clip_length = 4,
514+
topk = 5)

action/ek_eval.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,9 +465,7 @@ def evaluate_on_EK100(eval_args,
465465
global_total_samples.add_(1)
466466

467467
logger.info(f'Process {dist.get_rank()} - local_total_samples: {local_total_samples:.4f}')
468-
469468
logger.info(f'Process {dist.get_rank()} - loca_llava_correct: {llava_correct:.4f}')
470-
471469
logger.info(f'Process {dist.get_rank()} - local_running_corrects: {local_running_corrects:.4f}')
472470

473471

action/llava_ov_inference.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,9 @@ def llava_video_process(
8383
video_duration = time_meta['duration'].item()
8484
n_frames = time_meta['n_frames'].item()
8585
frame_time = time_meta['frame_time']
86-
frame_time = [e[0] for e in frame_time]
87-
time_instruciton = f"The video lasts for {video_duration:.2f} seconds, and {n_frames} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video."
86+
print ('frame time', frame_time)
87+
frame_time = frame_time[0]
88+
time_instruciton = f"You are seeing a video taken from egocentric view. The video lasts for {video_duration:.2f} seconds, and {n_frames} frames are uniformly sampled from it. What is the person doing? Format your answer letter. verb noun such as A. move knife."
8889

8990
frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].cuda().to(torch.bfloat16)
9091

@@ -97,12 +98,15 @@ def llava_video_process(
9798

9899
question = DEFAULT_IMAGE_TOKEN + f"{time_instruciton}\n:{options}"
99100

101+
print ('what is the question')
102+
print (question)
100103

101104
conv = copy.deepcopy(conv_templates[conv_template])
102105
conv.append_message(conv.roles[0], question)
103106
conv.append_message(conv.roles[1], None)
104107
prompt_question = conv.get_prompt()
105108

109+
106110
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
107111
image_sizes = [frame.size for frame in video_frames]
108112

llava/train/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1232,7 +1232,7 @@ def _get_item(self, i) -> Dict[str, torch.Tensor]:
12321232
processor = self.data_args.image_processor
12331233
image = processor.preprocess(video, return_tensors="pt")["pixel_values"]
12341234
if self.data_args.add_time_instruction:
1235-
time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {num_frames_to_sample} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video."
1235+
time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {num_frames_to_sample} frames are uniformly sampled from it. Please answer the following questions related to this video."
12361236
sources[0]["conversations"][0]["value"] = f'{DEFAULT_IMAGE_TOKEN}\n{time_instruciton}\n{sources[0]["conversations"][0]["value"].replace(DEFAULT_IMAGE_TOKEN, "")}'
12371237
image = [(image, video[0].size, "video")]
12381238
sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args)

0 commit comments

Comments
 (0)