2222client = openai .OpenAI (api_key = os .environ .get ("OPENAI_API_KEY" ))
2323
2424
25- GPT_MODEL = 'gpt-4o'
26-
2725class CaptionResponse (BaseModel ):
2826 """
2927 The GT was known. The response is to add more information to the GT
@@ -37,6 +35,7 @@ def datetime2sec(str):
3735
3836class CaptionInference (ChatGPT ):
3937 def __init__ (self ,
38+ gpt_model ,
4039 root ,
4140 annotation_file ,
4241 clip_length = 4 ,
@@ -48,7 +47,8 @@ def __init__(self,
4847 self .clip_length = clip_length
4948 self .debug = debug
5049 self .question_type = 'gpt-gt-reason'
51- self .fraction = fraction
50+ self .fraction = fraction
51+ self .gpt_model = gpt_model
5252 self .data = self .init_data ()
5353
5454 print (len (self .data ))
@@ -147,16 +147,16 @@ def predict_images(self, images, parsed_item):
147147- `"answer"`: the answer to the question.
148148- `"caption"`: A detailed caption of the video. Used to support the answer.
149149"""
150-
151- if 'o1' in GPT_MODEL :
150+
151+ if 'o1' in self . gpt_model :
152152 system_prompt += format_prompt
153153
154154 print (system_prompt )
155155
156- if 'o1-mini' == GPT_MODEL :
156+ if 'o1-mini' == self . gpt_model :
157157 system_role = "user"
158158 temperature = 1
159- elif 'o1' == GPT_MODEL :
159+ elif 'o1' == self . gpt_model :
160160 system_role = "developer"
161161 else :
162162 system_role = "system"
@@ -167,18 +167,18 @@ def predict_images(self, images, parsed_item):
167167 multi_modal_content = [{"type" : "text" , "text" : "" }] + multi_image_content
168168 user_message = [{"role" : "user" , "content" : multi_modal_content }]
169169
170- kwargs = {'model' : GPT_MODEL ,
170+ kwargs = {'model' : self . gpt_model ,
171171 'messages' : system_message + user_message ,
172172 'response_format' : CaptionResponse ,
173173 'temperature' : temperature }
174174
175- if 'o1' in GPT_MODEL :
175+ if 'o1' in self . gpt_model :
176176 kwargs .pop ('response_format' )
177- if 'o1' == GPT_MODEL :
177+ if 'o1' == self . gpt_model :
178178 kwargs .pop ('temperature' )
179179 pass
180180 #kwargs['reasoning_effort'] = 'high'
181- if 'o1' not in GPT_MODEL :
181+ if 'o1' not in self . gpt_model :
182182 # structural output
183183 response = client .beta .chat .completions .parse (
184184 ** kwargs
@@ -190,7 +190,7 @@ def predict_images(self, images, parsed_item):
190190
191191 total_cost = self .calculate_cost (response )
192192
193- ret = response .choices [0 ].message .parsed if 'o1' not in GPT_MODEL else response .choices [0 ].message
193+ ret = response .choices [0 ].message .parsed if 'o1' not in self . gpt_model else response .choices [0 ].message
194194
195195 return ret
196196
@@ -222,9 +222,7 @@ def run(self, indices = None):
222222
223223 ret [k ] = copy .deepcopy (v )
224224 ret [k ]['caption' ] = caption
225-
226-
227-
225+
228226 if self .debug :
229227 break
230228
0 commit comments