Skip to content

Commit b560afa

Browse files
author
Ye Shaokai
committed
WIP
1 parent 21e6b39 commit b560afa

File tree

4 files changed

+218
-151
lines changed

4 files changed

+218
-151
lines changed

action/chatgpt_utils.py

Lines changed: 89 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import io
33
import json
44
import os
5-
import cv2
65
import numpy as np
76
import openai
87
from pydantic import BaseModel
98
from multiprocessing.pool import Pool
9+
from action.utils import avion_video_loader
10+
import cv2
1011

1112
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
1213

@@ -26,11 +27,24 @@ class MultiChoiceResponse(BaseModel):
2627
explanation: str
2728

2829

30+
def split_indices(indices, num_chunks):
31+
chunk_size = len(indices) // num_chunks
32+
return [indices[i:i + chunk_size] for i in range(0, len(indices), chunk_size)]
2933

3034
class GPTAnnotator:
31-
def __init__(self, prediction_file_path):
32-
with open(prediction_file_path, 'r') as f:
33-
self.prediction_file = json.load(f)
35+
def __init__(self, ann_file, data_root, clip_length = 32):
36+
self.ann_file = ann_file
37+
self.data_root = data_root
38+
self.clip_length = clip_length
39+
data = []
40+
with open(ann_file, 'r') as f:
41+
for line in f:
42+
# Parse the JSON data
43+
_data = json.loads(line)
44+
# Process your data
45+
data.append(_data)
46+
self.data = data
47+
3448

