diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 6b1bfc709..15234f928 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -7,16 +7,18 @@ # Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml import asyncio -import logging +import time import uuid from dataclasses import dataclass -from typing import Any, Callable, Optional +from typing import Any, Callable import torch import torch.nn.functional as F +import torchstore as ts from datasets import load_dataset from forge.actors.policy import Policy from forge.actors.replay_buffer import ReplayBuffer +from forge.actors.trainer import _qwen3_hf_to_vllm from forge.cli.config import parse from forge.controller.actor import ForgeActor from forge.controller.service import ServiceConfig, shutdown_service, spawn_service @@ -26,12 +28,10 @@ from omegaconf import DictConfig from src.forge.data.utils import exclude_service from torch import nn +from torchstore.state_dict_utils import DELIM from transformers import AutoModelForCausalLM from vllm.transformers_utils.tokenizer import get_tokenizer -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - def compute_logprobs( logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 @@ -50,25 +50,21 @@ def compute_logprobs( class SimpleGRPOLoss(nn.Module): """Simplified GRPO Loss for simplified single step updates - Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py. + Inspired by the Hugging Face TRL implementation: + https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624. """ - def __init__(self, epsilon=0.1, beta=0.1): + def __init__(self, beta: float = 0.1): super().__init__() - self.epsilon = epsilon self.beta = beta def forward(self, logprobs, ref_logprobs, advantages, padding_mask): - per_token_kl = ( - torch.exp(ref_logprobs.detach() - logprobs) - - (ref_logprobs.detach() - logprobs) - - 1 - ) + kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages - per_token_loss = -(per_token_policy_loss - self.beta * per_token_kl) + per_token_loss = -(per_token_policy_loss - self.beta * kl) loss = ( - (per_token_loss * padding_mask).sum(dim=1) - / (padding_mask.sum(dim=1) + 1e-8) + ((per_token_loss * padding_mask).sum(dim=1)) + / (padding_mask.sum(dim=1).clamp(min=1.0)) ).mean() return loss @@ -82,14 +78,14 @@ class Episode: pad_id: int request_len: int response_len: int - target: Optional[Any] = None + target: Any | None = None # processed data - response: Optional[str] = None - request_tokens: Optional[list[int]] = None - response_tokens: Optional[list[int]] = None - ref_logprobs: Optional[torch.Tensor] = None - reward: Optional[float] = None - advantage: Optional[float] = None + response: str | None = None + request_tokens: list[int] | None = None + response_tokens: list[int] | None = None + ref_logprobs: torch.Tensor | None = None + reward: float | None = None + advantage: float | None = None @property def request_tensor(self): @@ -126,7 +122,7 @@ def new_group( target: Any = None, ): episodes = [] - for i in range(group_size): + for _ in range(group_size): episodes.append( Episode( episode_id=str(uuid.uuid4()), @@ -148,17 +144,15 @@ class Trainer(ForgeActor): model_name: str learning_rate: float = 1e-5 beta: float = 0.1 - epsilon: float = 0.1 device: torch.device | None = None + state_dict_key: str = "model_state_dict" dp_rank: int = 0 # TODO: support data parallelism, hard code it for now @endpoint - def setup(self): - # Set device + async def setup(self): if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # Initialize model self.model = AutoModelForCausalLM.from_pretrained( self.model_name, dtype=torch.bfloat16, @@ -166,60 +160,59 @@ def setup(self): ).to(self.device) self.model.train() - # Initialize optimizer self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.learning_rate ) self.optimizer.zero_grad() - # Initialize loss - self.loss = SimpleGRPOLoss(self.epsilon, self.beta) + self.loss = SimpleGRPOLoss(self.beta) - self.logger.info(f"Model initialized on {self.device}") + self.logger.info(f"Trainer model initialized on {self.device}") @endpoint - async def train_step(self, batch: list[Episode]): - batch = batch[self.dp_rank] - pad_id = batch[0].pad_id + async def train_step(self, batch: list[list[Episode]]): + microbatch = batch[self.dp_rank] + pad_id = microbatch[0].pad_id # prepare batch - request = [e.request_tensor for e in batch] + request = [e.request_tensor for e in microbatch] request = torch.stack(request).to(self.device) # [b x s] - response = [e.response_tensor for e in batch] + response = [e.response_tensor for e in microbatch] response = torch.stack(response).to(self.device) # [b x s] - ref_logprobs = [e.ref_logprobs for e in batch] + ref_logprobs = [e.ref_logprobs for e in microbatch] ref_logprobs = torch.stack(ref_logprobs).to(self.device).squeeze() # [b x s] - advantages = [e.advantage for e in batch] + advantages = [e.advantage for e in microbatch] advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1] del batch - # compute policy logprobs input_ids = torch.cat([request, response], dim=1) mask = input_ids != pad_id logits = self.model(input_ids=input_ids, attention_mask=mask).logits logprobs = compute_logprobs(logits, response) del logits - # compute loss mask = response != pad_id loss = self.loss(logprobs, ref_logprobs, advantages, mask) - - self.optimizer.zero_grad() loss.backward() - - # # Gradient clipping (optional but recommended for stability) - # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) - return {"loss": loss.item()} + return loss.item() @endpoint - async def push_weights(self): - pass + async def push_weights(self, version: int): + """Update policy model weights with trainer's current weights.""" + key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id + new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28) + start_time = time.time() + await ts.put_state_dict(new_sd, key) + end_time = time.time() + self.logger.debug( + f"Pushed weights to {key} in {end_time - start_time:.2f} seconds" + ) @dataclass @@ -230,11 +223,11 @@ class RewardActor(ForgeActor): @endpoint async def evaluate_response(self, prompt: str, response: str, target: str) -> float: - total_reward = 0.0 + total_rewards = 0.0 for reward_fn in self.reward_functions: reward = reward_fn(prompt, response, target) - total_reward += reward - return total_reward + total_rewards += reward + return total_rewards / len(self.reward_functions) class ComputeAdvantages(ForgeActor): @@ -243,18 +236,11 @@ class ComputeAdvantages(ForgeActor): @endpoint async def compute(self, group: Group) -> list[float]: # TODO: add batch processing - rewards = torch.Tensor([[e.reward for e in group.episodes]]) + rewards = torch.tensor([[e.reward for e in group.episodes]]) mean = rewards.mean(1, keepdim=True) std = rewards.std(1, keepdim=True) - - # if std is nan, return 0s. Remove this before shipping - if std.isnan().any(): - advantages = torch.zeros_like(rewards) - else: - advantages = (rewards - mean) / (std + 1e-4) - - x = advantages.squeeze(0).tolist() - return x + advantages = (rewards - mean) / (std + 1e-4) + return advantages.squeeze(0).tolist() class RefModel(ForgeActor): @@ -297,16 +283,24 @@ class DatasetActor(ForgeActor): revision: str = "main" data_split: str = "train" streaming: bool = True - model: str = "Qwen/Qwen3-1.7B-Base" + model: str = "Qwen/Qwen3-1.7B" @endpoint def setup(self): - self.tokenizer = get_tokenizer(self.model) + self._tokenizer = get_tokenizer(self.model) def gsm8k_transform(sample): + system_prompt = """ + Put all your scratchpad work between and tags. + Your final answer should be between and tags otherwise it will not be scored. + """ request: str = sample["question"] - formatted_request = self.tokenizer.apply_chat_template( - [{"role": "user", "content": request}], + as_chat = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": request}, + ] + formatted_request = self._tokenizer.apply_chat_template( + as_chat, tokenize=False, add_generation_prompt=True, ) @@ -330,7 +324,7 @@ async def sample(self) -> dict[str, str] | None: @endpoint async def pad_token(self): - return self.tokenizer.pad_token_id + return self._tokenizer.pad_token_id async def main(cfg: DictConfig): @@ -340,15 +334,14 @@ async def main(cfg: DictConfig): model = cfg.model max_req_tokens = cfg.max_req_tokens max_res_tokens = cfg.max_res_tokens - - # ---- Setup WandB Logger ---- # - logger = get_metric_logger( + mlogger = get_metric_logger( "wandb", freq=1, project="grpo-training", ) # ---- Setup services ---- # + await ts.initialize() ( dataloader, policy, @@ -371,7 +364,6 @@ async def main(cfg: DictConfig): spawn_service( ServiceConfig(**cfg.trainer.service), Trainer, - model_name=model, **exclude_service(cfg.trainer), ), spawn_service( @@ -407,7 +399,8 @@ async def continuous_rollouts(): print("Dataloader is empty, exiting continuous rollout") return prompt, target = sample["request"], sample["target"] - version = 0 # await policy.get_current_version.choose() + responses = await policy.generate.choose(prompt) + version = await policy.get_version.choose() group = Group.new_group( group_id=rollout_count, group_size=group_size, @@ -419,12 +412,11 @@ async def continuous_rollouts(): target=target, ) - responses = await policy.generate.choose(prompt) - + # TODO: Parallelize the following calculation for episode, response in zip(group.episodes, responses.outputs): episode.request_tokens = responses.prompt_token_ids episode.response_tokens = response.token_ids - assert len(response.token_ids) <= max_res_tokens + episode.response = response.text episode.ref_logprobs = await ref_model.forward.choose(episode) episode.reward = await reward_actor.evaluate_response.choose( prompt=prompt, response=response.text, target=target @@ -434,30 +426,33 @@ async def continuous_rollouts(): episode.advantage = advantage await replay_buffer.add.choose(episode) + avg_response_len = ( + sum(len(e.response_tokens) for e in group.episodes) / group_size + ) + mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count) + buffer_size = await replay_buffer._numel.choose() + mlogger.log("buffer_size/rollout", buffer_size, rollout_count) + avg_reward = sum(e.reward for e in group.episodes) / group_size + mlogger.log("avg_reward/rollout", avg_reward, rollout_count) + rollout_count += 1 - if rollout_count % 10 == 0: - avg_reward = sum(e.reward for e in group.episodes) / len(group.episodes) - print( - f"Generated {rollout_count} rollouts w/ average reward {avg_reward}" - ) - logger.log("reward_per_rollout", avg_reward, rollout_count) async def continuous_training(): training_step = 0 + policy_version = 0 while True: - batch = await replay_buffer.sample.choose(curr_policy_version=0) + batch = await replay_buffer.sample.choose( + curr_policy_version=policy_version + ) if batch is None: await asyncio.sleep(0.1) else: - training_result = await trainer.train_step.choose(batch) + loss = await trainer.train_step.choose(batch) training_step += 1 - if training_step % 10 == 0: - print(f"Completed {training_step} training steps") - if training_result: - loss_value = training_result.get("loss", 0.0) - print(f"Latest loss: {loss_value}") - logger.log("loss/training_step", loss_value, training_step) - # await trainer.update_weights(policy) + mlogger.log("loss/training_step", loss, training_step) + await trainer.push_weights.call(policy_version) + policy_version += 1 + await policy.update_weights.call() print("Starting GRPO training loops...") # TODO: Start multiple rollouts once all serivces support it @@ -483,10 +478,10 @@ async def continuous_training(): ) -@parse -def recipe_main(cfg: DictConfig) -> None: - asyncio.run(main(cfg)) +if __name__ == "__main__": + @parse + def _main(cfg): + asyncio.run(main(cfg)) -if __name__ == "__main__": - recipe_main() + _main() # @parse grabs the cfg from CLI diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 8ba96a096..6fc60bf53 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -1,11 +1,11 @@ # GRPO Training Configuration # Global configuration -group_size: 4 -batch_size: 4 +group_size: 8 +batch_size: 16 max_req_tokens: 512 -max_res_tokens: 128 -model: "Qwen/Qwen3-1.7B-Base" +max_res_tokens: 512 +model: "Qwen/Qwen3-1.7B" # Dataset configuration dataset: @@ -13,6 +13,7 @@ dataset: revision: "main" data_split: "train" streaming: true + model: ${model} service: procs_per_replica: 1 num_replicas: 1 @@ -24,10 +25,10 @@ policy: model: ${model} tensor_parallel_size: 1 pipeline_parallel_size: 1 - enforce_eager: true + enforce_eager: false sampling_config: - n: 4 - max_tokens: 128 + n: ${group_size} + max_tokens: ${max_res_tokens} temperature: 1.0 top_p: 1.0 service: @@ -37,6 +38,7 @@ policy: # Trainer configuration trainer: + model_name: ${model} learning_rate: 1e-5 service: procs_per_replica: 1 @@ -46,7 +48,7 @@ trainer: # Replay buffer configuration replay_buffer: batch_size: ${batch_size} - max_policy_age: 0 + max_policy_age: 1 # Async by 1 dp_size: 1 service: procs_per_replica: 1 diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 972fda18c..fd9f11482 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -5,24 +5,23 @@ # LICENSE file in the root directory of this source tree. import asyncio -import logging import os import sys +import time from collections.abc import Mapping from copy import copy from dataclasses import asdict, dataclass, field, fields -from typing import Dict, List import torch +import torchstore as ts from monarch.actor import current_rank, endpoint, ProcMesh -from torchstore import MultiProcessStore -from torchstore._state_dict_utils import DELIM +from torchstore.state_dict_utils import DELIM from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.utils import _validate_truncation_size from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs from vllm.lora.request import LoRARequest -from vllm.outputs import CompletionOutput +from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext @@ -45,9 +44,6 @@ from forge.types import ProcessConfig -logger = logging.getLogger(__name__) - - @dataclass class SamplingConfig: """ @@ -111,13 +107,13 @@ class Policy(PolicyInterface): lora_request: LoRARequest | None = None tokenization_kwargs: dict = field(default_factory=dict) policy_worker: "PolicyWorker" = None - store: MultiProcessStore | None = None def __post_init__(self): self._run_task: asyncio.Task | None = None self._policy_proc: ProcMesh | None = None self._worker_procs: ProcMesh | None = None self.weights_version: int = 0 + self.running = False if isinstance(self.engine_config, Mapping): self.engine_config = EngineConfig.from_dict(self.engine_config) if isinstance(self.sampling_config, Mapping): @@ -131,7 +127,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] engine_config: EngineConfig | Mapping = EngineConfig(), sampling_config: SamplingConfig | Mapping = SamplingConfig(), available_devices: str | None = None, - store: MultiProcessStore | None = None, **kwargs, ) -> "Policy": # Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES @@ -164,7 +159,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] sampling_config=sampling_config, available_devices=available_devices, policy_worker=workers, - store=store, ) policy._policy_proc = policy_proc policy._worker_procs = worker_procs @@ -192,10 +186,10 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] async def setup(self): # Set up policy_worker assert self.policy_worker is not None, "Policy worker should not be None" - await self.policy_worker.setup.call(store=self.store) + await self.policy_worker.setup.call() self.request_id = 0 - self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} + self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} self.vllm_args = await self.policy_worker.get_vllm_args.choose() # Setup sampling params @@ -239,12 +233,21 @@ def start_processing(self): self._run_task = asyncio.create_task(self.run()) @endpoint - async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]: + async def generate(self, prompt: str, priority: int = 0) -> RequestOutput: + """Generate a response for the given prompt + + Args: + prompt (str): The prompt to generate a response for. + priority (int, optional): The priority of the request. Defaults to 0. + + Returns: + RequestOutput: vLLM class with the generated response. + """ self.request_id += 1 % sys.maxsize request_id = str(self.request_id) # implement from a counter # Wraps prompt into a dict - prompt: Dict[str, str] = convert_input(prompt=prompt) + prompt_dict: dict[str, str] = convert_input(prompt=prompt) # truncate prmpt tokenization_kwargs = self.tokenization_kwargs or {} @@ -259,7 +262,7 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu # process and tokenize prompt prompt_str, request = self.processor.process_inputs( request_id=request_id, - prompt=prompt, + prompt=prompt_dict, params=self.sampling_params, arrival_time=None, lora_request=self.lora_request, @@ -331,25 +334,27 @@ async def run(self): engine_core_timestamp=outputs.timestamp, iteration_stats=None, ) + for request_output in processed_outputs.request_outputs: if request_output.finished: _, fut = self.requests.pop(request_output.request_id) fut.set_result(request_output) @endpoint - async def update_weights(self) -> int: - """Update the policy weights.""" - # Wait for all current requests to finish, then publish model weights - futures = [fut for _, fut in self.requests.values()] - if futures: - await asyncio.gather(*futures) - new_version = self.weights_version + 1 - await self.policy_worker.update.call(version=new_version) - self.weights_version = new_version - return self.weights_version + async def update_weights(self): + # TODO: If generating long sequences, this might be long and will block policy weight updates + curr_requests = [fut for _, fut in self.requests.values()] + if curr_requests: + self.logger.debug(f"Waiting for {len(curr_requests)} pending requests") + await asyncio.gather(*curr_requests) + + self.logger.debug(f"Starting weight update on {self.__class__.__name__}") + await self.policy_worker.update.call(version=self.weights_version) + self.weights_version += 1 + self.logger.info(f"Weight update completed (now v{self.weights_version})") @endpoint - async def _get_model_params(self) -> Dict[str, torch.Tensor]: + async def _get_model_params(self) -> dict[str, torch.Tensor]: """Get the current model parameters. Only for testing purposes.""" model_params = await self.policy_worker._get_model_params.choose() return model_params @@ -391,8 +396,7 @@ def __post_init__(self): self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS) @endpoint - async def setup(self, store: MultiProcessStore = None): - self.torchstore = store + async def setup(self): # TODO: remove ["gpus"] when monarch implements a flat rank self.rank = current_rank()["gpus"] self.worker = self.setup_worker() @@ -407,9 +411,6 @@ async def _load_tensor_parallel_state_dict( """ Load full state dict from torchstore into tensor parallel model with deterministic sharding. """ - - updated_count = 0 - # setting explictly to llama3 for now as its our only use case sharding = VLLMSharding( self.vllm_args.parallel_config.tensor_parallel_size, self.rank ) @@ -419,7 +420,7 @@ async def _load_tensor_parallel_state_dict( # Load the full tensor from torchstore # TODO: only get the part of the tensor that is needed - stored_tensor = await self.torchstore.get( + stored_tensor = await ts.get( f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}" ) sharding.load_from_source_to_target( @@ -428,23 +429,17 @@ async def _load_tensor_parallel_state_dict( current_tensor, ) - updated_count += 1 - @endpoint async def update(self, version: int): """Update model weights by reading state dict from torchstore""" - if self.torchstore is None: - raise Exception("No torchstore configured, skipping model update") - - logger.debug( - f"Starting model update from torchstore with key: {self.state_dict_key}{DELIM}{version}" - ) - + key = f"{self.state_dict_key}{DELIM}{version}" model = self.worker.model_runner.model current_state_dict = model.state_dict() - + start = time.time() await self._load_tensor_parallel_state_dict(current_state_dict, version) - logger.debug("Successfully updated model weights from torchstore") + self.logger.debug( + f"Loaded state dict from {key} in {time.time() - start} seconds" + ) @endpoint async def setup_kv_cache(self): @@ -481,7 +476,7 @@ async def get_vllm_args(self): return self.vllm_args @endpoint - async def _get_model_params(self) -> Dict[str, torch.Tensor]: + async def _get_model_params(self) -> dict[str, torch.Tensor]: model = self.worker.model_runner.model state_dict = {} @@ -514,7 +509,7 @@ def setup_worker(self): return worker -def convert_input(prompt=None, prompt_token_ids=None) -> Dict: +def convert_input(prompt=None, prompt_token_ids=None) -> dict: assert (prompt is None) ^ (prompt_token_ids is None) if prompt is not None: return {"prompt": prompt} diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 4232ca5ca..062fabe8a 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -268,3 +268,49 @@ def push_weights(self) -> None: async def cleanup(self) -> None: if self.engine.checkpointer: self.engine.checkpointer.close() + + +def _qwen3_hf_to_vllm( + sd: dict[str, torch.Tensor], num_layers: int +) -> dict[str, torch.Tensor]: + """Convert transformers state dict to vLLM format. Specifically, this fuses + QKV projection and MLP gate_up_proj layers. + + Args: + sd (dict): State dict from HF model. + num_layers (int): Number of layers in the model. + + Returns: + dict: State dict in vLLM format. + """ + load_sd = {} + + # Copy over directly mapped keys + for k in sd: + if any( + x in k + for x in [ + "down_proj", + "input_layernorm", + "post_attention_layernorm", + "o_proj", + "norm.weight", + "embed_tokens.weight", + "lm_head.weight", + ] + ): + load_sd[k] = sd[k] + + for i in range(num_layers): + prefix = f"model.layers.{i}." + # QKV fusion + q = sd[prefix + "self_attn.q_proj.weight"] + k = sd[prefix + "self_attn.k_proj.weight"] + v = sd[prefix + "self_attn.v_proj.weight"] + load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0) + # MLP gate_up_proj fusion + gate = sd[prefix + "mlp.gate_proj.weight"] + up = sd[prefix + "mlp.up_proj.weight"] + load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0) + + return load_sd diff --git a/src/forge/data/rewards.py b/src/forge/data/rewards.py index 644c69d1b..29a86fc3a 100644 --- a/src/forge/data/rewards.py +++ b/src/forge/data/rewards.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import re -from typing import Optional from forge.interfaces import Reward @@ -17,62 +16,69 @@ def __init__(self, tolerance: float = 1e-6, partial_credit: float = 0.1): self.tolerance = tolerance self.partial_credit = partial_credit - def _to_float(self, text) -> Optional[float]: - """Safely parse a string into a float, or return None if invalid.""" - if text is None: - return None - try: - return float(str(text).strip()) - except (ValueError, TypeError): - return None - - def _extract_number(self, text: str) -> Optional[float]: - """Try to extract a numeric answer from text.""" - number_pattern = r"([+-]?\d+(?:\.\d+)?(?:e[+-]?\d+)?)" - patterns = [ - r"####\s*" + number_pattern, - r"(?:the\s+)?answer\s+is\s*" + number_pattern, - r"(?:answer:|result:)\s*" + number_pattern, - r"\$" + number_pattern, # currency - number_pattern, # fallback - r"=\s*" + number_pattern + r"\s*(?:\.|$)", - r"\b" + number_pattern + r"\s*(?:\.|$)", - ] - text = text.lower().strip() - for pattern in patterns: - matches = re.findall(pattern, text) - if matches: - return self._to_float(matches[-1]) - return None - def __call__(self, prompt: str, response: str, target: str) -> float: """Compute math correctness reward.""" - # Parse expected - expected_answer = self._to_float(target) + target_number = self._to_float(target) + if target_number is None: + return 0.0 - # Parse response - model_answer = self._extract_number(response) + # Look for answer in tags + answer_match = re.search(r"(.*?)", response, re.DOTALL) - # Scoring - if expected_answer is None or model_answer is None: - return self.partial_credit # Partial credit for attempting + if answer_match: + model_answer = self._to_float(answer_match.group(1).strip()) + if ( + model_answer is not None + and abs(target_number - model_answer) < self.tolerance + ): + return 1.0 # Correct answer - if abs(expected_answer - model_answer) < self.tolerance: - return 1.0 # Correct answer - return 0.0 # Incorrect answer + # Check for partial credit: target number appears elsewhere in response + response_without_answer_tags = re.sub( + r".*?", "", response, flags=re.DOTALL + ) + # Convert to int if it's a whole number to avoid "117.0" vs "117" mismatch + target_str = ( + str(int(target_number)) + if target_number.is_integer() + else str(target_number) + ) + if target_str in response_without_answer_tags: + return self.partial_credit + + return 0.0 # No match + + def _to_float(self, text: str) -> float | None: + """Convert text to float, return None if invalid.""" + try: + # Remove common non-numeric characters like $, commas, etc. + cleaned_text = re.sub(r"[$,\s]", "", text.strip()) + return float(cleaned_text) + except (ValueError, AttributeError): + return None class ThinkingReward(Reward): """Reward class for evaluating use of tags in reasoning.""" - def __init__(self, reward_value: float = 0.5): - self.reward_value = reward_value + def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0): + self.partial_reward = partial_reward + self.full_reward = full_reward + self._THINK_BLOCK_RE = re.compile( + r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL + ) + self._THINK_TAG_ATTEMPT_RE = re.compile(r"<\s*/?\s*think\s*>", re.IGNORECASE) + + def __call__(self, prompt: str, response: str, target: str | None = None) -> float: + """Compute thinking reward.""" + if not response: + return 0.0 - def __call__( - self, prompt: str, response: str, target: Optional[str] = None - ) -> float: - """Check if response contains ... tags.""" - resp = response.lower() - if "" in resp and "" in resp: - return self.reward_value + matches = self._THINK_BLOCK_RE.findall(response) + has_well_formed = any(len(re.sub(r"\s+", "", m)) >= 1 for m in matches) + has_attempt = bool(self._THINK_TAG_ATTEMPT_RE.search(response)) or bool(matches) + if has_well_formed: + return self.full_reward + elif has_attempt: + return self.partial_reward return 0.0 diff --git a/tests/unit_tests/rl/test_math_reward.py b/tests/unit_tests/rl/test_math_reward.py index 2f3521b4d..726b1173c 100644 --- a/tests/unit_tests/rl/test_math_reward.py +++ b/tests/unit_tests/rl/test_math_reward.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import unittest -from unittest import mock from forge.data.rewards import MathReward @@ -36,6 +35,13 @@ def test_to_float_valid_numbers(self): self.assertEqual(self.reward._to_float("0"), 0.0) self.assertEqual(self.reward._to_float(" 123.45 "), 123.45) + def test_to_float_with_currency_and_formatting(self): + """Test _to_float with currency symbols and commas.""" + self.assertEqual(self.reward._to_float("$42"), 42.0) + self.assertEqual(self.reward._to_float("$1,000"), 1000.0) + self.assertEqual(self.reward._to_float("1,234.56"), 1234.56) + self.assertEqual(self.reward._to_float("$ 42.50 "), 42.5) + def test_to_float_invalid_inputs(self): """Test _to_float with invalid inputs.""" self.assertIsNone(self.reward._to_float("abc")) @@ -48,154 +54,146 @@ def test_to_float_edge_cases(self): """Test _to_float with edge cases.""" self.assertEqual(self.reward._to_float("1e6"), 1000000.0) self.assertEqual(self.reward._to_float("-1.5e-3"), -0.0015) - self.assertEqual(self.reward._to_float("inf"), float("inf")) - self.assertEqual(self.reward._to_float("-inf"), float("-inf")) - - def test_extract_number_gsm8k_format(self): - """Test _extract_number with GSM8K style format.""" - self.assertEqual(self.reward._extract_number("#### 42"), 42.0) - self.assertEqual(self.reward._extract_number("#### -3.14"), -3.14) - self.assertEqual(self.reward._extract_number("Some text #### 123.45"), 123.45) - - def test_extract_number_answer_patterns(self): - """Test _extract_number with various answer patterns.""" - self.assertEqual(self.reward._extract_number("The answer is 42"), 42.0) - self.assertEqual(self.reward._extract_number("answer is 3.14"), 3.14) - self.assertEqual(self.reward._extract_number("Answer: 123"), 123.0) - self.assertEqual(self.reward._extract_number("Result: -5.5"), -5.5) - - def test_extract_number_equals_pattern(self): - """Test _extract_number with equals sign patterns.""" - self.assertEqual(self.reward._extract_number("x = 42."), 42.0) - self.assertEqual(self.reward._extract_number("The result = 3.14"), 3.14) - self.assertEqual(self.reward._extract_number("calculation = -7.5."), -7.5) - - def test_extract_number_end_of_text(self): - """Test _extract_number with numbers at end of text.""" - self.assertEqual(self.reward._extract_number("The final result is 42."), 42.0) - self.assertEqual(self.reward._extract_number("We get 3.14"), 3.14) - self.assertEqual(self.reward._extract_number("Answer: -5.5."), -5.5) - - def test_extract_number_fallback_pattern(self): - """Test _extract_number with fallback pattern (any number).""" - self.assertEqual(self.reward._extract_number("There are 42 items"), 42.0) - self.assertEqual(self.reward._extract_number("Cost is $3.14 per item"), 3.14) - self.assertEqual(self.reward._extract_number("Temperature: -5.5 degrees"), -5.5) - - def test_extract_number_multiple_matches(self): - """Test _extract_number returns the last match when multiple numbers exist.""" - # Should return the last match from the pattern - self.assertEqual( - self.reward._extract_number("First 10, then 20, finally 30"), 30.0 - ) - self.assertEqual( - self.reward._extract_number("#### 5 but actually #### 10"), 10.0 - ) - def test_extract_number_no_match(self): - """Test _extract_number when no numbers are found.""" - self.assertIsNone(self.reward._extract_number("No numbers here")) - self.assertIsNone(self.reward._extract_number("")) - self.assertIsNone(self.reward._extract_number("Just text")) + def test_call_correct_answer_in_tags(self): + """Test __call__ with correct answers in tags.""" + self.assertEqual(self.reward("prompt", "42", "42"), 1.0) + self.assertEqual(self.reward("prompt", "3.14", "3.14"), 1.0) + self.assertEqual(self.reward("prompt", "-5.5", "-5.5"), 1.0) - def test_extract_number_case_insensitive(self): - """Test _extract_number is case insensitive.""" - self.assertEqual(self.reward._extract_number("THE ANSWER IS 42"), 42.0) - self.assertEqual(self.reward._extract_number("Answer: 3.14"), 3.14) - self.assertEqual(self.reward._extract_number("RESULT: 123"), 123.0) + def test_call_answer_tags_with_whitespace(self): + """Test __call__ with answer tags containing whitespace.""" + self.assertEqual(self.reward("prompt", " 42 ", "42"), 1.0) + self.assertEqual( + self.reward("prompt", "\n3.14\n", "3.14"), 1.0 + ) - def test_call_correct_answer(self): - """Test __call__ with correct answers.""" - self.assertEqual(self.reward("prompt", "The answer is 42", "42"), 1.0) - self.assertEqual(self.reward("prompt", "#### 3.14", "3.14"), 1.0) - self.assertEqual(self.reward("prompt", "Result: -5.5", "-5.5"), 1.0) + def test_call_answer_tags_with_complex_content(self): + """Test __call__ with complex content in answer tags.""" + response = """ + Let me solve this step by step: + First, I calculate 2 + 3 = 5 + Then, I multiply by 4: 5 * 4 = 20 + Finally, I subtract 8: 20 - 8 = 12 + 12 + """ + self.assertEqual(self.reward("prompt", response, "12"), 1.0) def test_call_within_tolerance(self): """Test __call__ with answers within tolerance.""" # Default tolerance is 1e-6 - self.assertEqual(self.reward("prompt", "42.0000001", "42"), 1.0) - self.assertEqual(self.reward("prompt", "3.1400001", "3.14"), 1.0) - - # Custom tolerance - self.assertEqual(self.custom_reward("prompt", "42.0001", "42"), 1.0) - self.assertEqual(self.custom_reward("prompt", "3.141", "3.14"), 1.0) - - def test_call_outside_tolerance(self): - """Test __call__ with answers outside tolerance.""" - self.assertEqual(self.reward("prompt", "42.1", "42"), 0.0) - self.assertEqual(self.reward("prompt", "3.15", "3.14"), 0.0) - self.assertEqual(self.custom_reward("prompt", "42.01", "42"), 0.0) - - def test_call_invalid_target(self): - """Test __call__ with invalid target values.""" self.assertEqual( - self.reward("prompt", "42", "invalid"), self.reward.partial_credit + self.reward("prompt", "42.0000001", "42"), 1.0 ) - self.assertEqual(self.reward("prompt", "42", ""), self.reward.partial_credit) self.assertEqual( - self.reward("prompt", "42", "not a number"), self.reward.partial_credit + self.reward("prompt", "3.1400001", "3.14"), 1.0 ) - def test_call_invalid_response(self): - """Test __call__ with invalid response values.""" + # Custom tolerance self.assertEqual( - self.reward("prompt", "no number", "42"), self.reward.partial_credit + self.custom_reward("prompt", "42.0001", "42"), 1.0 ) - self.assertEqual(self.reward("prompt", "", "42"), self.reward.partial_credit) self.assertEqual( - self.reward("prompt", "just text", "42"), self.reward.partial_credit + self.custom_reward("prompt", "3.141", "3.14"), 1.0 + ) + + def test_call_outside_tolerance(self): + """Test __call__ with answers outside tolerance.""" + self.assertEqual(self.reward("prompt", "42.1", "42"), 0.0) + self.assertEqual(self.reward("prompt", "3.15", "3.14"), 0.0) + self.assertEqual( + self.custom_reward("prompt", "42.01", "42"), 0.0 ) - def test_call_both_invalid(self): - """Test __call__ with both invalid target and response.""" + def test_call_partial_credit_target_in_response(self): + """Test __call__ with partial credit when target appears in response.""" + response = "The calculation shows 42 but I put 43" + self.assertEqual(self.reward("prompt", response, "42"), 0.1) + + response = "Let me work through this: 42 + 1 = 43. 43" + self.assertEqual(self.reward("prompt", response, "42"), 0.1) + + def test_call_partial_credit_custom_value(self): + """Test __call__ with custom partial credit value.""" + response = "The calculation shows 42 but I put 43" + self.assertEqual(self.custom_reward("prompt", response, "42"), 0.2) + + def test_call_no_partial_credit_with_answer_tags(self): + """Test __call__ doesn't give partial credit if target is only in answer tags.""" + response = "Let me solve this. 42" + # Target 100 is not elsewhere in response, so no partial credit + self.assertEqual(self.reward("prompt", response, "100"), 0.0) + + def test_call_integer_target_formatting(self): + """Test __call__ with integer targets formatted correctly.""" + # Integer targets should be formatted without decimal point + response = "I calculated and got 117 as the answer. 118" + self.assertEqual(self.reward("prompt", response, "117"), 0.1) + + # Should work with 117.0 in target too + self.assertEqual(self.reward("prompt", response, "117.0"), 0.1) + + def test_call_float_target_formatting(self): + """Test __call__ with float targets.""" + response = "I calculated and got 3.14 as the answer. 3.15" + self.assertEqual(self.reward("prompt", response, "3.14"), 0.1) + + def test_call_invalid_target(self): + """Test __call__ with invalid target values.""" + self.assertEqual(self.reward("prompt", "42", "invalid"), 0.0) + self.assertEqual(self.reward("prompt", "42", ""), 0.0) self.assertEqual( - self.reward("prompt", "no number", "invalid"), self.reward.partial_credit + self.reward("prompt", "42", "not a number"), 0.0 ) - self.assertEqual(self.reward("prompt", "", ""), self.reward.partial_credit) - def test_call_custom_partial_credit(self): - """Test __call__ uses custom partial credit value.""" - self.assertEqual(self.custom_reward("prompt", "no number", "42"), 0.2) - self.assertEqual(self.custom_reward("prompt", "42", "invalid"), 0.2) + def test_call_no_answer_tags(self): + """Test __call__ with response that has no answer tags.""" + # Should still check for partial credit + self.assertEqual(self.reward("prompt", "The answer is 42", "42"), 0.1) + self.assertEqual(self.reward("prompt", "No matching number", "42"), 0.0) + + def test_call_invalid_answer_in_tags(self): + """Test __call__ with invalid answer in tags.""" + response = "not a number but 42 is correct" + self.assertEqual(self.reward("prompt", response, "42"), 0.1) def test_call_zero_values(self): """Test __call__ with zero values.""" - self.assertEqual(self.reward("prompt", "0", "0"), 1.0) - self.assertEqual(self.reward("prompt", "The answer is 0", "0.0"), 1.0) + self.assertEqual(self.reward("prompt", "0", "0"), 1.0) + self.assertEqual(self.reward("prompt", "0.0", "0"), 1.0) def test_call_negative_values(self): """Test __call__ with negative values.""" - self.assertEqual(self.reward("prompt", "-42", "-42"), 1.0) - self.assertEqual(self.reward("prompt", "#### -3.14", "-3.14"), 1.0) - self.assertEqual(self.reward("prompt", "-5", "-4.9"), 0.0) + self.assertEqual(self.reward("prompt", "-42", "-42"), 1.0) + self.assertEqual(self.reward("prompt", "-3.14", "-3.14"), 1.0) def test_call_large_numbers(self): """Test __call__ with large numbers.""" - self.assertEqual(self.reward("prompt", "1000000", "1000000"), 1.0) - self.assertEqual(self.reward("prompt", "1e6", "1000000"), 1.0) - self.assertEqual(self.reward("prompt", "1000001", "1000000"), 0.0) + self.assertEqual( + self.reward("prompt", "1000000", "1000000"), 1.0 + ) + self.assertEqual(self.reward("prompt", "1e6", "1000000"), 1.0) def test_call_small_numbers(self): """Test __call__ with very small numbers.""" - self.assertEqual(self.reward("prompt", "0.000001", "0.000001"), 1.0) - self.assertEqual(self.reward("prompt", "1e-6", "0.000001"), 1.0) + self.assertEqual( + self.reward("prompt", "0.000001", "0.000001"), 1.0 + ) + self.assertEqual( + self.reward("prompt", "1e-6", "0.000001"), 1.0 + ) - def test_call_complex_response_text(self): - """Test __call__ with complex response text containing multiple elements.""" - response = """ - Let me solve this step by step: - First, I calculate 2 + 3 = 5 - Then, I multiply by 4: 5 * 4 = 20 - Finally, I subtract 8: 20 - 8 = 12 - #### 12 - """ - self.assertEqual(self.reward("prompt", response, "12"), 1.0) + def test_call_multiple_answer_tags(self): + """Test __call__ with multiple answer tags (should use first one).""" + response = "First answer: 42 Second: 43" + self.assertEqual(self.reward("prompt", response, "42"), 1.0) + self.assertEqual(self.reward("prompt", response, "43"), 0.0) - def test_call_with_units_and_formatting(self): - """Test __call__ with responses containing units and formatting.""" - self.assertEqual(self.reward("prompt", "The cost is $42.50", "42.5"), 1.0) - self.assertEqual(self.reward("prompt", "Distance: 3.14 meters", "3.14"), 1.0) - self.assertEqual(self.reward("prompt", "Temperature is -5.5°C", "-5.5"), 1.0) + # Test case where target appears outside answer tags for partial credit + response_with_partial = ( + "I think the answer is 43. 42 But 43 might be better." + ) + self.assertEqual(self.reward("prompt", response_with_partial, "43"), 0.1) if __name__ == "__main__": diff --git a/tests/unit_tests/rl/test_thinking_reward.py b/tests/unit_tests/rl/test_thinking_reward.py index 592ceb896..b95823e9a 100644 --- a/tests/unit_tests/rl/test_thinking_reward.py +++ b/tests/unit_tests/rl/test_thinking_reward.py @@ -13,29 +13,36 @@ class TestThinkingReward(unittest.TestCase): def setUp(self): """Set up test fixtures before each test method.""" self.reward = ThinkingReward() - self.custom_reward = ThinkingReward(reward_value=0.8) + self.custom_reward = ThinkingReward(partial_reward=0.3, full_reward=0.9) def test_init_default_values(self): """Test ThinkingReward initialization with default values.""" reward = ThinkingReward() - self.assertEqual(reward.reward_value, 0.5) + self.assertEqual(reward.partial_reward, 0.2) + self.assertEqual(reward.full_reward, 1.0) def test_init_custom_values(self): """Test ThinkingReward initialization with custom values.""" - reward = ThinkingReward(reward_value=0.8) - self.assertEqual(reward.reward_value, 0.8) + reward = ThinkingReward(partial_reward=0.3, full_reward=0.9) + self.assertEqual(reward.partial_reward, 0.3) + self.assertEqual(reward.full_reward, 0.9) - def test_call_with_both_tags(self): - """Test __call__ with response containing both and tags.""" - response = "This is my reasoning" - result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + def test_regex_patterns(self): + """Test that regex patterns are compiled correctly.""" + reward = ThinkingReward() + self.assertIsNotNone(reward._THINK_BLOCK_RE) + self.assertIsNotNone(reward._THINK_TAG_ATTEMPT_RE) + + def test_call_with_well_formed_thinking_block(self): + """Test __call__ with well-formed thinking blocks.""" + result = self.reward("prompt", "This is my reasoning") + self.assertEqual(result, 1.0) - result = self.custom_reward("prompt", response) - self.assertEqual(result, 0.8) + result = self.custom_reward("prompt", "This is my reasoning") + self.assertEqual(result, 0.9) - def test_call_with_both_tags_complex_content(self): - """Test __call__ with complex content between thinking tags.""" + def test_call_with_well_formed_thinking_block_complex_content(self): + """Test __call__ with complex content in thinking blocks.""" response = """ Let me solve this problem step by step. @@ -47,40 +54,58 @@ def test_call_with_both_tags_complex_content(self): The answer is 4. """ result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + self.assertEqual(result, 1.0) + + def test_call_with_minimal_content_thinking_block(self): + """Test __call__ with minimal content that still counts as well-formed.""" + result = self.reward("prompt", "x") + self.assertEqual(result, 1.0) + + def test_call_with_empty_thinking_block(self): + """Test __call__ with empty thinking block.""" + result = self.reward("prompt", "") + self.assertEqual(result, 0.2) # Should give partial reward, not full + + def test_call_with_whitespace_only_thinking_block(self): + """Test __call__ with whitespace-only thinking block.""" + result = self.reward("prompt", " \n \t ") + self.assertEqual(result, 0.2) # Should give partial reward, not full def test_call_with_only_opening_tag(self): - """Test __call__ with response containing only tag.""" - response = "This is incomplete reasoning" - result = self.reward("prompt", response) - self.assertEqual(result, 0.0) + """Test __call__ with response containing only opening tag.""" + result = self.reward("prompt", "This is incomplete reasoning") + self.assertEqual(result, 0.2) # Should give partial reward for attempt def test_call_with_only_closing_tag(self): - """Test __call__ with response containing only tag.""" - response = "This is incomplete reasoning" - result = self.reward("prompt", response) - self.assertEqual(result, 0.0) + """Test __call__ with response containing only closing tag.""" + result = self.reward("prompt", "This is incomplete reasoning") + self.assertEqual(result, 0.2) # Should give partial reward for attempt def test_call_with_no_tags(self): """Test __call__ with response containing no thinking tags.""" - response = "This is just a regular response without any thinking tags." - result = self.reward("prompt", response) + result = self.reward( + "prompt", "This is just a regular response without any thinking tags." + ) self.assertEqual(result, 0.0) def test_call_case_insensitive(self): """Test __call__ is case insensitive for thinking tags.""" - # Mixed case tags should work - response = "This is my reasoning" - result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + result = self.reward("prompt", "This is my reasoning") + self.assertEqual(result, 1.0) - response = "This is my reasoning" - result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + result = self.reward("prompt", "This is my reasoning") + self.assertEqual(result, 1.0) - response = "This is my reasoning" - result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + result = self.reward("prompt", "This is my reasoning") + self.assertEqual(result, 1.0) + + def test_call_with_whitespace_in_tags(self): + """Test __call__ with whitespace in thinking tags.""" + result = self.reward("prompt", "< think >This is my reasoning") + self.assertEqual(result, 1.0) + + result = self.reward("prompt", "<\tthink\n>Content") + self.assertEqual(result, 1.0) def test_call_multiple_thinking_blocks(self): """Test __call__ with multiple thinking blocks.""" @@ -90,54 +115,93 @@ def test_call_multiple_thinking_blocks(self): Second thought """ result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + self.assertEqual(result, 1.0) def test_call_nested_tags(self): """Test __call__ with nested or malformed tags.""" - # Nested tags - should still work as long as both tags exist - response = "Outer inner thought" + result = self.reward( + "prompt", "Outer inner thought" + ) + self.assertEqual(result, 1.0) + + def test_call_multiline_thinking_block(self): + """Test __call__ with multiline thinking blocks.""" + response = """ + This is a multiline + thinking block with + lots of content + """ result = self.reward("prompt", response) - self.assertEqual(result, 0.5) - - def test_call_empty_thinking_block(self): - """Test __call__ with empty thinking block.""" - response = "" - result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + self.assertEqual(result, 1.0) def test_call_empty_response(self): """Test __call__ with empty response.""" result = self.reward("prompt", "") self.assertEqual(result, 0.0) - def test_call_tags_with_extra_whitespace(self): - """Test __call__ with thinking tags containing extra whitespace.""" - response = "< think >This has spaces< /think >" - result = self.reward("prompt", response) - self.assertEqual(result, 0.0) # Should not match due to spaces in tags + def test_call_none_response(self): + """Test __call__ with None response.""" + result = self.reward("prompt", None) + self.assertEqual(result, 0.0) def test_call_with_target_parameter(self): """Test __call__ with target parameter (should be ignored).""" - response = "This is my reasoning" - result = self.reward("prompt", response, target="some target") - self.assertEqual(result, 0.5) + result = self.reward( + "prompt", "This is my reasoning", target="some target" + ) + self.assertEqual(result, 1.0) result = self.reward("prompt", "no tags", target="some target") self.assertEqual(result, 0.0) - def test_call_zero_reward_value(self): - """Test __call__ with zero reward value.""" - zero_reward = ThinkingReward(reward_value=0.0) - response = "This is my reasoning" - result = zero_reward("prompt", response) + result = self.reward( + "prompt", "This is my reasoning", target=None + ) + self.assertEqual(result, 1.0) + + def test_call_custom_reward_values(self): + """Test __call__ with custom reward values.""" + response_full = "This is proper reasoning" + response_partial = "" + response_none = "no thinking tags" + + # Test custom partial reward + self.assertEqual(self.custom_reward("prompt", response_full), 0.9) + self.assertEqual(self.custom_reward("prompt", response_partial), 0.3) + self.assertEqual(self.custom_reward("prompt", response_none), 0.0) + + def test_call_zero_custom_values(self): + """Test __call__ with zero custom values.""" + zero_reward = ThinkingReward(partial_reward=0.0, full_reward=0.0) + result = zero_reward("prompt", "This is my reasoning") self.assertEqual(result, 0.0) - def test_call_negative_reward_value(self): - """Test __call__ with negative reward value.""" - negative_reward = ThinkingReward(reward_value=-0.5) - response = "This is my reasoning" - result = negative_reward("prompt", response) - self.assertEqual(result, -0.5) + def test_call_negative_reward_values(self): + """Test __call__ with negative reward values.""" + negative_reward = ThinkingReward(partial_reward=-0.1, full_reward=-0.5) + + self.assertEqual( + negative_reward("prompt", "This is proper reasoning"), -0.5 + ) + self.assertEqual(negative_reward("prompt", ""), -0.1) + + def test_call_edge_case_characters(self): + """Test __call__ with edge case characters in thinking blocks.""" + result = self.reward( + "prompt", "Special chars: @#$%^&*()_+-=[]{}|;':\",./<>?`~" + ) + self.assertEqual(result, 1.0) + + def test_call_unicode_characters(self): + """Test __call__ with unicode characters in thinking blocks.""" + result = self.reward("prompt", "Unicode: αβγδε 中文 🚀") + self.assertEqual(result, 1.0) + + def test_call_very_long_thinking_block(self): + """Test __call__ with very long thinking blocks.""" + long_content = "A" * 10000 + result = self.reward("prompt", f"{long_content}") + self.assertEqual(result, 1.0) if __name__ == "__main__":