Skip to content

Commit 9b01e0d

Browse files
committed
WIP
1 parent 4cfe5ae commit 9b01e0d

File tree

4 files changed

+198
-88
lines changed

4 files changed

+198
-88
lines changed

action/chatgpt_utils.py

Lines changed: 127 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,28 @@
66
import openai
77
from pydantic import BaseModel
88
from concurrent.futures import ProcessPoolExecutor
9-
from action.utils import avion_video_loader
9+
from action.utils import avion_video_loader, create_multi_choice_from_avion_predictions
1010
import torch
1111
import cv2
1212
from pathlib import Path
13+
from action.prediction_analysis import PredictionAnalysis
1314

1415
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
1516

1617
GPT_MODEL = "gpt-4o-2024-08-06"
1718

1819

19-
class ImageOnlyResponse(BaseModel):
20+
class GT_Agnostic_Response(BaseModel):
2021
"""
22+
The GT was not known. The response is to generate a new answer
2123
"""
2224
explanation: str
25+
answer: str
2326

24-
class MultiChoiceResponse(BaseModel):
27+
class GT_Augmentation_Response(BaseModel):
2528
"""
26-
The output format of the response
29+
The GT was known. The response is to add more information to the GT
2730
"""
28-
2931
explanation: str
3032
def split_indices(indices, num_chunks):
3133
# Calculate the size of each chunk and the remainder
@@ -51,22 +53,21 @@ def __init__(self, ann_file, data_root, clip_length = 32):
5153
data = []
5254
with open(ann_file, 'r') as f:
5355
for line in f:
54-
# Parse the JSON data
55-
_data = json.loads(line)
56-
# Process your data
57-
data.append(_data)
56+
data.append(json.loads(line))
5857
self.data = data
59-
60-
58+
6159
def prepare_multiple_images(self, images):
6260
"""
6361
6462
"""
6563
encoded_image_list = []
64+
6665
for image in images:
6766

6867
if isinstance(image, torch.Tensor):
6968
image = image.cpu().detach().numpy()
69+
70+
7071
# images from matplotlib etc.
7172
if isinstance(image, io.BytesIO):
7273
image_bytes = image
@@ -104,10 +105,11 @@ def extract_frames(self, data_root, vid_path, start_second, end_second):
104105
jitter = False)
105106
return frames, time_meta
106107

107-
def parse_conversation(self, item):
108+
109+
110+
def parse_conversation_from_train_convs(self, item):
108111
"""
109-
We should get time steps, duration
110-
We shoudd also get gt and wrong answers
112+
The item has the structure of convs defined in the train anno.
111113
"""
112114
conversations = item['conversations']
113115
human_dict = conversations[0]
@@ -141,19 +143,53 @@ def annotate(self, indices):
141143
end_timestamp = item['end_timestamp']
142144
vid_path = '{}/{}'.format(item['video'].split('-')[0], item['video'].split('-')[1])
143145
frames, time_meta = self.extract_frames(self.data_root, vid_path, start_timestamp, end_timestamp)
144-
parsed_item = self.parse_conversation(item)
145-
gpt_answer = self.annotate_images(frames, parsed_item).explanation
146+
parsed_item = self.parse_conversation_from_train_convs(item)
147+
gpt_answer = self.annotate_images_from_train_anno(frames, parsed_item).explanation
146148
item['conversations'][1]['value'] = gpt_answer
147149
ret[index] = item
148150
break
149151

150-
return ret
152+
return ret
151153

