Skip to content

Commit c5fe6d3

Browse files
author
Ye Shaokai
committed
fixed evaluation
1 parent a6fa160 commit c5fe6d3

File tree

6 files changed

+261
-78
lines changed

6 files changed

+261
-78
lines changed

action/chatgpt_utils.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import base64
2+
import io
3+
import json
4+
import os
5+
import cv2
6+
import numpy as np
7+
import openai
8+
from pydantic import BaseModel
9+
from multiprocessing.pool import Pool
10+
11+
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
12+
13+
GPT_MODEL = "gpt-4o-2024-08-06"
14+
15+
16+
class ImageOnlyResponse(BaseModel):
17+
"""
18+
"""
19+
explanation: str
20+
21+
class MultiChoiceResponse(BaseModel):
22+
"""
23+
The output format of the response
24+
"""
25+
26+
explanation: str
27+
28+
29+
30+
class GPTAnnotator:
31+
def __init__(self, prediction_file_path):
32+
with open(prediction_file_path, 'r') as f:
33+
self.prediction_file = json.load(f)
34+
35+
def prepare_multiple_images(self, images):
36+
"""
37+
38+
"""
39+
encoded_image_list = []
40+
41+
for image in images:
42+
# images from matplotlib etc.
43+
if isinstance(image, io.BytesIO):
44+
image_bytes = image
45+
base64_image = base64.b64encode(image_bytes.getvalue()).decode("utf-8")
46+
# images from opencv
47+
elif isinstance(image, np.ndarray):
48+
result, buffer = cv2.imencode(".jpeg", image)
49+
image_bytes = io.BytesIO(buffer)
50+
base64_image = base64.b64encode(image_bytes.getvalue()).decode("utf-8")
51+
52+
encoded_image_list.append(base64_image)
53+
54+
multi_image_content = [
55+
{
56+
"type": "image_url",
57+
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
58+
}
59+
for encoded_image in encoded_image_list
60+
]
61+
62+
return multi_image_content
63+
64+
65+
def annotate(self, images):
66+
"""
67+
Annotate to do image caption only
68+
"""
69+
pass
70+
71+
def annotate_with_multichoice(self, images, mc_data):
72+
"""
73+
Annotate with mc_data
74+
75+
{
76+
77+
}
78+
79+
"""
80+
81+
temperature = 0
82+
include_images = True
83+
84+
system_prompt_prefix = """Inspect the images from the video and explain why the answer of the multi-choice question is D. """
85+
system_prompt_suffix = """Yes"""
86+
87+
system_prompt = system_prompt_prefix + system_prompt_suffix
88+
89+
system_message = [{"role": "system", "content": system_prompt}]
90+
91+
if include_images:
92+
multi_image_content = self.prepare_multiple_images(images)
93+
multi_modal_content = [{"type": "text", "text": ""}] + multi_image_content
94+
user_message = [{"role": "user", "content": multi_modal_content}]
95+
else:
96+
user_message = [{"role": "user", "content": ""}]
97+
98+
response = client.beta.chat.completions.parse(
99+
model=GPT_MODEL,
100+
messages=system_message + user_message,
101+
response_format = MultiChoiceResponse,
102+
temperature = temperature
103+
)
104+
105+
return response.choices[0].message.parsed
106+
107+
108+
def annotate_using_chatgpt():
109+
"""
110+
Multi processing to speed up
111+
"""
112+
with Pool() as pool:
113+
pass
114+
#pool.starmap(annotate, task_args)
115+
116+
pass
117+
118+
def annotate_from_train_conv_file(train_file_path):
119+
pass
120+
121+
if __name__ == '__main__':
122+
train_file_path = '/storage-rcp-pure/upmwmathis_scratch/shaokai'
123+
annotate_from_train_conv_file(train_file_path)

action/ek_eval.py

