Skip to content

Commit 45f6eec

Browse files
author
Ye Shaokai
committed
added new file
1 parent 2da1b64 commit 45f6eec

File tree

1 file changed

+232
-0
lines changed

1 file changed

+232
-0
lines changed
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
## first inference the dataset with one frame only
2+
## cannot be the whole dataset but a subset of it
3+
## then inference the dataset with 8 frames
4+
## Then find where the one frame is wrong and 8 frames is right
5+
## collect those examples as data for dpo
6+
7+
8+
from llava.action.chatgpt_utils import ChatGPT
9+
import os
10+
import csv
11+
import random
12+
from tqdm import tqdm
13+
from pydantic import BaseModel
14+
import traceback
15+
from concurrent.futures import ProcessPoolExecutor
16+
import openai
17+
18+
19+
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
20+
21+
22+
GPT_MODEL = 'gpt-4o'
23+
24+
class CaptionResponse(BaseModel):
25+
"""
26+
The GT was known. The response is to add more information to the GT
27+
"""
28+
caption: str
29+
30+
31+
def datetime2sec(str):
32+
hh, mm, ss = str.split(':')
33+
return int(hh) * 3600 + int(mm) * 60 + float(ss)
34+
35+
class CaptionInference(ChatGPT):
36+
def __init__(self,
37+
root,
38+
annotation_file,
39+
clip_length = 4,
40+
debug = False
41+
):
42+
self.root = root
43+
self.annotation_file = annotation_file
44+
self.clip_length = clip_length
45+
self.debug = debug
46+
self.question_type = 'gpt-gt-reason'
47+
48+
self.data = self.init_data()
49+
50+
print (len(self.data))
51+
52+
def select_train_subset(self):
53+
54+
with open(os.path.join(self.annotation_file), 'r') as f:
55+
csv_reader = list(csv.reader(f))
56+
header = csv_reader[0] # Get header
57+
data = csv_reader[1:] # Get data
58+
N = len(data)
59+
print ('N', N)
60+
# get a random subset of the data such as 20% of them. Give the indices
61+
random.seed(0)
62+
indices = random.sample(range(N), int(N*0.2))
63+
return indices
64+
65+
def init_data(self):
66+
ret = {}
67+
csv_reader = csv.reader(open(self.annotation_file))
68+
_ = next(csv_reader) # skip the header
69+
70+
indices = self.select_train_subset()
71+
count = 0
72+
for idx, row in enumerate(csv_reader):
73+
if idx not in indices:
74+
continue
75+
narration = row[8]
76+
pid, vid = row[1:3]
77+
start_second, end_second = datetime2sec(row[4]), datetime2sec(row[5])
78+
vid_path = '{}/{}'.format(pid, vid)
79+
verb, noun = int(row[10]), int(row[12])
80+
gt_vn = '{}:{}'.format(verb, noun)
81+
82+
narration = row[8]
83+
84+
ret[count] = {
85+
'gt_answer': narration,
86+
'start_second': start_second,
87+
'end_second': end_second,
88+
89+
'vid_path': vid_path
90+
}
91+
count+=1
92+
return ret
93+
94+
def multi_process_run(self, n_samples = -1):
95+
# to initialize it
96+
97+
if n_samples != -1:
98+
indices = list(range(len(self.data)))[:n_samples]
99+
100+
num_chunks = os.cpu_count() if not self.debug else 1
101+
102+
indices_groups = self.split_indices(indices, num_chunks)
103+
104+
with ProcessPoolExecutor(max_workers=num_chunks) as executor:
105+
# Pass additional arguments to the function
106+
futures = [executor.submit(self.run, group) for group in indices_groups]
107+
108+
# Wait for all futures to complete
109+
combined_results = {}
110+
for future in futures:
111+
result_dict = future.result()
112+
combined_results.update(result_dict)
113+
114+
if self.debug:
115+
print (combined_results)
116+
117+
118+
def predict_images(self, images, parsed_item):
119+
"""
120+
Predict the action from the images
121+
"""
122+
from llava.action.utils import format_task_related_prompt
123+
options = parsed_item['options']
124+
start_second = 0
125+
end_second = parsed_item['end_second'] - parsed_item['start_second']
126+
temperature = 0
127+
video_duration = end_second - start_second
128+
n_frames = len(images)
129+
130+
task_related_prompt = format_task_related_prompt(options, self.question_type, perspective = 'first_person')
131+
132+
time_instruction = f"The provided video lasts for {video_duration:.3f} seconds. "
133+
134+
system_prompt = time_instruction + task_related_prompt
135+
136+
format_prompt = """
137+
**Return only a JSON object** with the following two properties:
138+
139+
- `"answer"`: the answer to the question.
140+
- `"caption"`: A detailed caption of the video. Used to support the answer.
141+
"""
142+
143+
if 'o1' in GPT_MODEL:
144+
system_prompt += format_prompt
145+
146+
print (system_prompt)
147+
148+
if 'o1-mini' == GPT_MODEL:
149+
system_role = "user"
150+
temperature = 1
151+
elif 'o1' == GPT_MODEL:
152+
system_role = "developer"
153+
else:
154+
system_role = "system"
155+
156+
system_message = [{"role": system_role, "content": system_prompt}]
157+
158+
multi_image_content = self.prepare_multiple_images(images)
159+
multi_modal_content = [{"type": "text", "text": ""}] + multi_image_content
160+
user_message = [{"role": "user", "content": multi_modal_content}]
161+
162+
kwargs = {'model': GPT_MODEL,
163+
'messages': system_message + user_message,
164+
'response_format': CaptionResponse,
165+
'temperature': temperature}
166+
167+
if 'o1' in GPT_MODEL:
168+
kwargs.pop('response_format')
169+
if 'o1' == GPT_MODEL:
170+
kwargs.pop('temperature')
171+
pass
172+
#kwargs['reasoning_effort'] = 'high'
173+
if 'o1' not in GPT_MODEL:
174+
# structural output
175+
response = client.beta.chat.completions.parse(
176+
**kwargs
177+
)
178+
else:
179+
response = client.chat.completions.create(
180+
**kwargs
181+
)
182+
183+
total_cost = self.calculate_cost(response)
184+
185+
ret = response.choices[0].message.parsed if 'o1' not in GPT_MODEL else response.choices[0].message
186+
187+
return ret
188+
189+
def run(self, indices = None):
190+
191+
if indices is None:
192+
data_batch = {i : self.data[i] for i in range(len(self.data)) if i in list(range(len(self.data)))}
193+
else:
194+
data_batch = {i : self.data[i] for i in range(len(self.data)) if i in indices}
195+
ret = {}
196+
for k,v in tqdm(data_batch.items()):
197+
198+
start_timestamp = v['start_second']
199+
end_timestamp = v['end_second']
200+
vid_path = v['vid_path']
201+
202+
frames, time_meta = self.extract_frames(vid_path, start_timestamp, end_timestamp)
203+
try:
204+
parsed_answer = self.predict_images(frames, v)
205+
except Exception as e:
206+
# get full stack trace
207+
traceback.print_exc()
208+
209+
print ("An exception occurred: ", e)
210+
211+
caption = parsed_answer.caption
212+
print ('caption:', caption)
213+
214+
if self.debug:
215+
break
216+
217+
return ret
218+
219+
220+
221+
222+
223+
224+
if __name__ == '__main__':
225+
video_root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
226+
anno_root = '/data/shaokai/epic-kitchens-100-annotations/'
227+
clip_length = 8
228+
229+
cap = CaptionInference(video_root, os.path.join(anno_root, 'EPIC_100_train.csv'), clip_length, debug = True)
230+
231+
#cap.multi_process_run(n_samples = 2)
232+
cap.run()

0 commit comments

Comments
 (0)