diff --git a/ding/reward_model/__init__.py b/ding/reward_model/__init__.py index 4538102861..5202588d9a 100644 --- a/ding/reward_model/__init__.py +++ b/ding/reward_model/__init__.py @@ -13,3 +13,7 @@ from .guided_cost_reward_model import GuidedCostRewardModel from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel from .icm_reward_model import ICMRewardModel +# LLM/VLM reward models and verifiers +from .math_reward_model import MathRewardModel +from .math_rule_reward_model import MathRuleRewardModel +from .multi_modal_reward_model import MultiModalRewardModel diff --git a/ding/reward_model/math_reward_model.py b/ding/reward_model/math_reward_model.py new file mode 100644 index 0000000000..90a5e3a0d3 --- /dev/null +++ b/ding/reward_model/math_reward_model.py @@ -0,0 +1,144 @@ +from typing import List, Dict +from easydict import EasyDict +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoTokenizer, AutoModel +import torch +import torch.nn.functional as F +from ding.utils import REWARD_MODEL_REGISTRY +from .base_reward_model import BaseRewardModel + + +@REWARD_MODEL_REGISTRY.register('math') +class MathRewardModel(BaseRewardModel): + config = dict( + # (str) The type of the reward model. + type='math', + # (str) The name of the tokenizer and model + model_name='Qwen/Qwen2.5-Math-PRM-7B', + ) + + def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWriter') -> None: # noqa + self.cfg = config + self.device = device + self.logger = logger + self.tb_logger = tb_logger + + # 初始化tokenizer和model + self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name, trust_remote_code=True) + self.model = AutoModel.from_pretrained( + self.cfg.model_name, device_map=self.device, torch_dtype=torch.bfloat16, trust_remote_code=True + ) + self.model.eval() + + def make_step_rewards(self, logits: torch.Tensor, token_masks: torch.Tensor) -> List[List[float]]: + """Calculate step-wise rewards from model outputs""" + probabilities = F.softmax(logits, dim=-1) + probabilities = probabilities * token_masks.unsqueeze(-1) # bs, seq_len, num_labels + + all_scores_res = [] + for i in range(probabilities.size(0)): + sample = probabilities[i] # seq_len, num_labels + positive_probs = sample[sample != 0].view(-1, 2)[:, 1] # valid_tokens, num_labels + non_zero_elements_list = positive_probs.cpu().tolist() + all_scores_res.append(non_zero_elements_list) + return all_scores_res + + def estimate(self, data: List[Dict]) -> List[Dict]: + """ + Overview: + Estimate rewards for mathematical reasoning steps using Qwen2.5-Math-PRM-7B model. + Arguments: + - data (:obj:`List[Dict]`): List of dictionaries containing: + - system (:obj:`str`): System prompt for the model. + - query (:obj:`str`): The mathematical query to be evaluated. + - response (:obj:`List[str]`): List of reasoning steps. + Returns: + - reward (:obj:`List[Dict]`): List of dictionaries containing: + - reward (:obj:`float`): Final reward (last step reward). + - metadata (:obj:`Dict`): Additional information including: + - query (:obj:`str`): Original query. + - step_rewards (:obj:`List[float]`): Rewards for each reasoning step. + - num_steps (:obj:`int`): Number of reasoning steps. + Shapes: + - input_ids (:obj:`torch.LongTensor`): :math:`(B, L)`, where B is batch size and L is sequence length. + - outputs (:obj:`torch.FloatTensor`): :math:`(B, L, H)`, where H is hidden size. + - token_masks (:obj:`torch.BoolTensor`): :math:`(B, L)`. + - step_rewards (:obj:`List[List[float]]`): List of length B, each containing S rewards where S is num steps. + Examples: + >>> data = [{ + >>> "system": "Please reason step by step...", + >>> "query": "What is 1 + 1?", + >>> "response": ["First, we have 1", "Then add 1", "Therefore, 1 + 1 = 2"] + >>> }] + >>> results = model.estimate(data) + >>> print(results[0]["reward"]) # 1.0 + >>> print(results[0]["metadata"]["step_rewards"]) # [0.8, 0.9, 1.0] + """ + all_messages = [] + for item in data: + messages = [ + { + "role": "system", + "content": item['system'] + }, + { + "role": "user", + "content": item['query'] + }, + { + "role": "assistant", + "content": "".join(item['response']) + "" + }, + ] + all_messages.append(messages) + + conversation_strs = [ + self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + for messages in all_messages + ] + + # 批量编码输入 + input_ids = self.tokenizer( + conversation_strs, return_tensors="pt", padding=True, truncation=True + )["input_ids"].to(self.model.device) + + with torch.no_grad(): + outputs = self.model(input_ids=input_ids) + + step_sep_id = self.tokenizer.encode("")[0] + token_masks = (input_ids == step_sep_id) + batch_rewards = self.make_step_rewards(outputs[0], token_masks) + + results = [] + for item, step_rewards in zip(data, batch_rewards): + results.append( + { + "reward": step_rewards[-1] if step_rewards else 0.0, + "metadata": { + "query": item['query'], + "step_rewards": step_rewards, + "num_steps": len(item['response']), + } + } + ) + + return results + + def train(self): + """ + Training is not implemented for this reward model as it uses a pre-trained model + """ + self.logger.warning("Training is not implemented for this reward model") + pass + + def collect_data(self, data: list) -> None: + """ + Data collection is not needed for this reward model + """ + pass + + def clear_data(self) -> None: + """ + Data clearing is not needed for this reward model + """ + pass diff --git a/ding/reward_model/math_rule_reward_model.py b/ding/reward_model/math_rule_reward_model.py new file mode 100644 index 0000000000..ffaf720c54 --- /dev/null +++ b/ding/reward_model/math_rule_reward_model.py @@ -0,0 +1,718 @@ +from typing import Tuple, Optional, List, Dict +from easydict import EasyDict +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoTokenizer +import re +import math +import json +from ding.utils import REWARD_MODEL_REGISTRY +from .base_reward_model import BaseRewardModel + + +@REWARD_MODEL_REGISTRY.register("math_rule") +class MathRuleRewardModel(BaseRewardModel): + """ + Math rule-based reward model for evaluating mathematical answers. + Supports various mathematical expression formats including LaTeX, fractions, percentages, etc. + """ + + config = dict( + # (str) The type of the reward model. + type="math_rule", + # (str) The name of the dataset, usually the huggingface dataset name. + dataset_name="", + # (str) The name of the tokenizer, usually the huggingface tokenizer name. + tokenizer_name="", + # (float) The score of format error. + format_error_reward=-2, + # (float) The score of answer error. + answer_error_reward=-1, + # (float) The score of correct. + correct_reward=1, + # (float) Relative tolerance for numerical comparison + rel_tol=1e-5, + # (float) Absolute tolerance for numerical comparison + abs_tol=1e-8, + ) + + def __init__( + self, + config: EasyDict, + device: str = "cpu", + logger=None, + tb_logger: "SummaryWriter" = None, + ) -> None: # noqa + """Initialize the math rule reward model""" + self.cfg = config + self.device = device + self.logger = logger + self.tb_logger = tb_logger + + # Initialize tokenizer + if hasattr(config, "tokenizer_name") and config.tokenizer_name: + self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) + self.pad_token = (self.tokenizer.pad_token if self.tokenizer.pad_token else "[PAD]") + self.eos_token = (self.tokenizer.eos_token if self.tokenizer.eos_token else "[EOS]") + else: + self.tokenizer = None + self.pad_token = "[PAD]" + self.eos_token = "[EOS]" + + def _process_target_answer(self, text: str) -> Optional[float]: + """Process target answer text and convert to numerical value""" + if text is None or not text.strip(): + return None + # Clean and normalize text + if self.tokenizer: + text = strip_sequence(text, self.pad_token, self.eos_token) + text = normalize_text(text) + + # Try to process the mathematical expression + try: + result = self._process_math_expression(text) + if result is not None: + return result + except Exception as e: + if self.logger: + self.logger.warning(f"Error processing target answer: {e}") + return None + + def _process_response_answer(self, response: str) -> Tuple[Optional[float], Optional[str]]: + """Process response text, extract and convert to numerical value""" + if response is None or not response.strip(): + return None, None + + # Clean text + if self.tokenizer: + response = strip_sequence(response, self.pad_token, self.eos_token) + + # First try to extract the final answer + final_answer = self._extract_final_answer(response) + + # If a final answer is extracted, try to process it + if final_answer: + try: + value = self._process_math_expression(final_answer) + if value is not None: + return value, final_answer + except Exception as e: + if self.logger: + self.logger.debug(f"Error processing final answer: {e}") + + # If unable to get a valid value from the final answer, try to extract all possible expressions + expressions = self._extract_all_expressions(response) + + # Try to process each expression until a valid answer is found + for expr in expressions: + try: + value = self._process_math_expression(expr) + if value is not None: + return value, expr + except Exception as e: + if self.logger: + self.logger.debug(f"Error processing expression '{expr}': {e}") + + # If all attempts fail, return None + return None, None + + def _check_answer_match(self, pred: Optional[float], target: Optional[float]) -> bool: + """Check if two answers match within tolerance""" + if pred is None or target is None: + return False + try: + return math.isclose( + pred, + target, + rel_tol=self.cfg.get("rel_tol", 1e-5), + abs_tol=self.cfg.get("abs_tol", 1e-8), + ) + except Exception as e: + if self.logger: + self.logger.warning(f"Error comparing answers: {e}") + return False + + def _extract_final_answer(self, text: str) -> Optional[str]: + """ + Extract the final answer from text. + Supports various formats: + 1. "The answer is X". + 2. "Therefore, X is the answer". + 3. "X" (if only one number). + 4. "\\boxed{X}". + 5. "= X" (expression after equals sign). + 6. Last LaTeX expression like \\frac{a}{b}, \\sqrt{x}, etc. + """ + # Try to extract boxed content + boxed_match = re.search(r"\\boxed\{([^}]+)\}", text) + if boxed_match: + return boxed_match.group(0) + + # Try to extract "the answer is X" format + answer_match = re.search(r"(?:the\s+)?answer\s+is\s+([^\.]+)", text, re.IGNORECASE) + if answer_match: + answer_text = answer_match.group(1).strip() + # Check if the extracted answer contains a LaTeX expression + latex_match = re.search(r"(\\frac\{[^}]+\}\{[^}]+\}|\\sqrt\{[^}]+\})", answer_text) + if latex_match: + return latex_match.group(0) + return answer_text + + # Try to extract "therefore, X is the answer" format + therefore_match = re.search(r"therefore,?\s+([^\.]+)\s+is\s+the\s+answer", text, re.IGNORECASE) + if therefore_match: + therefore_text = therefore_match.group(1).strip() + # Check if the extracted answer contains a LaTeX expression + latex_match = re.search(r"(\\frac\{[^}]+\}\{[^}]+\}|\\sqrt\{[^}]+\})", therefore_text) + if latex_match: + return latex_match.group(0) + return therefore_text + + # Try to extract expression after equals sign + equals_matches = re.findall(r"=\s*([^\.=]+?)(?:\.|$|=)", text) + if equals_matches: + last_eq = equals_matches[-1].strip() + # Check if there's a LaTeX expression after the equals sign + latex_match = re.search(r"(\\frac\{[^}]+\}\{[^}]+\}|\\sqrt\{[^}]+\})", last_eq) + if latex_match: + return latex_match.group(0) + return last_eq + + # Try to directly extract LaTeX fraction expression + frac_matches = re.findall(r"(\\frac\{[^}]+\}\{[^}]+\})", text) + if frac_matches: + return frac_matches[-1] + + # Try to directly extract LaTeX square root expression + sqrt_matches = re.findall(r"(\\sqrt\{[^}]+\})", text) + if sqrt_matches: + return sqrt_matches[-1] + + # Try to extract pi-related expressions + pi_expr = self._extract_pi_expressions(text) + if pi_expr: + return pi_expr + + # If there's only one number, return it directly + numbers = re.findall(r"-?\d*\.?\d+", text) + if len(numbers) == 1: + return numbers[0] + + # Try to extract the last number (as a fallback) + if numbers: + return numbers[-1] + + return None + + def _extract_pi_expressions(self, text: str) -> Optional[str]: + """Extract pi-related expressions from text""" + # Try to extract expressions like \frac{a\pi}{b} + pi_frac_matches = re.findall(r"(\\frac\{[^}]*\\pi[^}]*\}\{[^}]+\})", text) + if pi_frac_matches: + return pi_frac_matches[-1] + + # Try to extract expressions like \frac{a}{b}\pi + frac_pi_matches = re.findall(r"(\\frac\{[^}]+\}\{[^}]+\}\\pi)", text) + if frac_pi_matches: + return frac_pi_matches[-1] + + # Try to extract expressions like 11π/6 + text_with_pi = text.replace("\\pi", "π") + pi_div_matches = re.findall(r"(\d+π/\d+)", text_with_pi) + if pi_div_matches: + return pi_div_matches[-1] + + # Try to extract expressions like π/2 + pi_simple_div_matches = re.findall(r"(π/\d+)", text_with_pi) + if pi_simple_div_matches: + return pi_simple_div_matches[-1] + + # Try to extract expressions like 2π + pi_mult_matches = re.findall(r"(\d+π)", text_with_pi) + if pi_mult_matches: + return pi_mult_matches[-1] + + # Check for standalone π + if "π" in text_with_pi or "\\pi" in text: + pi_standalone = re.search(r"(^|[^a-zA-Z0-9])π($|[^a-zA-Z0-9])", text_with_pi) + if pi_standalone: + return "π" + + return None + + def _process_pi_expressions(self, text: str) -> Optional[float]: + """Process pi-related expressions and convert to numerical value""" + # Standardize pi notation + text = text.replace("\\pi", "π") + + # Process expressions like 11π/6 + pi_match = re.search(r"(\d+)π/(\d+)", text) + if pi_match: + num, denom = map(int, pi_match.groups()) + return (num * math.pi) / denom + + # Process expressions like π/2 + pi_div_match = re.search(r"π/(\d+)", text) + if pi_div_match: + denom = int(pi_div_match.group(1)) + return math.pi / denom + + # Process expressions like 2π + pi_mult_match = re.search(r"(\d+)π", text) + if pi_mult_match: + num = int(pi_mult_match.group(1)) + return num * math.pi + + # If just π + if text == "π": + return math.pi + + return None + + def _process_math_expression(self, text: str) -> Optional[float]: + """ + Process special mathematical expressions, such as: + 1. Fractions: 1/2, \\frac{1}{2} + 2. Percentages: 50% + 3. Scientific notation: 1.2e-3 + 4. Mixed expressions: 1 + 2/3 + 5. Square roots: \\sqrt{2} + 6. Mixed fractions: 1\\frac{1}{2} + 7. Max/min functions: \\max(1,2,3), \\min(1,2,3) + 8. Pi-related expressions: 11π/6, \\frac{11\\pi}{6} + """ + if text is None or not text.strip(): + return None + + try: + # Remove all spaces and unnecessary LaTeX commands + text = text.replace(" ", "") + text = text.replace("\\left", "").replace("\\right", "") + + # Process pi-related expressions + if "π" in text or "\\pi" in text: + pi_value = self._process_pi_expressions(text) + if pi_value is not None: + return pi_value + + # Process percentages + if "%" in text: + return float(text.replace("%", "")) / 100 + + # Process LaTeX square roots \sqrt{...} + sqrt_match = re.search(r"\\sqrt\{([^}]+)\}", text) + if sqrt_match: + inner_expr = sqrt_match.group(1) + inner_value = self._process_math_expression(inner_expr) + if inner_value is not None: + return math.sqrt(inner_value) + + # Process LaTeX fractions \frac{...}{...} + frac_match = re.search(r"\\frac\{([^}]+)\}\{([^}]+)\}", text) + if frac_match: + num = frac_match.group(1) + denom = frac_match.group(2) + + # Recursively process numerator and denominator + num_value = self._process_math_expression(num) + denom_value = self._process_math_expression(denom) + + if (num_value is not None and denom_value is not None and denom_value != 0): + return num_value / denom_value + + # Process mixed fractions 1\frac{1}{2} + mixed_frac_match = re.search(r"(\d+)\\frac\{([^}]+)\}\{([^}]+)\}", text) + if mixed_frac_match: + whole = int(mixed_frac_match.group(1)) + num = mixed_frac_match.group(2) + denom = mixed_frac_match.group(3) + + # Recursively process numerator and denominator + num_value = self._process_math_expression(num) + denom_value = self._process_math_expression(denom) + + if (num_value is not None and denom_value is not None and denom_value != 0): + return whole + (num_value / denom_value) + + # Process max function \max(a,b,c) + max_match = re.search(r"\\max\(([^)]+)\)", text) + if max_match: + values_str = max_match.group(1) + values = values_str.split(",") + processed_values = [] + for val in values: + processed_val = self._process_math_expression(val) + if processed_val is not None: + processed_values.append(processed_val) + if processed_values: + return max(processed_values) + + # Process min function \min(a,b,c) + min_match = re.search(r"\\min\(([^)]+)\)", text) + if min_match: + values_str = min_match.group(1) + values = values_str.split(",") + processed_values = [] + for val in values: + processed_val = self._process_math_expression(val) + if processed_val is not None: + processed_values.append(processed_val) + if processed_values: + return min(processed_values) + + # Process simple arithmetic operations + if any(op in text for op in ["+", "-", "*", "/"]): + # Safe eval, only allowing basic operations + safe_dict = {"__builtins__": None} + return float(eval(text, safe_dict)) + + # Process scientific notation + if "e" in text.lower() and re.match(r"-?\d+\.?\d*e[+-]?\d+", text.lower()): + return float(text) + + # Process regular numbers + return float(text) + except Exception as e: + # Log exception information for debugging + if self.logger: + self.logger.debug(f"Error processing math expression '{text}': {str(e)}") + return None + + def estimate(self, data: List[Dict]) -> List[Dict]: + """ + Overview: + Estimate rewards for mathematical answers based on rule-based comparison. + Arguments: + - data (:obj:`List[Dict]`): The list of data queries used for estimation. + Format: [{"question": "...", "answer": "...", "response": "..."}, ...] + Each dictionary may contain: + - question: The mathematical question + - answer: The ground truth answer + - response: The model's response to evaluate + - system: Optional system prompt + - query: Optional alternative to question + Returns: + - rewards (:obj:`List[Dict]`): The estimated rewards. + Each dictionary contains: + - reward: The numerical reward value + - metadata: Additional information about the evaluation + Examples: + >>> data = [{ + >>> "question": "What is 2+2?", + >>> "answer": "4", + >>> "response": "The answer is 4." + >>> }] + >>> results = model.estimate(data) + >>> print(results[0]["reward"]) # 1.0 (correct) + >>> print(results[0]["metadata"]["reason"]) # "correct_answer" + """ + rewards = [] + + for item in data: + result = { + "reward": self.cfg.format_error_reward, + "metadata": { + "reason": "format_error", + "response_value": None, + "target_value": None, + "match_result": False, + "extracted_code": None, + "final_answer": None, + "extracted_expressions": [], + }, + } + + try: + # Extract question, answer and response from data item + item_data = self._extract_item_data(item) + if item_data is None: + rewards.append(result) + continue + + question, gt_answer, response = item_data + + # Process target answer + target_value = self._process_target_answer(gt_answer) + result["metadata"]["target_value"] = target_value + + # Process response answer + response_value, final_answer = self._process_response_answer(response) + result["metadata"]["response_value"] = response_value + result["metadata"]["final_answer"] = final_answer + + # Extract Python code (if any) + extracted_code = self._extract_python_code(response) + result["metadata"]["extracted_code"] = extracted_code + + # Extract all possible expressions (for debugging) + expressions = self._extract_all_expressions(response) + result["metadata"]["extracted_expressions"] = expressions + + # Determine reward based on answer comparison + result = self._determine_reward(result, target_value, response_value) + + except Exception as e: + result["metadata"]["reason"] = f"error: {str(e)}" + if self.logger: + self.logger.error(f"Error evaluating data: {str(e)}") + + rewards.append(result) + + return rewards + + def _extract_item_data(self, item) -> Optional[Tuple[str, str, str]]: + """Extract question, answer and response from data item""" + if isinstance(item, dict): + question = item.get("question", "") + gt_answer = item.get("answer", "") + response = item.get("response", "") + query = item.get("query", "") + elif isinstance(item, str): + # If input is a string, try to parse as JSON + try: + item_dict = json.loads(item) + question = item_dict.get("question", "") + gt_answer = item_dict.get("answer", "") + response = item_dict.get("response", "") + query = item_dict.get("query", "") + except: + # If parsing fails, assume the entire string is the response + question = "" + gt_answer = "" + response = item + query = "" + else: + # Unsupported input type + return None + + # If no question but query exists, use query as question + if not question and query: + question = query + + return question, gt_answer, response + + def _determine_reward( + self, + result: Dict, + target_value: Optional[float], + response_value: Optional[float], + ) -> Dict: + """Determine reward based on answer comparison""" + if target_value is None: + result["reward"] = self.cfg.format_error_reward + result["metadata"]["reason"] = "invalid_target_format" + elif response_value is None: + result["reward"] = self.cfg.format_error_reward + result["metadata"]["reason"] = "invalid_response_format" + else: + # Compare answers + is_match = self._check_answer_match(response_value, target_value) + result["metadata"]["match_result"] = is_match + + if is_match: + result["reward"] = self.cfg.correct_reward + result["metadata"]["reason"] = "correct_answer" + else: + result["reward"] = self.cfg.answer_error_reward + result["metadata"]["reason"] = "wrong_answer" + + return result + + def _extract_all_expressions(self, text: str) -> List[str]: + """Extract all possible mathematical expressions from text, sorted by priority""" + if text is None or not text.strip(): + return [] + + expressions = [] + + # Extract expressions from LaTeX math environments + self._extract_latex_environments(text, expressions) + + # Extract boxed content (highest priority) + self._extract_boxed_content(text, expressions) + + # Extract expressions after equals sign + self._extract_equals_expressions(text, expressions) + + # Extract expressions from answer phrases + self._extract_answer_phrases(text, expressions) + + # Extract LaTeX expressions + self._extract_latex_expressions(text, expressions) + + # Extract pi-related expressions + self._extract_pi_expressions_for_list(text, expressions) + + # Extract all numbers (lowest priority) + self._extract_numbers(text, expressions) + + # Remove duplicates while preserving order + unique_expressions = [] + for expr in expressions: + if expr not in unique_expressions: + unique_expressions.append(expr) + + return unique_expressions + + def _extract_latex_environments(self, text: str, expressions: List[str]) -> None: + """Extract expressions from LaTeX math environments""" + # Match \(...\) or $...$ format LaTeX expressions + latex_envs = re.findall(r"\\\\?\((.+?)\\\\?\)", text) + re.findall(r"\$(.+?)\$", text) + for latex_env in latex_envs: + expressions.append(latex_env.strip()) + + def _extract_boxed_content(self, text: str, expressions: List[str]) -> None: + """Extract boxed content""" + boxed_matches = re.findall(r"\\boxed\{([^}]+)\}", text) + for match in boxed_matches: + expressions.append(match.strip()) + + def _extract_equals_expressions(self, text: str, expressions: List[str]) -> None: + """Extract expressions after equals sign""" + equals_matches = re.findall(r"=\s*([^\.=]+?)(?:\.|$|=)", text) + for match in equals_matches: + expressions.append(match.strip()) + + def _extract_answer_phrases(self, text: str, expressions: List[str]) -> None: + """Extract expressions from answer phrases""" + # Extract "the answer is X" format + answer_match = re.search(r"(?:the\s+)?answer\s+is\s+([^\.]+)", text, re.IGNORECASE) + if answer_match: + expressions.append(answer_match.group(1).strip()) + + # Extract "therefore, X is the answer" format + therefore_match = re.search(r"therefore,?\s+([^\.]+)\s+is\s+the\s+answer", text, re.IGNORECASE) + if therefore_match: + expressions.append(therefore_match.group(1).strip()) + + def _extract_latex_expressions(self, text: str, expressions: List[str]) -> None: + """Extract LaTeX expressions""" + # Extract LaTeX fraction expressions + frac_matches = re.findall(r"\\frac\{([^}]+)\}\{([^}]+)\}", text) + for num, denom in frac_matches: + expressions.append(f"\\frac{{{num}}}{{{denom}}}") + + # Extract LaTeX square root expressions + sqrt_matches = re.findall(r"\\sqrt\{([^}]+)\}", text) + for inner in sqrt_matches: + expressions.append(f"\\sqrt{{{inner}}}") + + # Extract all LaTeX expressions + latex_expressions = re.findall(r"\\[a-zA-Z]+(?:\{[^}]*\})+", text) + for expr in latex_expressions: + if expr not in expressions: + expressions.append(expr) + + def _extract_pi_expressions_for_list(self, text: str, expressions: List[str]) -> None: + """Extract pi-related expressions for the expressions list""" + # Replace \pi with π for unified processing + text_with_pi = text.replace("\\pi", "π") + + # Extract expressions like 11π/6 + pi_div_matches = re.findall(r"(\d+)π/(\d+)", text_with_pi) + for num, denom in pi_div_matches: + expressions.append(f"{num}π/{denom}") + + # Extract expressions like π/2 + pi_simple_div_matches = re.findall(r"π/(\d+)", text_with_pi) + for denom in pi_simple_div_matches: + expressions.append(f"π/{denom}") + + # Extract expressions like 2π + pi_mult_matches = re.findall(r"(\d+)π", text_with_pi) + for num in pi_mult_matches: + expressions.append(f"{num}π") + + # Extract standalone π + if "π" in text_with_pi: + expressions.append("π") + + def _extract_numbers(self, text: str, expressions: List[str]) -> None: + """Extract all numbers""" + numbers = re.findall(r"-?\d*\.?\d+", text) + expressions.extend(numbers) + + # rule-based reward model does not need training, thus the following methods are empty + def train(self): + """Training method (not needed for rule-based reward model)""" + pass + + def collect_data(self, data: list) -> None: + """Data collection method (not needed for rule-based reward model)""" + pass + + def clear_data(self) -> None: + """Data clearing method (not needed for rule-based reward model)""" + pass + + def _extract_python_code(self, text: str) -> Optional[str]: + """Extract Python code blocks from text""" + if text is None or not text.strip(): + return None + # Match code between ```python and ``` + code_blocks = re.findall(r"```python\s*(.*?)\s*```", text, re.DOTALL) + if code_blocks: + return code_blocks[-1].strip() + # Match code between ``` and ``` (without specified language) + code_blocks = re.findall(r"```\s*(.*?)\s*```", text, re.DOTALL) + if code_blocks: + return code_blocks[-1].strip() + + return None + + +def strip_sequence(text: str, pad_token: str, eos_token: str) -> str: + """ + Overview: + Remove leading and trailing sequences of padding/eos tokens from a text. + .. note:: + This function uses regular expressions to strip all consecutive occurrences + of the specified padding and end-of-sequence tokens from both the beginning + and end of the input text. Tokens in the middle of the text are preserved. + Arguments: + - text (str): The input text to be processed. + - pad_token (str): The padding token to be stripped (e.g., ""). + - eos_token (str): The end-of-sequence token to be stripped (e.g., ""). + Returns: + - cleaned_text (str): The cleaned text with leading/trailing padding/eos tokens removed. + Examples: + >>> strip_sequence("Hello", "", "") + 'Hello' + >>> strip_sequence("TestMiddleKeep", "", "") + 'TestMiddleKeep' + >>> strip_sequence("Full removal", "", "") + 'Full removal' + >>> strip_sequence("No tokens here", "", "") + 'No tokens here' + >>> strip_sequence("", "", "") + '' + """ + pad_token_escaped = re.escape(pad_token) + eos_token_escaped = re.escape(eos_token) + + # Remove leading tokens + pattern = f"^({eos_token_escaped}|{pad_token_escaped})+" + text = re.sub(pattern, "", text) + + # Remove trailing tokens + pattern = f"({eos_token_escaped}|{pad_token_escaped})+$" + text = re.sub(pattern, "", text) + return text + + +def normalize_text(text: str) -> str: + """ + Overview: + This function is designed to standardize text by: + - Converting all text to lowercase + - Replacing various punctuation marks and special characters with spaces + - Removing import statements + - Normalizing whitespace by replacing multiple spaces with a single space + - Stripping leading and trailing whitespace + Arguments: + - text (str): The input text to be processed. + Returns: + - normalized_text (str): The normalized text. + """ + text = re.sub(r"import\s[a-zA-Z\.]+(\sas\s[a-zA-Z\.]+)\n", " ", text) + text = re.sub(r"\s+", " ", text) + return text.strip() diff --git a/ding/reward_model/multi_modal_reward_model.py b/ding/reward_model/multi_modal_reward_model.py new file mode 100644 index 0000000000..9a12eb1cd4 --- /dev/null +++ b/ding/reward_model/multi_modal_reward_model.py @@ -0,0 +1,163 @@ +from typing import List, Dict +from easydict import EasyDict +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +from ding.utils import REWARD_MODEL_REGISTRY +from .base_reward_model import BaseRewardModel + + +@REWARD_MODEL_REGISTRY.register('multi_modal') +class MultiModalRewardModel(BaseRewardModel): + config = dict( + type='multi_modal', + model_name='internlm/internlm-xcomposer2d5-7b-reward', + hd_num=9, # Number of high-definition patches for image processing + ) + + def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWriter') -> None: + self.cfg = config + self.device = device + self.logger = logger + self.tb_logger = tb_logger + + self.tokenizer = AutoTokenizer.from_pretrained( + self.cfg.model_name, trust_remote_code=True, local_files_only=True + ) + self.model = AutoModelForCausalLM.from_pretrained( + self.cfg.model_name, torch_dtype=torch.float16, trust_remote_code=True + ) + + self.model.tokenizer = self.tokenizer + self.model.cuda().eval() + + def estimate(self, data: List[Dict], image: List[str], output_mode: str = 'score') -> List[Dict]: + """ + Estimate rewards for multi-modal inputs using internlm-xcomposer model. + + Arguments: + data (List[Dict]): List of chat dictionaries, each containing: + - chat (List[Dict]): List of messages, each message is a dict with: + - role (str): Either "user" or "assistant" + - content (str): The message content + image (List[str]): List of image paths. If fewer images than chats, last image will be reused + output_mode (str, optional): Evaluation mode. Defaults to 'score'. + - 'score': Return reward scores for each chat + - 'rank': Return ranking indices (0 is best) for all chats + - 'compare': Compare first two chats (returns 1.0 for better, 0.0 for worse) + + Returns: + List[Dict]: Results depending on output_mode: + - For 'score' mode: + [{ + 'reward': float, # Reward score + 'metadata': { + 'mode': 'score', + 'chat_idx': int, # Index of the chat + 'image_path': str # Path of the image used + } + }, ...] + - For 'rank' mode: + [{ + 'rank': int, # Ranking position (0 is best) + 'metadata': { + 'mode': 'rank', + 'chat_idx': int, + 'image_path': str + } + }, ...] + - For 'compare' mode: + [{ + 'reward': float, # 1.0 for better, 0.0 for worse + 'metadata': { + 'mode': 'compare', + 'chat_idx': int, + 'image_path': str, + 'compared_with': int # Index of the compared chat + } + }, ...] + """ + # Get chat data + chats = [item['chat'] for item in data] + + with torch.autocast(device_type='cuda', dtype=torch.float16): + if output_mode == 'score': + # Ensure each chat has a corresponding image, use the last image if not enough + if len(image) < len(chats): + image = image + [image[-1]] * (len(chats) - len(image)) + + # Get scores for each chat + scores = [] + for chat, img in zip(chats, image): + score = self.model.get_score(chat, [img], hd_num=self.cfg.hd_num) + scores.append(score) + + return [ + { + 'reward': float(score), + 'metadata': { + 'mode': 'score', + 'chat_idx': idx, + 'image_path': img + } + } for idx, (score, img) in enumerate(zip(scores, image)) + ] + + elif output_mode == 'rank': + # Use the first image for ranking + img = image[0] + ranks = self.model.rank(chats, [[img]] * len(chats), hd_num=self.cfg.hd_num) + + return [ + { + 'rank': int(rank), + 'metadata': { + 'mode': 'rank', + 'chat_idx': idx, + 'image_path': img + } + } for idx, rank in enumerate(ranks) + ] + + elif output_mode == 'compare': + if len(data) < 2: + raise ValueError("Compare mode requires at least 2 samples") + + # Use the first image for comparison + img = image[0] + is_better = self.model.compare(chats[0], [img], chats[1], [img], hd_num=self.cfg.hd_num) + + return [ + { + 'reward': 1.0 if is_better else 0.0, + 'metadata': { + 'mode': 'compare', + 'chat_idx': 0, + 'image_path': img, + 'compared_with': 1 + } + }, { + 'reward': 0.0 if is_better else 1.0, + 'metadata': { + 'mode': 'compare', + 'chat_idx': 1, + 'image_path': img, + 'compared_with': 0 + } + } + ] + else: + raise ValueError(f"Invalid output mode: {output_mode}") + + def train(self): + """Training is not implemented for this reward model""" + self.logger.warning("Training is not implemented for this reward model") + pass + + def collect_data(self, data: list) -> None: + """Data collection is not needed for this reward model""" + pass + + def clear_data(self) -> None: + """Data clearing is not needed for this reward model""" + pass diff --git a/ding/reward_model/tests/test_math_reward_model.py b/ding/reward_model/tests/test_math_reward_model.py new file mode 100644 index 0000000000..70a7ba4637 --- /dev/null +++ b/ding/reward_model/tests/test_math_reward_model.py @@ -0,0 +1,87 @@ +import pytest +from easydict import EasyDict +import torch +from unittest.mock import MagicMock + +from ding.reward_model import MathRewardModel + + +@pytest.mark.envtest +def test_math_reward_model(): + # Create configuration + cfg = EasyDict(dict( + type='math', + model_name='Qwen/Qwen2.5-Math-PRM-7B', + )) + + # Create mock logger and tb_logger + logger = MagicMock() + tb_logger = MagicMock() + + # Initialize reward model + model = MathRewardModel(cfg, "cuda" if torch.cuda.is_available() else "cpu", logger, tb_logger) + + # Simple math problem + data_simple = [ + { + "system": "Please reason step by step...", + "query": "What is 1 + 1?", + "response": ["First, we have 1", "Then add 1", "Therefore, 1 + 1 = 2"] + } + ] + + # Complex word problem + data_complex = [ + { + "system": "Please reason step by step, and put your final answer within \\boxed{}.", + "query": "Sue lives in a fun neighborhood...", + "response": [ + "To find out how many more pink plastic flamingos...", + "On Saturday, they take back one third of the flamingos...", + "On Sunday, the neighbors add another 18 pink plastic flamingos...", + "To find the difference, subtract the number of white flamingos..." + ] + } + ] + + # Test simple case + results_simple = model.estimate(data_simple) + + # Verify simple case results + assert len(results_simple) == 1, "Should return one result" + assert "reward" in results_simple[0], "Result should contain reward" + assert "metadata" in results_simple[0], "Result should contain metadata" + assert "step_rewards" in results_simple[0]["metadata"], "Metadata should contain step_rewards" + assert len(results_simple[0]["metadata"]["step_rewards"]) == 3, "Should have 3 step rewards" + assert results_simple[0]["metadata"]["num_steps"] == 3, "Should have 3 steps" + + # Test complex case + results_complex = model.estimate(data_complex) + + # Verify complex case results + assert len(results_complex) == 1, "Should return one result" + assert "reward" in results_complex[0], "Result should contain reward" + assert "metadata" in results_complex[0], "Result should contain metadata" + assert "step_rewards" in results_complex[0]["metadata"], "Metadata should contain step_rewards" + assert len(results_complex[0]["metadata"]["step_rewards"]) == 4, "Should have 4 step rewards" + assert results_complex[0]["metadata"]["num_steps"] == 4, "Should have 4 steps" + + # Verify reward value ranges + for result in results_simple + results_complex: + assert 0 <= result["reward"] <= 1, "Reward should be between 0 and 1" + for step_reward in result["metadata"]["step_rewards"]: + assert 0 <= step_reward <= 1, "Step rewards should be between 0 and 1" + + # Test batch processing functionality + batch_data = data_simple + data_complex + batch_results = model.estimate(batch_data) + assert len(batch_results) == 2, "Should return two results for batch processing" + + # Print detailed information for debugging + print("\nSimple problem results:") + print(f"Final reward: {results_simple[0]['reward']}") + print(f"Step rewards: {results_simple[0]['metadata']['step_rewards']}") + + print("\nComplex problem results:") + print(f"Final reward: {results_complex[0]['reward']}") + print(f"Step rewards: {results_complex[0]['metadata']['step_rewards']}") diff --git a/ding/reward_model/tests/test_math_rule_reward_model.py b/ding/reward_model/tests/test_math_rule_reward_model.py new file mode 100644 index 0000000000..d4a7600d91 --- /dev/null +++ b/ding/reward_model/tests/test_math_rule_reward_model.py @@ -0,0 +1,128 @@ +import os +import sys +import pytest +from easydict import EasyDict +from ding.reward_model.math_rule_reward_model import MathRuleRewardModel + + +@pytest.fixture +def reward_model(): + return MathRuleRewardModel( + config=EasyDict( + tokenizer_name='unsloth/Meta-Llama-3.1-8B', + type='math_rule', + format_error_reward=-2, + answer_error_reward=-1, + correct_reward=1, + ) + ) + + +@pytest.mark.envtest +def test_math_rule_reward_model_correct_answer(reward_model): + data_correct = [ + { + "system": "Please answer this math problem...", + "query": ( + "The school now introduces a new color, silver, for the flag design. " + "Crestview's school colors are now purple, gold, and silver. " + "The students are designing a flag using three solid-colored horizontal stripes. " + "Using one, two, or all three of the school colors, how many different flags " + "are possible if adjacent stripes may be the same color?" + ), + "response": ( + "Crestview's school colors—purple, gold, and silver—can be used to design " + "a flag with three horizontal stripes, where each stripe can be any of the " + "three colors and adjacent stripes may be the same. Since each of the three " + "stripes has three independent color choices, the total number of possible " + "flag designs is 27" + ), + "answer": r"27" + } + ] + + # Test the case with correct answer + rewards = reward_model.estimate(data_correct) + assert len(rewards) == len(data_correct) + assert rewards[0]['reward'] == reward_model.cfg.correct_reward + assert rewards[0]['metadata']['reason'] == 'correct_answer' + assert rewards[0]['metadata']['match_result'] + + +@pytest.mark.envtest +def test_math_rule_reward_model_wrong_answer(reward_model): + data_wrong = [ + { + "system": "Please answer this math problem...", + "query": ( + "The school now introduces a new color, silver, for the flag design. " + "Crestview's school colors are now purple, gold, and silver. " + "The students are designing a flag using three solid-colored horizontal stripes. " + "Using one, two, or all three of the school colors, how many different flags " + "are possible if adjacent stripes may be the same color?" + ), + "response": ( + r"The given point \(\left(\frac{\sqrt{3}}{2}, -\frac{1}{2}\right)\) lies on " + r"the unit circle, meaning its coordinates correspond to \((\cos \alpha, " + r"\sin \alpha)\). Since \(\cos \alpha = \frac{\sqrt{3}}{2}\) and " + r"\(\sin \alpha = -\frac{1}{2}\), the angle \(\alpha\) is in the " + r"**fourth quadrant**, where the reference angle is \(\frac{\pi}{6}\). " + r"Therefore, the smallest positive value of \(\alpha\) is " + r"\(2\pi - \frac{\pi}{6} = \frac{17\pi}{6}\)." + ), + "answer": r"\frac{11\pi}{6}" + } + ] + + rewards = reward_model.estimate(data_wrong) + assert len(rewards) == len(data_wrong) + assert rewards[0]['reward'] == reward_model.cfg.answer_error_reward + assert rewards[0]['metadata']['reason'] == 'wrong_answer' + assert rewards[0]['metadata']['match_result'] is False + + +@pytest.mark.envtest +def test_math_rule_reward_model_format_error(reward_model): + data_format_error = [ + { + "system": "Please answer this math problem...", + "query": "What is 2+2?", + "response": "The answer is four.", + "answer": r"4" + } + ] + rewards_format = reward_model.estimate(data_format_error) + assert len(rewards_format) == len(data_format_error) + # This should be a format error because "four" cannot be processed as a numerical value + assert rewards_format[0]['reward'] == reward_model.cfg.format_error_reward + assert 'format' in rewards_format[0]['metadata']['reason'] + + +@pytest.mark.envtest +def test_math_rule_reward_model_special_expressions(reward_model): + data_edge_cases = [ + { + "query": "What is 1/2?", + "response": r"The answer is \frac{1}{2}.", + "answer": r"0.5" + }, { + "query": "What is 50%?", + "response": "The answer is 50%.", + "answer": r"0.5" + }, { + "query": "What is sqrt(4)?", + "response": r"The answer is \sqrt{4} = 2.", + "answer": r"2" + } + ] + rewards_edge = reward_model.estimate(data_edge_cases) + assert len(rewards_edge) == len(data_edge_cases) + # Check fraction processing + assert rewards_edge[0]['metadata']['match_result'] + assert rewards_edge[0]['reward'] == reward_model.cfg.correct_reward + # Check percentage processing + assert rewards_edge[1]['metadata']['match_result'] + assert rewards_edge[1]['reward'] == reward_model.cfg.correct_reward + # Check square root processing + assert rewards_edge[2]['metadata']['match_result'] + assert rewards_edge[2]['reward'] == reward_model.cfg.correct_reward diff --git a/ding/reward_model/tests/test_multi_modal_reward_model.py b/ding/reward_model/tests/test_multi_modal_reward_model.py new file mode 100644 index 0000000000..7e87aa6f8a --- /dev/null +++ b/ding/reward_model/tests/test_multi_modal_reward_model.py @@ -0,0 +1,121 @@ +import pytest +from easydict import EasyDict +import torch +from ding.reward_model import MultiModalRewardModel +from unittest.mock import MagicMock +import os + + +@pytest.fixture +def reward_model(): + # Create configuration + cfg = EasyDict(dict( + type='multi_modal', + model_name='internlm/internlm-xcomposer2d5-7b-reward', + hd_num=9, + )) + + # Create mock logger and tb_logger + logger = MagicMock() + tb_logger = MagicMock() + + # Initialize reward model + model = MultiModalRewardModel(cfg, "cuda" if torch.cuda.is_available() else "cpu", logger, tb_logger) + return model + + +@pytest.fixture +def test_data(): + # Shared test data + chats = [ + [ # chat_1 + {"role": "user", "content": 'I want to buy a car from the input image, ' + 'analyze the advantages and weaknesses.'}, + {"role": "assistant", "content": "The car in the image is a Mercedes-Benz G-Class..."} + ], + [ # chat_2 + {"role": "user", "content": 'I want to buy a car from the input image, ' + 'analyze the advantages and weaknesses.'}, + {"role": "assistant", "content": "Based on the image, it appears to be a Ferrari F8 Tributo..."} + ] + ] + + images = ['./examples/cars1.jpg'] + + return {'chats': chats, 'images': images, 'hd_num': 9} + + +@pytest.mark.envtest +def test_single_score(reward_model, test_data): + """Test single chat scoring""" + data = [{'chat': test_data['chats'][0]}] + + results = reward_model.estimate(data, test_data['images'], output_mode='score') + print(f"Single score results: {results}") + + assert len(results) == 1 + assert 'reward' in results[0] + assert isinstance(results[0]['reward'], float) + assert results[0]['metadata']['mode'] == 'score' + assert results[0]['metadata']['chat_idx'] == 0 + + +@pytest.mark.envtest +def test_multiple_scores(reward_model, test_data): + """Test multiple chats scoring""" + data = [{'chat': test_data['chats'][0]}, {'chat': test_data['chats'][1]}] + + results = reward_model.estimate(data, test_data['images'], output_mode='score') + print(f"Multiple scores results: {results}") + + assert len(results) == 2 + assert all('reward' in r for r in results) + assert all(isinstance(r['reward'], float) for r in results) + assert all(r['metadata']['mode'] == 'score' for r in results) + + +@pytest.mark.envtest +def test_rank(reward_model, test_data): + """Test ranking functionality""" + data = [{'chat': test_data['chats'][0]}, {'chat': test_data['chats'][1]}] + + results = reward_model.estimate(data, test_data['images'], output_mode='rank') + print(f"Rank results: {results}") + + assert len(results) == 2 + assert all('rank' in r for r in results) + assert set(r['rank'] for r in results) == {0, 1} + + +@pytest.mark.envtest +def test_compare(reward_model, test_data): + """Test comparison functionality""" + data = [{'chat': test_data['chats'][0]}, {'chat': test_data['chats'][1]}] + + results = reward_model.estimate(data, test_data['images'], output_mode='compare') + print(f"Compare results: {results}") + + assert len(results) == 2 + assert sum(r['reward'] for r in results) == 1.0 + assert all(r['metadata']['mode'] == 'compare' for r in results) + + +@pytest.mark.envtest +def test_default_parameters(reward_model, test_data): + """Test default parameters""" + data = [{'chat': test_data['chats'][0]}] + + # Test without specifying optional parameters + results = reward_model.estimate(data, test_data['images']) + + assert len(results) == 1 + assert 'reward' in results[0] + assert results[0]['metadata']['mode'] == 'score' + + +@pytest.mark.envtest +def test_error_handling(reward_model, test_data): + """Test error handling""" + with pytest.raises(Exception): + # Test invalid input format + reward_model.model.get_score(None, test_data['image'], hd_num=test_data['hd_num'])