152-
def annotate_images(self, images, data_item):
154+
def predict_images(self, images, data_item):
153155
"""
154-
Annotate with mc_data
155-
{
156-
}
156+
Predict the action from the images
157+
"""
158+
159+
option_text = data_item['options']
160+
start_second = data_item['start_second']
161+
end_second = data_item['end_second']
162+
temperature = 0
163+
system_prompt_prefix = f"""
164+
You are seeing video frames from an egocentric view of a person. Pretend that you are the person. Your task is to describe what action you are performing.
165+
To assist you for how to describe the action, the video's start time is {start_second} and the end time is {end_second} and the duration is {end_second - start_second} seconds.
166+
You were given multiple choice options {option_text}. Pick the correct one and put that into the answer. Note in the answer do not include the option letter, just the name of the action.
167+
Also explain why the correct answer is correct and why the other options are incorrect.
168+
"""
169+
170+
system_prompt_suffix = """"""
171+
172+
system_prompt = system_prompt_prefix + system_prompt_suffix
173+
174+
system_message = [{"role": "system", "content": system_prompt}]
175+
176+
multi_image_content = self.prepare_multiple_images(images)
177+
multi_modal_content = [{"type": "text", "text": ""}] + multi_image_content
178+
user_message = [{"role": "user", "content": multi_modal_content}]
179+
180+
response = client.beta.chat.completions.parse(
181+
model=GPT_MODEL,
182+
messages=system_message + user_message,
183+
response_format = GT_Agnostic_Response,
184+
temperature = temperature
185+
)
186+
187+
return response.choices[0].message.parsed
188+
189+
190+
def annotate_images_from_train_anno(self, images, data_item):
191+
"""
192+
Assuming that data_item already has the multi-choice options and the gt_answer
157193
"""
158194
gt_answer = data_item['gt_answer']
159195
option_text = data_item['options']
@@ -190,7 +226,7 @@ def annotate_images(self, images, data_item):
190226
response = client.beta.chat.completions.parse(
191227
model=GPT_MODEL,
192228
messages=system_message + user_message,
193-
response_format = MultiChoiceResponse,
229+
response_format = GT_Augmentation_Response,
194230
temperature = temperature
195231
)
196232

@@ -203,14 +239,7 @@ def process_subset(indices_subset, train_file_path, root):
203239
return annotator.annotate(indices_subset)
204240

205241

206-
if __name__ == '__main__':
207-
#train_file_path = '/storage-rcp-pure/upmwmathis_scratch/shaokai/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl'
208-
#root = '/storage-rcp-pure/upmwmathis_scratch/shaokai/EK100'
209-
train_file_path = '/data/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl'
210-
root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
211-
212-
num_cores = 2 #os.cpu_count()
213-
242+
def multi_process_annotate(train_file_path, root, num_cores):
214243
print (f'Using {num_cores} cores thus splitting the data into {num_cores} chunks')
215244

216245
with open(train_file_path, 'r') as f:
@@ -238,7 +267,7 @@ def process_subset(indices_subset, train_file_path, root):
238267
keys = sorted(list(combined_results.keys()))
239268

240269
print ('resulted number of keys', len(keys))
241-
270+
242271
result = []
243272
for key in keys:
244273
result.append(combined_results[key])
@@ -247,4 +276,70 @@ def process_subset(indices_subset, train_file_path, root):
247276

