1+ ## first inference the dataset with one frame only
2+ ## cannot be the whole dataset but a subset of it
3+ ## then inference the dataset with 8 frames
4+ ## Then find where the one frame is wrong and 8 frames is right
5+ ## collect those examples as data for dpo
6+
7+
8+ from llava .action .chatgpt_utils import ChatGPT
9+ import os
10+ import csv
11+ import random
12+ from tqdm import tqdm
13+ from pydantic import BaseModel
14+ import traceback
15+ from concurrent .futures import ProcessPoolExecutor
16+ import openai
17+
18+
19+ client = openai .OpenAI (api_key = os .environ .get ("OPENAI_API_KEY" ))
20+
21+
22+ GPT_MODEL = 'gpt-4o'
23+
24+ class CaptionResponse (BaseModel ):
25+ """
26+ The GT was known. The response is to add more information to the GT
27+ """
28+ caption : str
29+
30+
31+ def datetime2sec (str ):
32+ hh , mm , ss = str .split (':' )
33+ return int (hh ) * 3600 + int (mm ) * 60 + float (ss )
34+
35+ class CaptionInference (ChatGPT ):
36+ def __init__ (self ,
37+ root ,
38+ annotation_file ,
39+ clip_length = 4 ,
40+ debug = False
41+ ):
42+ self .root = root
43+ self .annotation_file = annotation_file
44+ self .clip_length = clip_length
45+ self .debug = debug
46+ self .question_type = 'gpt-gt-reason'
47+
48+ self .data = self .init_data ()
49+
50+ print (len (self .data ))
51+
52+ def select_train_subset (self ):
53+
54+ with open (os .path .join (self .annotation_file ), 'r' ) as f :
55+ csv_reader = list (csv .reader (f ))
56+ header = csv_reader [0 ] # Get header
57+ data = csv_reader [1 :] # Get data
58+ N = len (data )
59+ print ('N' , N )
60+ # get a random subset of the data such as 20% of them. Give the indices
61+ random .seed (0 )
62+ indices = random .sample (range (N ), int (N * 0.2 ))
63+ return indices
64+
65+ def init_data (self ):
66+ ret = {}
67+ csv_reader = csv .reader (open (self .annotation_file ))
68+ _ = next (csv_reader ) # skip the header
69+
70+ indices = self .select_train_subset ()
71+ count = 0
72+ for idx , row in enumerate (csv_reader ):
73+ if idx not in indices :
74+ continue
75+ narration = row [8 ]
76+ pid , vid = row [1 :3 ]
77+ start_second , end_second = datetime2sec (row [4 ]), datetime2sec (row [5 ])
78+ vid_path = '{}/{}' .format (pid , vid )
79+ verb , noun = int (row [10 ]), int (row [12 ])
80+ gt_vn = '{}:{}' .format (verb , noun )
81+
82+ narration = row [8 ]
83+
84+ ret [count ] = {
85+ 'gt_answer' : narration ,
86+ 'start_second' : start_second ,
87+ 'end_second' : end_second ,
88+
89+ 'vid_path' : vid_path
90+ }
91+ count += 1
92+ return ret
93+
94+ def multi_process_run (self , n_samples = - 1 ):
95+ # to initialize it
96+
97+ if n_samples != - 1 :
98+ indices = list (range (len (self .data )))[:n_samples ]
99+
100+ num_chunks = os .cpu_count () if not self .debug else 1
101+
102+ indices_groups = self .split_indices (indices , num_chunks )
103+
104+ with ProcessPoolExecutor (max_workers = num_chunks ) as executor :
105+ # Pass additional arguments to the function
106+ futures = [executor .submit (self .run , group ) for group in indices_groups ]
107+
108+ # Wait for all futures to complete
109+ combined_results = {}
110+ for future in futures :
111+ result_dict = future .result ()
112+ combined_results .update (result_dict )
113+
114+ if self .debug :
115+ print (combined_results )
116+
117+
118+ def predict_images (self , images , parsed_item ):
119+ """
120+ Predict the action from the images
121+ """
122+ from llava .action .utils import format_task_related_prompt
123+ options = parsed_item ['options' ]
124+ start_second = 0
125+ end_second = parsed_item ['end_second' ] - parsed_item ['start_second' ]
126+ temperature = 0
127+ video_duration = end_second - start_second
128+ n_frames = len (images )
129+
130+ task_related_prompt = format_task_related_prompt (options , self .question_type , perspective = 'first_person' )
131+
132+ time_instruction = f"The provided video lasts for { video_duration :.3f} seconds. "
133+
134+ system_prompt = time_instruction + task_related_prompt
135+
136+ format_prompt = """
137+ **Return only a JSON object** with the following two properties:
138+
139+ - `"answer"`: the answer to the question.
140+ - `"caption"`: A detailed caption of the video. Used to support the answer.
141+ """
142+
143+ if 'o1' in GPT_MODEL :
144+ system_prompt += format_prompt
145+
146+ print (system_prompt )
147+
148+ if 'o1-mini' == GPT_MODEL :
149+ system_role = "user"
150+ temperature = 1
151+ elif 'o1' == GPT_MODEL :
152+ system_role = "developer"
153+ else :
154+ system_role = "system"
155+
156+ system_message = [{"role" : system_role , "content" : system_prompt }]
157+
158+ multi_image_content = self .prepare_multiple_images (images )
159+ multi_modal_content = [{"type" : "text" , "text" : "" }] + multi_image_content
160+ user_message = [{"role" : "user" , "content" : multi_modal_content }]
161+
162+ kwargs = {'model' : GPT_MODEL ,
163+ 'messages' : system_message + user_message ,
164+ 'response_format' : CaptionResponse ,
165+ 'temperature' : temperature }
166+
167+ if 'o1' in GPT_MODEL :
168+ kwargs .pop ('response_format' )
169+ if 'o1' == GPT_MODEL :
170+ kwargs .pop ('temperature' )
171+ pass
172+ #kwargs['reasoning_effort'] = 'high'
173+ if 'o1' not in GPT_MODEL :
174+ # structural output
175+ response = client .beta .chat .completions .parse (
176+ ** kwargs
177+ )
178+ else :
179+ response = client .chat .completions .create (
180+ ** kwargs
181+ )
182+
183+ total_cost = self .calculate_cost (response )
184+
185+ ret = response .choices [0 ].message .parsed if 'o1' not in GPT_MODEL else response .choices [0 ].message
186+
187+ return ret
188+
189+ def run (self , indices = None ):
190+
191+ if indices is None :
192+ data_batch = {i : self .data [i ] for i in range (len (self .data )) if i in list (range (len (self .data )))}
193+ else :
194+ data_batch = {i : self .data [i ] for i in range (len (self .data )) if i in indices }
195+ ret = {}
196+ for k ,v in tqdm (data_batch .items ()):
197+
198+ start_timestamp = v ['start_second' ]
199+ end_timestamp = v ['end_second' ]
200+ vid_path = v ['vid_path' ]
201+
202+ frames , time_meta = self .extract_frames (vid_path , start_timestamp , end_timestamp )
203+ try :
204+ parsed_answer = self .predict_images (frames , v )
205+ except Exception as e :
206+ # get full stack trace
207+ traceback .print_exc ()
208+
209+ print ("An exception occurred: " , e )
210+
211+ caption = parsed_answer .caption
212+ print ('caption:' , caption )
213+
214+ if self .debug :
215+ break
216+
217+ return ret
218+
219+
220+
221+
222+
223+
224+ if __name__ == '__main__' :
225+ video_root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
226+ anno_root = '/data/shaokai/epic-kitchens-100-annotations/'
227+ clip_length = 8
228+
229+ cap = CaptionInference (video_root , os .path .join (anno_root , 'EPIC_100_train.csv' ), clip_length , debug = True )
230+
231+ #cap.multi_process_run(n_samples = 2)
232+ cap .run ()
0 commit comments