Skip to content

Commit 654569d

Browse files
committed
able to inference o1
1 parent 93d5c43 commit 654569d

File tree

1 file changed

+96
-33
lines changed

1 file changed

+96
-33
lines changed

llava/action/chatgpt_utils.py

Lines changed: 96 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,20 @@
1717
import numpy as np
1818
import base64
1919
from pathlib import Path
20+
import traceback
2021

2122

2223
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
2324

24-
GPT_MODEL = "gpt-4o-2024-08-06"
25+
GPT_MODEL = "o1"
2526

2627
prices = {
27-
"gpt-4o-2024-08-06": {"input": 2.5 / 10**6, "output": 10 / 10**6},
28+
"gpt-4o": {"input": 2.5 / 10**6, "output": 10 / 10**6},
29+
"o1": {"input": 15 / 10**6, "output": 60 / 10**6},
30+
"o1-mini": {"input": 3 / 10**6, "output": 12 / 10**6},
31+
"gpt-4o-mini": {"input": 0.15 / 10**6, "output": 0.6 / 10**6},
2832
}
2933

30-
31-
3234
class ExpandReasonMCPrompt:
3335
"""
3436
Given the reasoning + mc description, create multiple questions
@@ -92,6 +94,27 @@ def generate_prompt(cls, start_second, end_second, option_text, gt_answer):
9294
return prompt
9395

9496

97+
class InferenceAnswer:
98+
json_errors = 0
99+
def __init__(self, answer):
100+
101+
if 'o1' not in GPT_MODEL:
102+
self.answer = answer.answer
103+
self.caption = answer.caption
104+
else:
105+
content = answer.content
106+
temp = content.replace('```json', '').replace('```', '').strip()
107+
try:
108+
answer = json.loads(temp)
109+
except json.JSONDecodeError as e:
110+
print(f"Failed to decode JSON response: {response_content}")
111+
self.answer = 'N/A'
112+
self.caption = 'N/A'
113+
json_errors += 1
114+
self.answer = answer['answer']
115+
self.caption = answer['caption']
116+
117+
95118

96119
class GPTStrongReasoningWithGTPrompt:
97120
@classmethod
@@ -133,6 +156,7 @@ class GT_Agnostic_Response(BaseModel):
133156
The GT was known. The response is to add more information to the GT
134157
"""
135158
answer: str
159+
caption: str
136160

137161

138162
class GPTHandObjectResponse(BaseModel):
@@ -329,8 +353,8 @@ def __init__(self,
329353
self.annotation_root = Path(annotation_file).parent
330354
self.action_representation = action_representation
331355
self.labels, self.mapping_vn2narration, self.mapping_vn2act, self.verb_maps, self.noun_maps = generate_label_map(self.annotation_root,
332-
action_representation,
333-
cache_file = os.path.join(self.annotation_root, 'nlp_cache.pkl'))
356+
action_representation)
357+
334358

335359

