Skip to content

Commit c650950

Browse files
committed
WIP
1 parent 6016a41 commit c650950

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

llava/action/generate_comparison_dpo.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
2323

2424

25-
GPT_MODEL = 'gpt-4o'
26-
2725
class CaptionResponse(BaseModel):
2826
"""
2927
The GT was known. The response is to add more information to the GT
@@ -37,6 +35,7 @@ def datetime2sec(str):
3735

3836
class CaptionInference(ChatGPT):
3937
def __init__(self,
38+
gpt_model,
4039
root,
4140
annotation_file,
4241
clip_length = 4,
@@ -48,7 +47,8 @@ def __init__(self,
4847
self.clip_length = clip_length
4948
self.debug = debug
5049
self.question_type = 'gpt-gt-reason'
51-
self.fraction = fraction
50+
self.fraction = fraction
51+
self.gpt_model = gpt_model
5252
self.data = self.init_data()
5353

5454
print (len(self.data))
@@ -147,16 +147,16 @@ def predict_images(self, images, parsed_item):
147147
- `"answer"`: the answer to the question.
148148
- `"caption"`: A detailed caption of the video. Used to support the answer.
149149
"""
150-
151-
if 'o1' in GPT_MODEL:
150+
151+
if 'o1' in self.gpt_model:
152152
system_prompt += format_prompt
153153

154154
print (system_prompt)
155155

156-
if 'o1-mini' == GPT_MODEL:
156+
if 'o1-mini' == self.gpt_model:
157157
system_role = "user"
158158
temperature = 1
159-
elif 'o1' == GPT_MODEL:
159+
elif 'o1' == self.gpt_model:
160160
system_role = "developer"
161161
else:
162162
system_role = "system"
@@ -167,18 +167,18 @@ def predict_images(self, images, parsed_item):
167167
multi_modal_content = [{"type": "text", "text": ""}] + multi_image_content
168168
user_message = [{"role": "user", "content": multi_modal_content}]
169169

170-
kwargs = {'model': GPT_MODEL,
170+
kwargs = {'model': self.gpt_model,
171171
'messages': system_message + user_message,
172172
'response_format': CaptionResponse,
173173
'temperature': temperature}
174174

175-
if 'o1' in GPT_MODEL:
175+
if 'o1' in self.gpt_model:
176176
kwargs.pop('response_format')
177-
if 'o1' == GPT_MODEL:
177+
if 'o1' == self.gpt_model:
178178
kwargs.pop('temperature')
179179
pass
180180
#kwargs['reasoning_effort'] = 'high'
181-
if 'o1' not in GPT_MODEL:
181+
if 'o1' not in self.gpt_model:
182182
# structural output
183183
response = client.beta.chat.completions.parse(
184184
**kwargs
@@ -190,7 +190,7 @@ def predict_images(self, images, parsed_item):
190190

191191
total_cost = self.calculate_cost(response)
192192

193-
ret = response.choices[0].message.parsed if 'o1' not in GPT_MODEL else response.choices[0].message
193+
ret = response.choices[0].message.parsed if 'o1' not in self.gpt_model else response.choices[0].message
194194

195195
return ret
196196

@@ -222,9 +222,7 @@ def run(self, indices = None):
222222

223223
ret[k] = copy.deepcopy(v)
224224
ret[k]['caption'] = caption
225-
226-
227-
225+
228226
if self.debug:
229227
break
230228

0 commit comments

Comments
 (0)