3549
def prepare_multiple_images(self, images):
3650
"""
@@ -62,38 +76,84 @@ def prepare_multiple_images(self, images):
6276
return multi_image_content
6377

6478

65-
def annotate(self, images):
79+
def extract_frames(self, data_root, vid_path, start_second, end_second):
80+
frames, time_meta = avion_video_loader(data_root,
81+
vid_path,
82+
'MP4',
83+
start_second,
84+
end_second,
85+
clip_length = self.clip_length,
86+
threads = 1,
87+
fast_rrc=False,
88+
fast_rcc = False,
89+
jitter = False)
90+
return frames, time_meta
91+
92+
def parse_conversation(self, item):
6693
"""
67-
Annotate to do image caption only
94+
We should get time steps, duration
95+
We shoudd also get gt and wrong answers
6896
"""
69-
pass
97+
conversations = item['conversations']
98+
human_dict = conversations[0]
99+
100+
# the offset is to remove the quote '
101+
option_start = human_dict['value'].index['['] + 2
102+
option_end = human_dict['value'].index[']'] - 1
103+
104+
option_text = human_dict['value'][option_start:option_end]
105+
gpt_dict = conversations[1]
106+
gt_answer = gpt_dict['value']
107+
108+
assert human_dict['from'] == 'human' and gpt_dict['from'] =='gpt'
70109

71-
def annotate_with_multichoice(self, images, mc_data):
110+
ret = {'options': option_text,
111+
'gt_answer': gt_answer,
112+
'start_second': item['start_timestamp'],
113+
'end_second': item['end_timestemp']}
114+
115+
return ret
116+
117+
def annotate(self, indices):
118+
119+
data_batch = [self.data[i] for i in range(len(self.data)) if i in indices]
120+
121+
for item in data_batch:
122+
start_timestamp = item['start_timestamp']
123+
end_timestamp = item['end_timestamp']
124+
vid_path = '{}/{}'.format(item['video'].split('-')[0], item['video'].split('-')[1])
125+
frames, time_meta = self.extract_frames(self.data_root, vid_path, start_timestamp, end_timestamp)
126+
data_item = self.parse_conversation(item)
127+
anno = self.annotate_images(frames, data_item)
128+
print (anno)
129+
break
130+
131+
def annotate_images(self, images, data_item):
72132
"""
73133
Annotate with mc_data
74-
75134
{
76-
77135
}
78-
79136
"""
80-
137+
gt_answer = data_item['gt_answer']
138+
option_text = data_item['option_text']
139+
start_second = data_item['start_second']
140+
end_second = data_item['end_second']
81141
temperature = 0
82-
include_images = True
83-
84-
system_prompt_prefix = """Inspect the images from the video and explain why the answer of the multi-choice question is D. """
85-
system_prompt_suffix = """Yes"""
142+
system_prompt_prefix = f"""
143+
You are seeing video frames from an egocentric view. You are determining what action the person is performing.
144+
The video's start time is {start_second} and the end time is {end_second}.
145+
In a multi-choice video question answering, you were given following options {option_text} and the correct answer is {gt_answer}.
146+
Please describe what you see and why wrong answers are wrong and why right answer is right.
147+
"""
148+
system_prompt_suffix = """"""
86149

87150
system_prompt = system_prompt_prefix + system_prompt_suffix
88151

89152
system_message = [{"role": "system", "content": system_prompt}]
90153

91-
if include_images:
92-
multi_image_content = self.prepare_multiple_images(images)
93-
multi_modal_content = [{"type": "text", "text": ""}] + multi_image_content
94-
user_message = [{"role": "user", "content": multi_modal_content}]
95-
else:
96-
user_message = [{"role": "user", "content": ""}]
154+
multi_image_content = self.prepare_multiple_images(images)
155+
multi_modal_content = [{"type": "text", "text": ""}] + multi_image_content
156+
user_message = [{"role": "user", "content": multi_modal_content}]
97157

98158
response = client.beta.chat.completions.parse(
99159
model=GPT_MODEL,
@@ -114,10 +174,12 @@ def annotate_using_chatgpt():
114174
#pool.starmap(annotate, task_args)
115175

116176
pass
117-
118-
def annotate_from_train_conv_file(train_file_path):
119-
pass
177+
178+
120179

121180
if __name__ == '__main__':
122-
train_file_path = '/storage-rcp-pure/upmwmathis_scratch/shaokai'
123-
annotate_from_train_conv_file(train_file_path)
181+
train_file_path = '/storage-rcp-pure/upmwmathis_scratch/shaokai/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl'
182+
root = '/storage-rcp-pure/upmwmathis_scratch/shaokai/EK100'
183+
184+
185+
GPTAnnotator(train_file_path, root)

action/ek_eval.py

Lines changed: 2 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import json
1919
import logging
2020
from llava.utils import rank0_print
21-
from action.utils import generate_label_map, MultiChoiceGenerator, match_answer, parse_avion_predictions
21+
from action.utils import generate_label_map, MultiChoiceGenerator, match_answer, parse_avion_predictions, avion_video_loader
2222
from action.prediction_analysis import PredictionAnalysis
2323
import copy
2424
from collections import Counter
@@ -33,125 +33,6 @@ def datetime2sec(str):
3333
hh, mm, ss = str.split(':')
3434
return int(hh) * 3600 + int(mm) * 60 + float(ss)
3535

36-
37-
def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
38-
frame_ids = np.convolve(np.linspace(start_frame, end_frame, num_segments + 1), [0.5, 0.5], mode='valid')
39-
if jitter:
40-
seg_size = float(end_frame - start_frame - 1) / num_segments
41-
shift = (np.random.rand(num_segments) - 0.5) * seg_size
42-
frame_ids += shift
43-
return frame_ids.astype(int).tolist()
44-
45-
46-
def get_video_reader(videoname, num_threads, fast_rrc, rrc_params, fast_rcc, rcc_params):
47-
video_reader = None
48-
if fast_rrc:
49-
video_reader = decord.VideoReader(
50-
videoname,
51-
num_threads=num_threads,
52-
width=rrc_params[0], height=rrc_params[0],
53-
use_rrc=True, scale_min=rrc_params[1][0], scale_max=rrc_params[1][1],
54-
)
55-
elif fast_rcc:
56-
video_reader = decord.VideoReader(
57-
videoname,
58-
num_threads=num_threads,
59-
width=rcc_params[0], height=rcc_params[0],
60-
use_rcc=True,
61-
)
62-
else:
63-
video_reader = decord.VideoReader(videoname, num_threads=num_threads)
64-
return video_reader
65-
66-
67-
def video_loader(root, vid, ext, second, end_second,
68-
chunk_len=300, fps=30, clip_length=32,
69-
threads=1,
70-
fast_rrc=False, rrc_params=(224, (0.5, 1.0)),
71-
fast_rcc=False, rcc_params=(224, ),
72-
jitter=False):
73-
assert fps > 0, 'fps should be greater than 0'
74-
if chunk_len == -1:
75-
vr = get_video_reader(
76-
osp.join(root, '{}.{}'.format(vid, ext)),
77-
num_threads=threads,
78-
fast_rrc=fast_rrc, rrc_params=rrc_params,
79-
fast_rcc=fast_rcc, rcc_params=rcc_params,
80-
)
81-
end_second = min(end_second, len(vr) / fps)
82-
83-
# calculate frame_ids
84-
frame_offset = int(np.round(second * fps))
85-
total_duration = max(int((end_second - second) * fps), clip_length)
86-
frame_ids = get_frame_ids(frame_offset, min(frame_offset + total_duration, len(vr)), num_segments=clip_length, jitter=jitter)
87-
88-
# load frames
89-
assert max(frame_ids) < len(vr)
90-
try:
91-
frames = vr.get_batch(frame_ids).asnumpy()
92-
except decord.DECORDError as error:
93-
print(error)
94-
frames = vr.get_batch([0] * len(frame_ids)).asnumpy()
95-
96-
return torch.from_numpy(frames.astype(np.float32))
97-
98-
else:
99-
time_meta = {}
100-
101-
time_meta['duration'] = end_second - second
102-
chunk_start = int(second) // chunk_len * chunk_len
103-
chunk_end = int(end_second) // chunk_len * chunk_len
104-
while True:
105-
video_filename = osp.join(root, '{}.{}'.format(vid, ext), '{}.{}'.format(chunk_end, ext))
106-
107-
if not osp.exists(video_filename):
108-
# print("{} does not exists!".format(video_filename))
109-
chunk_end -= chunk_len
110-
else:
111-
vr = decord.VideoReader(video_filename)
112-
end_second = min(end_second, (len(vr) - 1) / fps + chunk_end)
113-
assert chunk_start <= chunk_end
114-
break
115-
# calculate frame_ids
116-
frame_ids = get_frame_ids(
117-
int(np.round(second * fps)),
118-
int(np.round(end_second * fps)),
119-
num_segments=clip_length, jitter=jitter
120-
)
121-
all_frames = []
122-
all_frame_ids = []
123-
# allocate absolute frame-ids into the relative ones
124-
for chunk in range(chunk_start, chunk_end + chunk_len, chunk_len):
125-
rel_frame_ids = list(filter(lambda x: int(chunk * fps) <= x < int((chunk + chunk_len) * fps), frame_ids))
126-
rel_frame_ids = [int(frame_id - chunk * fps) for frame_id in rel_frame_ids]
127-
vr = get_video_reader(
128-
osp.join(root, '{}.{}'.format(vid, ext), '{}.{}'.format(chunk, ext)),
129-
num_threads=threads,
130-
fast_rrc=fast_rrc, rrc_params=rrc_params,
131-
fast_rcc=fast_rcc, rcc_params=rcc_params,
132-
)
133-
try:
134-
frames = vr.get_batch(rel_frame_ids).asnumpy()
135-
except decord.DECORDError as error:
136-
print(error)
137-
frames = vr.get_batch([0] * len(rel_frame_ids)).asnumpy()
138-
except IndexError:
139-
print(root, vid, ext, second, end_second)
140-
all_frames.append(frames)
141-
all_frame_ids.append(frame_ids)
142-
if sum(map(lambda x: x.shape[0], all_frames)) == clip_length:
143-
break
144-
res = torch.from_numpy(np.concatenate(all_frames, axis=0).astype(np.float32))
145-
time_meta['n_frames'] = res.shape[0]
146-
all_frame_ids = np.concatenate(all_frame_ids, axis = 0)
147-
frame_time = [e/fps for e in all_frame_ids]
148-
frame_time-= frame_time[0]
149-
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
150-
time_meta['frame_time'] = frame_time
151-
assert res.shape[0] == clip_length, "{}, {}, {}, {}, {}, {}, {}".format(root, vid, second, end_second, res.shape[0], rel_frame_ids, frame_ids)
152-
return res, time_meta
153-
154-
15536
class VideoCaptionDatasetBase(torch.utils.data.Dataset):
15637
def __init__(self, dataset, root, metadata, is_trimmed=True):
15738
self.dataset = dataset
@@ -216,7 +97,7 @@ def get_raw_item(
21697
vid_path, start_second, end_second, fps, narration, verb, noun = self.samples[i]
21798
# chunk length is the chunked video clip length
21899
# clip length is number of frames we want to sample from the clip
219-
frames, time_meta = video_loader(self.root, vid_path, 'MP4',
100+
frames, time_meta = avion_video_loader(self.root, vid_path, 'MP4',
220101
start_second, end_second,
221102
chunk_len=chunk_len, fps=fps,
222103
clip_length=clip_length,

action/prediction_analysis.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,17 @@ def analysis(self):
8989
gt_name = items['gt_name']
9090
# only replacing the first :
9191
avion_pred = items['avion_preds']['predictions'][0].replace(':', ' ', 1)
92-
93-
llava_verb, llava_noun = llava_pred.split(' ')
92+
avion_preds = items['avion_preds']['predictions'][:5]
93+
avion_preds = [e.replace(':', ' ', 1) for e in avion_preds]
94+
try:
95+
llava_verb, llava_noun = llava_pred.split(' ')
96+
except:
97+
lst = llava_pred.split(' ')
98+
llava_verb, llava_noun = lst[0], lst[1]
9499
avion_verb, avion_noun = avion_pred.split(' ')
95100
gt_verb, gt_noun = gt_name.split(' ')
96101

97-
if llava_pred != gt_name:
102+
if llava_pred != gt_name:
98103
wrong_llava_collections[idx] = 0
99104
else:
100105
wrong_llava_collections[idx] = 1

0 commit comments

Comments
 (0)