Skip to content

Commit ac57c7a

Browse files
committed
Remove env variables
1 parent 4714111 commit ac57c7a

File tree

1 file changed

+1
-252
lines changed

1 file changed

+1
-252
lines changed

custom_rewards/vl_agent.py

Lines changed: 1 addition & 252 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
openai_api_key = "EMPTY"
2828
openai_api_base_list = [
29-
os.environ.get("LLM_AS_A_JUDGE_BASE", "https://sd285v869b9467c7sab70.apigateway-cn-shanghai.volceapi.com/v1"),
29+
os.environ.get("LLM_AS_A_JUDGE_BASE", "YOUR_API_BASE"), # e.g. http://localhost:8000/v1
3030
]
3131

3232
client_list = []
@@ -453,94 +453,6 @@ def compute_score(predict_str: str, ground_truth: str, extra_info=None, **kwargs
453453
else:
454454
return (0.8 * acc_reward + 0.2 * format_reward, acc_reward, format_reward)
455455

456-
457-
# def rule_math_verify(ground_truth, model_answer):
458-
# gold = parse(ground_truth)
459-
# answer = parse(model_answer)
460-
# return verify(gold, answer)
461-
462-
463-
# def generative_verify(query, ground_truth, model_answer):
464-
# client_idx = random.randint(0, len(client_list) - 1)
465-
# client = client_list[client_idx]
466-
# model_name = model_name_list[client_idx]
467-
468-
# full_prompt = MATH_VERIFY_PROMPT.format(
469-
# query=query,
470-
# gold_ans=ground_truth,
471-
# pred_ans=model_answer,
472-
# )
473-
474-
# response = ""
475-
# for it in range(8):
476-
# try:
477-
# chat_response = client.chat.completions.create(
478-
# model=model_name,
479-
# messages=[
480-
# {"role": "user", "content": full_prompt},
481-
# ],
482-
# seed=random.randint(0, 1000000),
483-
# temperature=0.0,
484-
# )
485-
# response = chat_response.choices[0].message.content.strip()
486-
# break
487-
# except Exception as e:
488-
# print(f" [ERROR math] generative_verify error: {e}")
489-
# continue
490-
491-
# judgement = response.split("## Equivalence Judgement")[-1].lower()
492-
# if "true" in judgement and "false" not in judgement:
493-
# return True
494-
# elif "false" in judgement and "true" not in judgement:
495-
# return False
496-
# else:
497-
# print(" [ERROR math] verify bug output: ")
498-
499-
500-
# def compute_score_math(predict_str: str, ground_truth: str, extra_info=None) -> float:
501-
# is_format_error = False
502-
# # predict_str = "<think>" + predict_str
503-
# count_think_1 = predict_str.count("<think>")
504-
# count_think_2 = predict_str.count("</think>")
505-
# if count_think_1 != count_think_2 or count_think_1 == 0: # reward hacking
506-
# is_format_error = True
507-
508-
# predict_no_think = predict_str.split("</think>")[-1].strip()
509-
# count_answer_1 = predict_no_think.count("<answer>")
510-
# count_answer_2 = predict_no_think.count("</answer>")
511-
# if count_answer_1 != count_answer_2 or count_answer_1 == 0:
512-
# is_format_error = True
513-
514-
# # extract answer content from answer tag
515-
# if count_answer_1 == 0 or count_answer_2 == 0:
516-
# answer_content = ""
517-
# else:
518-
# answer_content = predict_str.split("<answer>")[-1].split("</answer>")[0].strip()
519-
520-
# model_answer = ""
521-
# if answer_content == "":
522-
# acc_reward = 0.0
523-
# else:
524-
# answer_pattern = r"\\boxed{([^}]+)}"
525-
# answer_list = re.findall(answer_pattern, answer_content, flags=re.DOTALL)
526-
# if len(answer_list) == 0:
527-
# acc_reward = 0.0
528-
# is_format_error = True
529-
# else:
530-
# if len(answer_list) > 1:
531-
# is_format_error = True
532-
533-
# model_answer = answer_list[-1]
534-
# if rule_math_verify(ground_truth, model_answer):
535-
# acc_reward = 1.0
536-
# else:
537-
# acc_reward = 1.0 if generative_verify(extra_info["question"], ground_truth, model_answer) else 0.0
538-
539-
# format_reward = 0.0 if is_format_error else 1.0
540-
541-
# return 0.8 * acc_reward + 0.2 * format_reward, acc_reward, format_reward
542-
543-
544456
def compute_score_time_r1(predict_str: str, ground_truth: str, extra_info=None, use_recall=False) -> float:
545457
is_format_error = False
546458
# predict_str = "<think>" + predict_str
@@ -633,166 +545,3 @@ def compute_score_time_r1(predict_str: str, ground_truth: str, extra_info=None,
633545
format_reward = 0.0 if is_format_error else 1.0
634546

635547
return 1.0 * acc_reward + 1.0 * format_reward, acc_reward, format_reward
636-
637-
638-
# def compute_score_videor1(predict_str: str, ground_truth: str, extra_info=None, **kwargs) -> float:
639-
# """
640-
# Video-R1 style reward computation with accuracy and format rewards.
641-
642-
# Args:
643-
# predict_str: Model prediction string
644-
# ground_truth: Ground truth answer
645-
# extra_info: Dictionary containing additional info like 'question' and 'problem_type'
646-
# **kwargs: Additional arguments like temporal, len_control settings
647-
648-
# Returns:
649-
# Tuple of (total_reward, accuracy_reward, format_reward) or just total_reward
650-
# """
651-
652-
# def extract_answer_videor1(text):
653-
# """Extract answer from <answer></answer> tags"""
654-
# pattern = r"<answer>\s*(.*?)\s*</answer>"
655-
# match = re.search(pattern, text, re.DOTALL)
656-
# if match:
657-
# return match.group(1).strip()
658-
# return ""
659-
660-
# def normalize_number(num_str):
661-
# """Normalize number string to float"""
662-
# try:
663-
# num_str = num_str.replace(",", "")
664-
# return float(num_str)
665-
# except Exception as e:
666-
# print(f"Error converting '{num_str}' to float: {e}")
667-
# return None
668-
669-
# def wer(reference, hypothesis):
670-
# """Word Error Rate calculation"""
671-
# ref_words = reference.split()
672-
# hyp_words = hypothesis.split()
673-
# m = len(ref_words)
674-
# n = len(hyp_words)
675-
# d = [[0] * (n + 1) for _ in range(m + 1)]
676-
# for i in range(m + 1):
677-
# d[i][0] = i
678-
# for j in range(n + 1):
679-
# d[0][j] = j
680-
# for i in range(1, m + 1):
681-
# for j in range(1, n + 1):
682-
# if ref_words[i - 1] == hyp_words[j - 1]:
683-
# d[i][j] = d[i - 1][j - 1]
684-
# else:
685-
# d[i][j] = 1 + min(d[i - 1][j], d[i][j - 1], d[i - 1][j - 1])
686-
# return d[m][n] / max(1, m)
687-
688-
# def compute_rouge_score(reference, hypothesis, use_stemmer=True):
689-
# """Compute ROUGE score"""
690-
# scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=use_stemmer)
691-
# scores = scorer.score(reference, hypothesis)
692-
# average_fmeasure = (scores["rouge1"].fmeasure + scores["rouge2"].fmeasure + scores["rougeL"].fmeasure) / 3
693-
# return average_fmeasure
694-
695-
# # Initialize variables
696-
# is_format_error = False
697-
698-
# # Format checking - Video-R1 style
699-
# count_think_1 = predict_str.count("<think>")
700-
# count_think_2 = predict_str.count("</think>")
701-
# count_answer_1 = predict_str.count("<answer>")
702-
# count_answer_2 = predict_str.count("</answer>")
703-
704-
# # Check format requirements
705-
# if count_think_1 != count_think_2 or count_think_1 == 0:
706-
# is_format_error = True
707-
# if count_answer_1 != count_answer_2 or count_answer_1 == 0:
708-
# is_format_error = True
709-
710-
# # Check basic format pattern
711-
# pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
712-
# if not re.search(pattern, predict_str, re.DOTALL):
713-
# is_format_error = True
714-
715-
# # Extract answer
716-
# if count_answer_1 == 0 or count_answer_2 == 0:
717-
# answer_text = ""
718-
# else:
719-
# answer_text = extract_answer_videor1(predict_str)
720-
721-
# # Compute accuracy reward
722-
# # If there's a format error, set accuracy reward to 0 regardless of answer content
723-
# if answer_text == "" or is_format_error:
724-
# acc_reward = 0.0
725-
# else:
726-
# try:
727-
# # Get problem type from extra_info
728-
# problem_type = extra_info.get("problem_type", "free-form") if extra_info else "free-form"
729-
730-
# output_ans = answer_text
731-
# gt_ans = extract_answer_videor1(ground_truth) if "<answer>" in ground_truth else ground_truth
732-
733-
# if problem_type == "multiple choice":
734-
# acc_reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
735-
# elif problem_type == "numerical":
736-
# gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
737-
# out_has_decimal = ("." in output_ans) or ("," in output_ans)
738-
# if gt_has_decimal != out_has_decimal:
739-
# acc_reward = 0.0
740-
# else:
741-
# gt_number = normalize_number(gt_ans)
742-
# out_number = normalize_number(output_ans)
743-
# if gt_number is None or out_number is None:
744-
# acc_reward = 0.0
745-
# else:
746-
# acc_reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
747-
# elif problem_type == "OCR":
748-
# error_rate = wer(gt_ans, output_ans)
749-
# acc_reward = 1 - error_rate
750-
# acc_reward = max(0.0, min(1.0, acc_reward))
751-
# elif problem_type == "free-form":
752-
# score = compute_rouge_score(gt_ans, output_ans)
753-
# acc_reward = max(0.0, min(1.0, score))
754-
# elif problem_type == "regression":
755-
# gt_number = normalize_number(gt_ans)
756-
# out_number = normalize_number(output_ans)
757-
# if gt_number is None or out_number is None:
758-
# acc_reward = 0.0
759-
# else:
760-
# rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
761-
# rel_diff = min(1.0, max(0.0, rel_diff))
762-
# acc_reward = 1 - rel_diff
763-
# else:
764-
# # Unknown problem type - return 0 (same as original Video-R1)
765-
# acc_reward = 0.0
766-
767-
# except Exception as e:
768-
# print(f"Error in Video-R1 accuracy reward computation: {e}")
769-
# acc_reward = 0.0
770-
771-
# # Penalize for overly long answers
772-
# if len(answer_text) >= 1000:
773-
# acc_reward = 0.0
774-
# is_format_error = True
775-
776-
# # Format reward
777-
# format_reward = 0.0 if is_format_error else 1.0
778-
779-
# # Debug logging
780-
# if os.getenv("DEBUG_MODE") == "true":
781-
# log_path = os.getenv("LOG_PATH", "./videor1_debug.log")
782-
# current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
783-
# with open(log_path, "a", encoding="utf-8") as f:
784-
# f.write(f"------------- {current_time} Video-R1 Reward -------------\n")
785-
# f.write(f"Prediction: {predict_str}\n")
786-
# f.write(f"Ground Truth: {ground_truth}\n")
787-
# f.write(f"Extracted Answer: {answer_text}\n")
788-
# f.write(f"Accuracy Reward: {acc_reward}\n")
789-
# f.write(f"Format Reward: {format_reward}\n")
790-
# f.write(f"Format Error: {is_format_error}\n")
791-
# f.write("=" * 50 + "\n")
792-
793-
# return (acc_reward + format_reward, acc_reward, format_reward)
794-
795-
796-
# if __name__ == "__main__":
797-
# # 测试新的Video-R1 reward函数
798-
# test_compute_score_videor1()

0 commit comments

Comments
 (0)