336360
self.mc_generator = AvionMultiChoiceGenerator(self.annotation_root)
@@ -422,9 +446,14 @@ def run(self, indices=None):
422446
try:
423447
parsed_answer = self.predict_images(frames, v)
424448
except Exception as e:
449+
# get full stack trace
450+
traceback.print_exc()
451+
425452
print ("An exception occurred: ", e)
453+
426454
predicted_answer = parsed_answer.answer
427-
print (predicted_answer)
455+
caption = parsed_answer.caption
456+
print ('caption:', caption)
428457
gt_name = v['gt_answer']
429458
ret[k] = {
430459
'gt_name': gt_name,
@@ -458,28 +487,64 @@ def predict_images(self, images, parsed_item):
458487
time_instruction = f"The provided video lasts for {video_duration:.3f} seconds, and {n_frames} frames are uniformly sampled from it. "
459488

460489
system_prompt = time_instruction + task_related_prompt
490+
491+
format_prompt = """
492+
**Return only a JSON object** with the following two properties:
493+
494+
- `"answer"`: the answer to the question.
495+
- `"caption"`: A detailed caption of the video. Used to support the answer.
496+
"""
497+
498+
if 'o1' in GPT_MODEL:
499+
system_prompt += format_prompt
461500

462501
print (system_prompt)
463502

464503
if self.handobj_root is not None:
465504
system_prompt += f"""To further assist you, we mark hands and object when they are visible. The left hand is marked with a bounding box that contains letter L and the right hand's bounding box contains letter R. The object is marked as 'O'."""
466505

467-
468-
system_message = [{"role": "system", "content": system_prompt}]
506+
if 'o1-mini' == GPT_MODEL:
507+
system_role = "user"
508+
temperature = 1
509+
elif 'o1' == GPT_MODEL:
510+
system_role = "developer"
511+
else:
512+
system_role = "system"
513+
514+
system_message = [{"role": system_role, "content": system_prompt}]
469515

470516
multi_image_content = self.prepare_multiple_images(images)
471517
multi_modal_content = [{"type": "text", "text": ""}] + multi_image_content
472518
user_message = [{"role": "user", "content": multi_modal_content}]
473519

474-
response = client.beta.chat.completions.parse(
475-
model=GPT_MODEL,
476-
messages=system_message + user_message,
477-
response_format = GT_Agnostic_Response,
478-
temperature = temperature
479-
)
520+
kwargs = {'model': GPT_MODEL,
521+
'messages': system_message + user_message,
522+
'response_format': GT_Agnostic_Response,
523+
'temperature': temperature}
524+
525+
if 'o1' in GPT_MODEL:
526+
kwargs.pop('response_format')
527+
if 'o1' == GPT_MODEL:
528+
kwargs.pop('temperature')
529+
pass
530+
#kwargs['reasoning_effort'] = 'high'
531+
if 'o1' not in GPT_MODEL:
532+
# structural output
533+
response = client.beta.chat.completions.parse(
534+
**kwargs
535+
)
536+
else:
537+
response = client.chat.completions.create(
538+
**kwargs
539+
)
540+
480541
total_cost = self.calculate_cost(response)
542+
543+
ret = response.choices[0].message.parsed if 'o1' not in GPT_MODEL else response.choices[0].message
544+
545+
return InferenceAnswer(ret)
546+
481547

482-
return response.choices[0].message.parsed
483548

484549
class GPTHandObjectAnnotator(ChatGPT):
485550
"""
@@ -526,7 +591,6 @@ def run(self, indices):
526591
item['conversations'][1]['value'] = gpt_answer
527592
item['question_type'] = self.anno_type
528593
ret[index] = item
529-
print (item)
530594
if self.debug:
531595
break
532596

@@ -622,7 +686,6 @@ def parse_conversation_from_train_convs(self, item):
622686
option_text = ', '.join(eval(human_dict['value']))
623687
gpt_dict = conversations[1]
624688
gt_answer = gpt_dict['value']
625-
print ('gt_answer', gt_answer)
626689
assert human_dict['from'] == 'human' and gpt_dict['from'] =='gpt'
627690

628691
ret = {'options': option_text,
@@ -683,7 +746,6 @@ def run(self, indices):
683746
item['conversations'][1]['value'] = gpt_answer
684747
item['question_type'] = self.anno_type
685748
ret[index] = item
686-
print (item)
687749
if self.debug:
688750

689751
break
@@ -895,26 +957,27 @@ def convert_instruct_json_to_jsonl(path, apply_filter = False):
895957
# debug = False,
896958
# clip_length = 4,
897959
# n_samples = -1,
898-
# anno_type = 'gpt-gt-strong-reason')
960+
# anno_type = 'gpt-gt-reason')
899961

900-
# root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
901-
# val_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
902-
# avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
962+
root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
963+
val_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
964+
avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
903965

904966

905-
# annotator = GPTInferenceAnnotator(root,
906-
# val_file,
907-
# avion_prediction_file,
908-
# clip_length = 8,
909-
# debug = False,
910-
# action_representation = "GT_random_narration",
911-
# question_type = 'mc_GT_random_narration',
912-
# topk = 5)
967+
annotator = GPTInferenceAnnotator(root,
968+
val_file,
969+
avion_prediction_file,
970+
clip_length = 8,
971+
debug = False,
972+
action_representation = "GT_random_narration",
973+
question_type = 'mc_GT_random_narration',
974+
topk = 5)
913975

914-
# annotator.multi_process_run(n_samples = 100)
976+
annotator.multi_process_run(n_samples = 100)
977+
print ('# json errors', InferenceAnswer.json_errors)
915978

916979

917-
convert_json_to_jsonl('train_anno_gpt-gt-strong-reason_4_all.json')
980+
#convert_json_to_jsonl('train_anno_gpt-gt-strong-reason_4_all.json')
918981

919982
#convert_instruct_json_to_jsonl('train_anno_gpt-gt-instruct-reason_4_all.json')
920983

0 commit comments

Comments
 (0)