|
17 | 17 | import numpy as np |
18 | 18 | import base64 |
19 | 19 | from pathlib import Path |
| 20 | +import traceback |
20 | 21 |
|
21 | 22 |
|
22 | 23 | client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) |
23 | 24 |
|
24 | | -GPT_MODEL = "gpt-4o-2024-08-06" |
| 25 | +GPT_MODEL = "o1" |
25 | 26 |
|
26 | 27 | prices = { |
27 | | - "gpt-4o-2024-08-06": {"input": 2.5 / 10**6, "output": 10 / 10**6}, |
| 28 | + "gpt-4o": {"input": 2.5 / 10**6, "output": 10 / 10**6}, |
| 29 | + "o1": {"input": 15 / 10**6, "output": 60 / 10**6}, |
| 30 | + "o1-mini": {"input": 3 / 10**6, "output": 12 / 10**6}, |
| 31 | + "gpt-4o-mini": {"input": 0.15 / 10**6, "output": 0.6 / 10**6}, |
28 | 32 | } |
29 | 33 |
|
30 | | - |
31 | | - |
32 | 34 | class ExpandReasonMCPrompt: |
33 | 35 | """ |
34 | 36 | Given the reasoning + mc description, create multiple questions |
@@ -92,6 +94,27 @@ def generate_prompt(cls, start_second, end_second, option_text, gt_answer): |
92 | 94 | return prompt |
93 | 95 |
|
94 | 96 |
|
| 97 | +class InferenceAnswer: |
| 98 | + json_errors = 0 |
| 99 | + def __init__(self, answer): |
| 100 | + |
| 101 | + if 'o1' not in GPT_MODEL: |
| 102 | + self.answer = answer.answer |
| 103 | + self.caption = answer.caption |
| 104 | + else: |
| 105 | + content = answer.content |
| 106 | + temp = content.replace('```json', '').replace('```', '').strip() |
| 107 | + try: |
| 108 | + answer = json.loads(temp) |
| 109 | + except json.JSONDecodeError as e: |
| 110 | + print(f"Failed to decode JSON response: {response_content}") |
| 111 | + self.answer = 'N/A' |
| 112 | + self.caption = 'N/A' |
| 113 | + json_errors += 1 |
| 114 | + self.answer = answer['answer'] |
| 115 | + self.caption = answer['caption'] |
| 116 | + |
| 117 | + |
95 | 118 |
|
96 | 119 | class GPTStrongReasoningWithGTPrompt: |
97 | 120 | @classmethod |
@@ -133,6 +156,7 @@ class GT_Agnostic_Response(BaseModel): |
133 | 156 | The GT was known. The response is to add more information to the GT |
134 | 157 | """ |
135 | 158 | answer: str |
| 159 | + caption: str |
136 | 160 |
|
137 | 161 |
|
138 | 162 | class GPTHandObjectResponse(BaseModel): |
@@ -329,8 +353,8 @@ def __init__(self, |
329 | 353 | self.annotation_root = Path(annotation_file).parent |
330 | 354 | self.action_representation = action_representation |
331 | 355 | self.labels, self.mapping_vn2narration, self.mapping_vn2act, self.verb_maps, self.noun_maps = generate_label_map(self.annotation_root, |
332 | | - action_representation, |
333 | | - cache_file = os.path.join(self.annotation_root, 'nlp_cache.pkl')) |
| 356 | + action_representation) |
| 357 | + |
334 | 358 |
|
335 | 359 |
|
336 | 360 | self.mc_generator = AvionMultiChoiceGenerator(self.annotation_root) |
@@ -422,9 +446,14 @@ def run(self, indices=None): |
422 | 446 | try: |
423 | 447 | parsed_answer = self.predict_images(frames, v) |
424 | 448 | except Exception as e: |
| 449 | + # get full stack trace |
| 450 | + traceback.print_exc() |
| 451 | + |
425 | 452 | print ("An exception occurred: ", e) |
| 453 | + |
426 | 454 | predicted_answer = parsed_answer.answer |
427 | | - print (predicted_answer) |
| 455 | + caption = parsed_answer.caption |
| 456 | + print ('caption:', caption) |
428 | 457 | gt_name = v['gt_answer'] |
429 | 458 | ret[k] = { |
430 | 459 | 'gt_name': gt_name, |
@@ -458,28 +487,64 @@ def predict_images(self, images, parsed_item): |
458 | 487 | time_instruction = f"The provided video lasts for {video_duration:.3f} seconds, and {n_frames} frames are uniformly sampled from it. " |
459 | 488 |
|
460 | 489 | system_prompt = time_instruction + task_related_prompt |
| 490 | + |
| 491 | + format_prompt = """ |
| 492 | +**Return only a JSON object** with the following two properties: |
| 493 | +
|
| 494 | +- `"answer"`: the answer to the question. |
| 495 | +- `"caption"`: A detailed caption of the video. Used to support the answer. |
| 496 | +""" |
| 497 | + |
| 498 | + if 'o1' in GPT_MODEL: |
| 499 | + system_prompt += format_prompt |
461 | 500 |
|
462 | 501 | print (system_prompt) |
463 | 502 |
|
464 | 503 | if self.handobj_root is not None: |
465 | 504 | system_prompt += f"""To further assist you, we mark hands and object when they are visible. The left hand is marked with a bounding box that contains letter L and the right hand's bounding box contains letter R. The object is marked as 'O'.""" |
466 | 505 |
|
467 | | - |
468 | | - system_message = [{"role": "system", "content": system_prompt}] |
| 506 | + if 'o1-mini' == GPT_MODEL: |
| 507 | + system_role = "user" |
| 508 | + temperature = 1 |
| 509 | + elif 'o1' == GPT_MODEL: |
| 510 | + system_role = "developer" |
| 511 | + else: |
| 512 | + system_role = "system" |
| 513 | + |
| 514 | + system_message = [{"role": system_role, "content": system_prompt}] |
469 | 515 |
|
470 | 516 | multi_image_content = self.prepare_multiple_images(images) |
471 | 517 | multi_modal_content = [{"type": "text", "text": ""}] + multi_image_content |
472 | 518 | user_message = [{"role": "user", "content": multi_modal_content}] |
473 | 519 |
|
474 | | - response = client.beta.chat.completions.parse( |
475 | | - model=GPT_MODEL, |
476 | | - messages=system_message + user_message, |
477 | | - response_format = GT_Agnostic_Response, |
478 | | - temperature = temperature |
479 | | - ) |
| 520 | + kwargs = {'model': GPT_MODEL, |
| 521 | + 'messages': system_message + user_message, |
| 522 | + 'response_format': GT_Agnostic_Response, |
| 523 | + 'temperature': temperature} |
| 524 | + |
| 525 | + if 'o1' in GPT_MODEL: |
| 526 | + kwargs.pop('response_format') |
| 527 | + if 'o1' == GPT_MODEL: |
| 528 | + kwargs.pop('temperature') |
| 529 | + pass |
| 530 | + #kwargs['reasoning_effort'] = 'high' |
| 531 | + if 'o1' not in GPT_MODEL: |
| 532 | + # structural output |
| 533 | + response = client.beta.chat.completions.parse( |
| 534 | + **kwargs |
| 535 | + ) |
| 536 | + else: |
| 537 | + response = client.chat.completions.create( |
| 538 | + **kwargs |
| 539 | + ) |
| 540 | + |
480 | 541 | total_cost = self.calculate_cost(response) |
| 542 | + |
| 543 | + ret = response.choices[0].message.parsed if 'o1' not in GPT_MODEL else response.choices[0].message |
| 544 | + |
| 545 | + return InferenceAnswer(ret) |
| 546 | + |
481 | 547 |
|
482 | | - return response.choices[0].message.parsed |
483 | 548 |
|
484 | 549 | class GPTHandObjectAnnotator(ChatGPT): |
485 | 550 | """ |
@@ -526,7 +591,6 @@ def run(self, indices): |
526 | 591 | item['conversations'][1]['value'] = gpt_answer |
527 | 592 | item['question_type'] = self.anno_type |
528 | 593 | ret[index] = item |
529 | | - print (item) |
530 | 594 | if self.debug: |
531 | 595 | break |
532 | 596 |
|
@@ -622,7 +686,6 @@ def parse_conversation_from_train_convs(self, item): |
622 | 686 | option_text = ', '.join(eval(human_dict['value'])) |
623 | 687 | gpt_dict = conversations[1] |
624 | 688 | gt_answer = gpt_dict['value'] |
625 | | - print ('gt_answer', gt_answer) |
626 | 689 | assert human_dict['from'] == 'human' and gpt_dict['from'] =='gpt' |
627 | 690 |
|
628 | 691 | ret = {'options': option_text, |
@@ -683,7 +746,6 @@ def run(self, indices): |
683 | 746 | item['conversations'][1]['value'] = gpt_answer |
684 | 747 | item['question_type'] = self.anno_type |
685 | 748 | ret[index] = item |
686 | | - print (item) |
687 | 749 | if self.debug: |
688 | 750 |
|
689 | 751 | break |
@@ -895,26 +957,27 @@ def convert_instruct_json_to_jsonl(path, apply_filter = False): |
895 | 957 | # debug = False, |
896 | 958 | # clip_length = 4, |
897 | 959 | # n_samples = -1, |
898 | | - # anno_type = 'gpt-gt-strong-reason') |
| 960 | + # anno_type = 'gpt-gt-reason') |
899 | 961 |
|
900 | | - # root = '/data/EK100/EK100_320p_15sec_30fps_libx264' |
901 | | - # val_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv' |
902 | | - # avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json' |
| 962 | + root = '/data/EK100/EK100_320p_15sec_30fps_libx264' |
| 963 | + val_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv' |
| 964 | + avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json' |
903 | 965 |
|
904 | 966 |
|
905 | | - # annotator = GPTInferenceAnnotator(root, |
906 | | - # val_file, |
907 | | - # avion_prediction_file, |
908 | | - # clip_length = 8, |
909 | | - # debug = False, |
910 | | - # action_representation = "GT_random_narration", |
911 | | - # question_type = 'mc_GT_random_narration', |
912 | | - # topk = 5) |
| 967 | + annotator = GPTInferenceAnnotator(root, |
| 968 | + val_file, |
| 969 | + avion_prediction_file, |
| 970 | + clip_length = 8, |
| 971 | + debug = False, |
| 972 | + action_representation = "GT_random_narration", |
| 973 | + question_type = 'mc_GT_random_narration', |
| 974 | + topk = 5) |
913 | 975 |
|
914 | | - # annotator.multi_process_run(n_samples = 100) |
| 976 | + annotator.multi_process_run(n_samples = 100) |
| 977 | + print ('# json errors', InferenceAnswer.json_errors) |
915 | 978 |
|
916 | 979 |
|
917 | | - convert_json_to_jsonl('train_anno_gpt-gt-strong-reason_4_all.json') |
| 980 | + #convert_json_to_jsonl('train_anno_gpt-gt-strong-reason_4_all.json') |
918 | 981 |
|
919 | 982 | #convert_instruct_json_to_jsonl('train_anno_gpt-gt-instruct-reason_4_all.json') |
920 | 983 |
|
|
0 commit comments