248277
with open(anno_root / 'gpt_annotated.jsonl', 'w') as f:
249278
for item in result:
250-
f.write(json.dumps(item) + '\n')
279+
f.write(json.dumps(item) + '\n')
280+
281+
def explore_wrong_examples(train_file_path, root, prediction_save_folder):
282+
283+
annotator = GPTAnnotator(train_file_path, root)
284+
prediction_analysis = PredictionAnalysis(prediction_save_folder)
285+
wrong_examples = prediction_analysis.get_wrong_examples()
286+
data_root = root
287+
288+
count = 0
289+
for k,v in wrong_examples.items():
290+
291+
292+
gt_name = v['gt_name']
293+
avion_predictions = v['avion_preds']['predictions']
294+
_avion_predictions = [e.replace(':', ' ', 1) for e in avion_predictions]
295+
if gt_name not in _avion_predictions:
296+
print ('gt_name not in avion_predictions')
297+
continue
298+
else:
299+
count+=1
300+
if count <= 2:
301+
continue
302+
if count > 6:
303+
break
304+
print ('gt_name in avion_predictions')
305+
306+
vid_path = v['vid_path'][0]
307+
start_second = v['start_second']
308+
end_second = v['end_second']
309+
310+
frames, time_meta = annotator.extract_frames(data_root, vid_path, start_second, end_second)
311+
312+
option_text = create_multi_choice_from_avion_predictions(avion_predictions, len(avion_predictions))['options'][0]
313+
parsed_item = {
314+
'options': option_text,
315+
'gt_answer': gt_name,
316+
'start_second': start_second,
317+
'end_second': end_second
318+
}
319+
320+
parsed_answer = annotator.predict_images(frames, parsed_item)
321+
322+
predicted_answer = parsed_answer.answer
323+
explanation = parsed_answer.explanation
324+
325+
print ('gt_name', gt_name)
326+
print ('avion_predictions', avion_predictions)
327+
print ('llava_pred', v['llava_pred'])
328+
print ('chatgpt answer', predicted_answer)
329+
print ('explanation', explanation)
330+
331+
332+
333+
334+
335+
if __name__ == '__main__':
336+
#train_file_path = '/storage-rcp-pure/upmwmathis_scratch/shaokai/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl'
337+
#root = '/storage-rcp-pure/upmwmathis_scratch/shaokai/EK100'
338+
train_file_path = '/data/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl'
339+
root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
340+
pred_folder = '/data/epic_kitchen/llavavideo_avion_mc_top10_5epoch_preds'
341+
num_cores = 2 #os.cpu_count()
342+
#multi_process_annotate(train_file_path, root, num_cores)
343+
344+
explore_wrong_examples(train_file_path, root, pred_folder)
345+

action/ek_eval.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@
1818
import json
1919
import logging
2020
from llava.utils import rank0_print
21-
from action.utils import generate_label_map, MultiChoiceGenerator, match_answer, parse_avion_predictions, avion_video_loader
21+
from action.utils import generate_label_map, MultiChoiceGenerator, match_answer, parse_avion_predictions, avion_video_loader, create_multi_choice_from_avion_predictions
2222
from action.prediction_analysis import PredictionAnalysis
2323
import copy
2424
from collections import Counter
2525
import torch.distributed as dist
2626

2727
if not dist.is_initialized():
2828
dist.init_process_group(backend='nccl')
29-
rank = dist.get_rank()
30-
torch.cuda.set_device(rank)
29+
rank = dist.get_rank()
30+
torch.cuda.set_device(rank)
3131

3232
def datetime2sec(str):
3333
hh, mm, ss = str.split(':')
@@ -250,25 +250,7 @@ def prepare_llava(pretrained):
250250

251251
return tokenizer, model, image_processor, max_length
252252

253-
def get_topk_predictions(data, idx, k):
254253

255-
letters = [chr(65+i) for i in range(26)][:k]
256-
options = list(range(26))[:k]
257-
258-
predictions = data[str(idx)]['predictions'][:k]
259-
predictions = parse_avion_predictions(predictions)
260-
261-
for i in range(len(options)):
262-
options[i] = f'{letters[i]}. {predictions[i]}'
263-
264-
mc_data = {
265-
'question': {0: 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer.'},
266-
'options': {0: options},
267-
'valid_letters': letters,
268-
'avion_pred': predictions[0]
269-
}
270-
271-
return mc_data
272254

273255
def ensemble_llava_evaluation(
274256
pretrained_name,
@@ -415,7 +397,7 @@ def evaluate_on_EK100(eval_args,
415397
local_total_samples = torch.tensor(0.0, device=device)
416398

417399
if eval_args.action_predictions:
418-
mc_data = get_topk_predictions(predictions, global_index, eval_args.topk_predictions)
400+
mc_data = create_multi_choice_from_avion_predictions(predictions[global_index], eval_args.topk_predictions)
419401
avion_pred = mc_data['avion_pred']
420402
if gt_name == avion_pred:
421403
local_avion_correct.add_(1)

0 commit comments

Comments
 (0)