|
26 | 26 |
|
27 | 27 | openai_api_key = "EMPTY" |
28 | 28 | openai_api_base_list = [ |
29 | | - os.environ.get("LLM_AS_A_JUDGE_BASE", "https://sd285v869b9467c7sab70.apigateway-cn-shanghai.volceapi.com/v1"), |
| 29 | + os.environ.get("LLM_AS_A_JUDGE_BASE", "YOUR_API_BASE"), # e.g. http://localhost:8000/v1 |
30 | 30 | ] |
31 | 31 |
|
32 | 32 | client_list = [] |
@@ -453,94 +453,6 @@ def compute_score(predict_str: str, ground_truth: str, extra_info=None, **kwargs |
453 | 453 | else: |
454 | 454 | return (0.8 * acc_reward + 0.2 * format_reward, acc_reward, format_reward) |
455 | 455 |
|
456 | | - |
457 | | -# def rule_math_verify(ground_truth, model_answer): |
458 | | -# gold = parse(ground_truth) |
459 | | -# answer = parse(model_answer) |
460 | | -# return verify(gold, answer) |
461 | | - |
462 | | - |
463 | | -# def generative_verify(query, ground_truth, model_answer): |
464 | | -# client_idx = random.randint(0, len(client_list) - 1) |
465 | | -# client = client_list[client_idx] |
466 | | -# model_name = model_name_list[client_idx] |
467 | | - |
468 | | -# full_prompt = MATH_VERIFY_PROMPT.format( |
469 | | -# query=query, |
470 | | -# gold_ans=ground_truth, |
471 | | -# pred_ans=model_answer, |
472 | | -# ) |
473 | | - |
474 | | -# response = "" |
475 | | -# for it in range(8): |
476 | | -# try: |
477 | | -# chat_response = client.chat.completions.create( |
478 | | -# model=model_name, |
479 | | -# messages=[ |
480 | | -# {"role": "user", "content": full_prompt}, |
481 | | -# ], |
482 | | -# seed=random.randint(0, 1000000), |
483 | | -# temperature=0.0, |
484 | | -# ) |
485 | | -# response = chat_response.choices[0].message.content.strip() |
486 | | -# break |
487 | | -# except Exception as e: |
488 | | -# print(f" [ERROR math] generative_verify error: {e}") |
489 | | -# continue |
490 | | - |
491 | | -# judgement = response.split("## Equivalence Judgement")[-1].lower() |
492 | | -# if "true" in judgement and "false" not in judgement: |
493 | | -# return True |
494 | | -# elif "false" in judgement and "true" not in judgement: |
495 | | -# return False |
496 | | -# else: |
497 | | -# print(" [ERROR math] verify bug output: ") |
498 | | - |
499 | | - |
500 | | -# def compute_score_math(predict_str: str, ground_truth: str, extra_info=None) -> float: |
501 | | -# is_format_error = False |
502 | | -# # predict_str = "<think>" + predict_str |
503 | | -# count_think_1 = predict_str.count("<think>") |
504 | | -# count_think_2 = predict_str.count("</think>") |
505 | | -# if count_think_1 != count_think_2 or count_think_1 == 0: # reward hacking |
506 | | -# is_format_error = True |
507 | | - |
508 | | -# predict_no_think = predict_str.split("</think>")[-1].strip() |
509 | | -# count_answer_1 = predict_no_think.count("<answer>") |
510 | | -# count_answer_2 = predict_no_think.count("</answer>") |
511 | | -# if count_answer_1 != count_answer_2 or count_answer_1 == 0: |
512 | | -# is_format_error = True |
513 | | - |
514 | | -# # extract answer content from answer tag |
515 | | -# if count_answer_1 == 0 or count_answer_2 == 0: |
516 | | -# answer_content = "" |
517 | | -# else: |
518 | | -# answer_content = predict_str.split("<answer>")[-1].split("</answer>")[0].strip() |
519 | | - |
520 | | -# model_answer = "" |
521 | | -# if answer_content == "": |
522 | | -# acc_reward = 0.0 |
523 | | -# else: |
524 | | -# answer_pattern = r"\\boxed{([^}]+)}" |
525 | | -# answer_list = re.findall(answer_pattern, answer_content, flags=re.DOTALL) |
526 | | -# if len(answer_list) == 0: |
527 | | -# acc_reward = 0.0 |
528 | | -# is_format_error = True |
529 | | -# else: |
530 | | -# if len(answer_list) > 1: |
531 | | -# is_format_error = True |
532 | | - |
533 | | -# model_answer = answer_list[-1] |
534 | | -# if rule_math_verify(ground_truth, model_answer): |
535 | | -# acc_reward = 1.0 |
536 | | -# else: |
537 | | -# acc_reward = 1.0 if generative_verify(extra_info["question"], ground_truth, model_answer) else 0.0 |
538 | | - |
539 | | -# format_reward = 0.0 if is_format_error else 1.0 |
540 | | - |
541 | | -# return 0.8 * acc_reward + 0.2 * format_reward, acc_reward, format_reward |
542 | | - |
543 | | - |
544 | 456 | def compute_score_time_r1(predict_str: str, ground_truth: str, extra_info=None, use_recall=False) -> float: |
545 | 457 | is_format_error = False |
546 | 458 | # predict_str = "<think>" + predict_str |
@@ -633,166 +545,3 @@ def compute_score_time_r1(predict_str: str, ground_truth: str, extra_info=None, |
633 | 545 | format_reward = 0.0 if is_format_error else 1.0 |
634 | 546 |
|
635 | 547 | return 1.0 * acc_reward + 1.0 * format_reward, acc_reward, format_reward |
636 | | - |
637 | | - |
638 | | -# def compute_score_videor1(predict_str: str, ground_truth: str, extra_info=None, **kwargs) -> float: |
639 | | -# """ |
640 | | -# Video-R1 style reward computation with accuracy and format rewards. |
641 | | - |
642 | | -# Args: |
643 | | -# predict_str: Model prediction string |
644 | | -# ground_truth: Ground truth answer |
645 | | -# extra_info: Dictionary containing additional info like 'question' and 'problem_type' |
646 | | -# **kwargs: Additional arguments like temporal, len_control settings |
647 | | - |
648 | | -# Returns: |
649 | | -# Tuple of (total_reward, accuracy_reward, format_reward) or just total_reward |
650 | | -# """ |
651 | | - |
652 | | -# def extract_answer_videor1(text): |
653 | | -# """Extract answer from <answer></answer> tags""" |
654 | | -# pattern = r"<answer>\s*(.*?)\s*</answer>" |
655 | | -# match = re.search(pattern, text, re.DOTALL) |
656 | | -# if match: |
657 | | -# return match.group(1).strip() |
658 | | -# return "" |
659 | | - |
660 | | -# def normalize_number(num_str): |
661 | | -# """Normalize number string to float""" |
662 | | -# try: |
663 | | -# num_str = num_str.replace(",", "") |
664 | | -# return float(num_str) |
665 | | -# except Exception as e: |
666 | | -# print(f"Error converting '{num_str}' to float: {e}") |
667 | | -# return None |
668 | | - |
669 | | -# def wer(reference, hypothesis): |
670 | | -# """Word Error Rate calculation""" |
671 | | -# ref_words = reference.split() |
672 | | -# hyp_words = hypothesis.split() |
673 | | -# m = len(ref_words) |
674 | | -# n = len(hyp_words) |
675 | | -# d = [[0] * (n + 1) for _ in range(m + 1)] |
676 | | -# for i in range(m + 1): |
677 | | -# d[i][0] = i |
678 | | -# for j in range(n + 1): |
679 | | -# d[0][j] = j |
680 | | -# for i in range(1, m + 1): |
681 | | -# for j in range(1, n + 1): |
682 | | -# if ref_words[i - 1] == hyp_words[j - 1]: |
683 | | -# d[i][j] = d[i - 1][j - 1] |
684 | | -# else: |
685 | | -# d[i][j] = 1 + min(d[i - 1][j], d[i][j - 1], d[i - 1][j - 1]) |
686 | | -# return d[m][n] / max(1, m) |
687 | | - |
688 | | -# def compute_rouge_score(reference, hypothesis, use_stemmer=True): |
689 | | -# """Compute ROUGE score""" |
690 | | -# scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=use_stemmer) |
691 | | -# scores = scorer.score(reference, hypothesis) |
692 | | -# average_fmeasure = (scores["rouge1"].fmeasure + scores["rouge2"].fmeasure + scores["rougeL"].fmeasure) / 3 |
693 | | -# return average_fmeasure |
694 | | - |
695 | | -# # Initialize variables |
696 | | -# is_format_error = False |
697 | | - |
698 | | -# # Format checking - Video-R1 style |
699 | | -# count_think_1 = predict_str.count("<think>") |
700 | | -# count_think_2 = predict_str.count("</think>") |
701 | | -# count_answer_1 = predict_str.count("<answer>") |
702 | | -# count_answer_2 = predict_str.count("</answer>") |
703 | | - |
704 | | -# # Check format requirements |
705 | | -# if count_think_1 != count_think_2 or count_think_1 == 0: |
706 | | -# is_format_error = True |
707 | | -# if count_answer_1 != count_answer_2 or count_answer_1 == 0: |
708 | | -# is_format_error = True |
709 | | - |
710 | | -# # Check basic format pattern |
711 | | -# pattern = r"<think>.*?</think>\s*<answer>.*?</answer>" |
712 | | -# if not re.search(pattern, predict_str, re.DOTALL): |
713 | | -# is_format_error = True |
714 | | - |
715 | | -# # Extract answer |
716 | | -# if count_answer_1 == 0 or count_answer_2 == 0: |
717 | | -# answer_text = "" |
718 | | -# else: |
719 | | -# answer_text = extract_answer_videor1(predict_str) |
720 | | - |
721 | | -# # Compute accuracy reward |
722 | | -# # If there's a format error, set accuracy reward to 0 regardless of answer content |
723 | | -# if answer_text == "" or is_format_error: |
724 | | -# acc_reward = 0.0 |
725 | | -# else: |
726 | | -# try: |
727 | | -# # Get problem type from extra_info |
728 | | -# problem_type = extra_info.get("problem_type", "free-form") if extra_info else "free-form" |
729 | | - |
730 | | -# output_ans = answer_text |
731 | | -# gt_ans = extract_answer_videor1(ground_truth) if "<answer>" in ground_truth else ground_truth |
732 | | - |
733 | | -# if problem_type == "multiple choice": |
734 | | -# acc_reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0 |
735 | | -# elif problem_type == "numerical": |
736 | | -# gt_has_decimal = ("." in gt_ans) or ("," in gt_ans) |
737 | | -# out_has_decimal = ("." in output_ans) or ("," in output_ans) |
738 | | -# if gt_has_decimal != out_has_decimal: |
739 | | -# acc_reward = 0.0 |
740 | | -# else: |
741 | | -# gt_number = normalize_number(gt_ans) |
742 | | -# out_number = normalize_number(output_ans) |
743 | | -# if gt_number is None or out_number is None: |
744 | | -# acc_reward = 0.0 |
745 | | -# else: |
746 | | -# acc_reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0 |
747 | | -# elif problem_type == "OCR": |
748 | | -# error_rate = wer(gt_ans, output_ans) |
749 | | -# acc_reward = 1 - error_rate |
750 | | -# acc_reward = max(0.0, min(1.0, acc_reward)) |
751 | | -# elif problem_type == "free-form": |
752 | | -# score = compute_rouge_score(gt_ans, output_ans) |
753 | | -# acc_reward = max(0.0, min(1.0, score)) |
754 | | -# elif problem_type == "regression": |
755 | | -# gt_number = normalize_number(gt_ans) |
756 | | -# out_number = normalize_number(output_ans) |
757 | | -# if gt_number is None or out_number is None: |
758 | | -# acc_reward = 0.0 |
759 | | -# else: |
760 | | -# rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9) |
761 | | -# rel_diff = min(1.0, max(0.0, rel_diff)) |
762 | | -# acc_reward = 1 - rel_diff |
763 | | -# else: |
764 | | -# # Unknown problem type - return 0 (same as original Video-R1) |
765 | | -# acc_reward = 0.0 |
766 | | - |
767 | | -# except Exception as e: |
768 | | -# print(f"Error in Video-R1 accuracy reward computation: {e}") |
769 | | -# acc_reward = 0.0 |
770 | | - |
771 | | -# # Penalize for overly long answers |
772 | | -# if len(answer_text) >= 1000: |
773 | | -# acc_reward = 0.0 |
774 | | -# is_format_error = True |
775 | | - |
776 | | -# # Format reward |
777 | | -# format_reward = 0.0 if is_format_error else 1.0 |
778 | | - |
779 | | -# # Debug logging |
780 | | -# if os.getenv("DEBUG_MODE") == "true": |
781 | | -# log_path = os.getenv("LOG_PATH", "./videor1_debug.log") |
782 | | -# current_time = datetime.now().strftime("%d-%H-%M-%S-%f") |
783 | | -# with open(log_path, "a", encoding="utf-8") as f: |
784 | | -# f.write(f"------------- {current_time} Video-R1 Reward -------------\n") |
785 | | -# f.write(f"Prediction: {predict_str}\n") |
786 | | -# f.write(f"Ground Truth: {ground_truth}\n") |
787 | | -# f.write(f"Extracted Answer: {answer_text}\n") |
788 | | -# f.write(f"Accuracy Reward: {acc_reward}\n") |
789 | | -# f.write(f"Format Reward: {format_reward}\n") |
790 | | -# f.write(f"Format Error: {is_format_error}\n") |
791 | | -# f.write("=" * 50 + "\n") |
792 | | - |
793 | | -# return (acc_reward + format_reward, acc_reward, format_reward) |
794 | | - |
795 | | - |
796 | | -# if __name__ == "__main__": |
797 | | -# # 测试新的Video-R1 reward函数 |
798 | | -# test_compute_score_videor1() |
0 commit comments