Lines changed: 53 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -525,67 +525,68 @@ def evaluate_on_EK100(eval_args,
525525

526526
for idx, (frames, mc_data, time_meta, global_index) in tqdm(enumerate(val_dataloader)):
527527

528-
global_index = global_index.item()
528+
with torch.no_grad():
529+
global_index = global_index.item()
529530

530-
gt_name = mc_data['gt_answer_name'][0][0]
531-
local_avion_correct = torch.tensor(0.0, device=device)
532-
local_running_corrects = torch.tensor(0.0, device=device)
533-
local_total_samples = torch.tensor(0.0, device=device)
534-
535-
if eval_args.action_predictions:
536-
mc_data = get_topk_predictions(predictions, global_index, eval_args.topk_predictions)
537-
avion_pred = mc_data['avion_pred']
538-
if gt_name == avion_pred:
539-
local_avion_correct.add_(1)
540-
global_avion_correct.add_(1)
541-
542-
# we don't want to evaluate the whole thing
543-
# let's evaluate 1000 samples to get the complete picture
544-
if finish_early and idx> (1000 / dist.get_world_size()):
545-
break
546-
547-
# Update running corrects and total samples
531+
gt_name = mc_data['gt_answer_name'][0][0]
532+
local_avion_correct = torch.tensor(0.0, device=device)
533+
local_running_corrects = torch.tensor(0.0, device=device)
534+
local_total_samples = torch.tensor(0.0, device=device)
535+
536+
if eval_args.action_predictions:
537+
mc_data = get_topk_predictions(predictions, global_index, eval_args.topk_predictions)
538+
avion_pred = mc_data['avion_pred']
539+
if gt_name == avion_pred:
540+
local_avion_correct.add_(1)
541+
global_avion_correct.add_(1)
542+
543+
# we don't want to evaluate the whole thing
544+
# let's evaluate 1000 samples to get the complete picture
545+
if finish_early and idx> (1000 / dist.get_world_size()):
546+
break
547+
548+
# Update running corrects and total samples
549+
550+
llava_correct, llava_pred = ensemble_llava_evaluation(
551+
eval_args.pretrained_name,
552+
gt_name,
553+
frames,
554+
tokenizer,
555+
model,
556+
image_processor,
557+
mc_data,
558+
eval_args.clip_length,
559+
eval_args.llava_num_frames,
560+
temperature = 0,
561+
ensemble_k = 1,
562+
time_meta = time_meta,
563+
is_test = not finish_early)
564+
565+
# log the predictions into prediciton analysis
548566

549-
llava_correct, llava_pred = ensemble_llava_evaluation(
550-
eval_args.pretrained_name,
551-
gt_name,
552-
frames,
553-
tokenizer,
554-
model,
555-
image_processor,
556-
mc_data,
557-
eval_args.clip_length,
558-
eval_args.llava_num_frames,
559-
temperature = 0,
560-
ensemble_k = 1,
561-
time_meta = time_meta,
562-
is_test = not finish_early)
563-
564-
# log the predictions into prediciton analysis
565-
566-
val_dataset.prediction_analysis.log(global_index,
567-
llava_pred,
568-
gt_name,
569-
predictions[str(global_index)],
570-
time_meta['start_second'].item(),
571-
time_meta['end_second'].item(),
572-
time_meta['vid_path'],
573-
dataset_name = 'EK100')
567+
val_dataset.prediction_analysis.log(global_index,
568+
llava_pred,
569+
gt_name,
570+
predictions[str(global_index)],
571+
time_meta['start_second'].item(),
572+
time_meta['end_second'].item(),
573+
time_meta['vid_path'],
574+
dataset_name = 'EK100')
574575

575576

576577

577578

578-
local_running_corrects.add_(llava_correct)
579-
global_running_corrects.add_(llava_correct)
580-
581-
local_total_samples.add_(1)
582-
global_total_samples.add_(1)
579+
local_running_corrects.add_(llava_correct)
580+
global_running_corrects.add_(llava_correct)
581+
582+
local_total_samples.add_(1)
583+
global_total_samples.add_(1)
583584

584-
logger.info(f'Process {dist.get_rank()} - local_total_samples: {local_total_samples:.4f}')
585+
logger.info(f'Process {dist.get_rank()} - local_total_samples: {local_total_samples:.4f}')
585586

586-
logger.info(f'Process {dist.get_rank()} - loca_llava_correct: {llava_correct:.4f}')
587+
logger.info(f'Process {dist.get_rank()} - loca_llava_correct: {llava_correct:.4f}')
587588

588-
logger.info(f'Process {dist.get_rank()} - local_running_corrects: {local_running_corrects:.4f}')
589+
logger.info(f'Process {dist.get_rank()} - local_running_corrects: {local_running_corrects:.4f}')
589590

590591

591592
# Calculate and log running mean accuracy

action/llava_ov_inference.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,16 @@ def llava_ov_process(video_frames,
4343
image_sizes = [frame.size for frame in video_frames]
4444

4545
# Generate response
46-
cont = model.generate(
47-
input_ids,
48-
images=image_tensors,
49-
image_sizes=image_sizes,
50-
do_sample=False,
51-
temperature=temperature,
52-
max_new_tokens=4096,
53-
modalities=["video"],
54-
)
46+
with torch.no_grad():
47+
cont = model.generate(
48+
input_ids,
49+
images=image_tensors,
50+
image_sizes=image_sizes,
51+
do_sample=False,
52+
temperature=temperature,
53+
max_new_tokens=4096,
54+
modalities=["video"],
55+
)
5556

5657
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
5758
return text_outputs[0]

action/prediction_analysis.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import glob
33
import os
4+
import numpy as np
45
class PredictionAnalysis:
56
"""
67
We save data that can be used for ad-hoc analysis
@@ -24,7 +25,7 @@ def __init__(self, save_folder = '.', rank = 0):
2425
self.rank = rank
2526
self.prefix = 'prediction_analysis_buf'
2627
self.save_path = os.path.join(save_folder, f'{self.prefix}_rank{rank}.json')
27-
self.data = {}
28+
self.data = {}
2829
def log(self,
2930
global_index,
3031
llava_pred,
@@ -62,10 +63,11 @@ def load(self):
6263
with open(file, 'r') as f:
6364
_data = json.load(f)
6465
self.data.update(_data)
66+
print ('length', len(self.data))
67+
assert len(self.data) == 9668
68+
#print (sorted(list(self.data.keys()), key = lambda x: int(x)))
6569

66-
print (sorted(list(self.data.keys()), key = lambda x: int(x)))
67-
68-
def wrong_verb(self):
70+
def analysis(self):
6971

7072
N = len(self.data)
7173
llava_wrong_verb_collections = []
@@ -76,27 +78,83 @@ def wrong_verb(self):
7678
avion_wrong_noun_collections = []
7779
avion_wrong_verb_noun_collections = []
7880

79-
wrong_llava_collections = []
80-
wrong_avion_collections = []
81+
wrong_llava_collections = [0] * N
82+
wrong_avion_collections = [0] * N
8183

8284
indices = sorted(list(self.data.keys()), key = lambda x: int(x))
8385

84-
for index in indices:
86+
for idx, index in enumerate(indices):
8587
items = self.data[index]
8688
llava_pred = items['llava_pred']
8789
gt_name = items['gt_name']
8890
# only replacing the first :
8991
avion_pred = items['avion_preds']['predictions'][0].replace(':', ' ', 1)
9092

93+
llava_verb, llava_noun = llava_pred.split(' ')
94+
avion_verb, avion_noun = avion_pred.split(' ')
95+
gt_verb, gt_noun = gt_name.split(' ')
96+
9197
if llava_pred != gt_name:
92-
wrong_llava_collections.append((llava_pred, gt_name))
98+
if set(llava_pred).intersection(set(gt_name)) == set(gt_name):
99+
print ('what is going on')
100+
print ('nooo', llava_pred, gt_name)
101+
#wrong_llava_collections.append((llava_pred, gt_name))
102+
#print (llava_pred, gt_name)
103+
wrong_llava_collections[idx] = 0
104+
else:
105+
wrong_llava_collections[idx] = 1
93106
if avion_pred!= gt_name:
94-
# pred, gt
95-
wrong_avion_collections.append((avion_pred, gt_name))
107+
wrong_avion_collections[idx] = 0
108+
else:
109+
wrong_avion_collections[idx] = 1
110+
96111

112+
if llava_verb == gt_verb and llava_noun!=gt_noun:
113+
llava_wrong_noun_collections.append((llava_pred, gt_name))
114+
if llava_noun == gt_noun and llava_verb!=gt_verb:
115+
llava_wrong_verb_collections.append((llava_pred, gt_name))
116+
if llava_noun!= gt_noun and llava_verb!=gt_verb:
117+
llava_wrong_verb_noun_collections.append((llava_pred, gt_name))
118+
119+
if avion_verb == gt_verb and avion_noun!=gt_noun:
120+
avion_wrong_noun_collections.append((avion_pred, gt_name))
121+
if avion_noun == gt_noun and avion_verb!=gt_verb:
122+
avion_wrong_verb_collections.append((avion_pred, gt_name))
123+
if avion_noun!= gt_noun and avion_verb!=gt_verb:
124+
avion_wrong_verb_noun_collections.append((avion_pred, gt_name))
125+
126+
wrong_llava_collections = np.array(wrong_llava_collections)
127+
wrong_avion_collections = np.array(wrong_avion_collections)
128+
llava_wrong_noun_collections = np.array(llava_wrong_noun_collections)
129+
llava_wrong_verb_collections = np.array(llava_wrong_verb_collections)
130+
llava_wrong_verb_noun_collections = np.array(llava_wrong_verb_noun_collections)
131+
avion_wrong_noun_collections = np.array(avion_wrong_noun_collections)
132+
avion_wrong_verb_collections = np.array(avion_wrong_verb_collections)
133+
avion_wrong_verb_noun_collections = np.array(avion_wrong_verb_noun_collections)
134+
135+
# first, the correlation between avion and llava
136+
correlation = np.corrcoef(wrong_llava_collections, wrong_avion_collections)[0, 1]
137+
138+
print("Correlation:", correlation)
139+
140+
print ('llava top1 action accuracy {:.3f}'.format(np.sum(wrong_llava_collections == 1) / len(wrong_llava_collections)))
141+
print ('avion top1 action accuracy {:.3f}'.format(np.sum(wrong_avion_collections == 1) / len(wrong_avion_collections)))
142+
143+
print ('llava percentage of wrong noun {:.2f}'.format(len(llava_wrong_noun_collections) / np.sum(wrong_llava_collections == 0)))
144+
print ('llava percentage of wrong verb {:.2f}'.format(len(llava_wrong_verb_collections) / np.sum(wrong_llava_collections == 0)))
145+
print ('llava percentage of both verb noun wrong {:.2f}'.format(len(llava_wrong_verb_noun_collections) / np.sum(wrong_llava_collections == 0)))
146+
147+
148+
print ('avion percentage of wrong noun {:.2f}'.format(len(avion_wrong_noun_collections) / np.sum(wrong_avion_collections == 0)))
149+
print ('avion percentage of wrong verb {:.2f}'.format(len(avion_wrong_verb_collections) / np.sum(wrong_avion_collections == 0)))
150+
print ('avion percentage of both verb noun wrong {:.2f}'.format(len(avion_wrong_verb_noun_collections) / np.sum(wrong_avion_collections == 0)))
151+
152+
153+
97154

98155
if __name__ == '__main__':
99156

100157

101158
prediction_analysis = PredictionAnalysis(save_folder = '/storage-rcp-pure/upmwmathis_scratch/shaokai/LLaVA-NeXT')
102159
prediction_analysis.load()
160+
prediction_analysis.analysis()

0 commit comments

Comments
 (0)