diff --git a/apps/grpo/main.py b/apps/grpo/main.py
index 6b1bfc709..15234f928 100644
--- a/apps/grpo/main.py
+++ b/apps/grpo/main.py
@@ -7,16 +7,18 @@
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
import asyncio
-import logging
+import time
import uuid
from dataclasses import dataclass
-from typing import Any, Callable, Optional
+from typing import Any, Callable
import torch
import torch.nn.functional as F
+import torchstore as ts
from datasets import load_dataset
from forge.actors.policy import Policy
from forge.actors.replay_buffer import ReplayBuffer
+from forge.actors.trainer import _qwen3_hf_to_vllm
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
@@ -26,12 +28,10 @@
from omegaconf import DictConfig
from src.forge.data.utils import exclude_service
from torch import nn
+from torchstore.state_dict_utils import DELIM
from transformers import AutoModelForCausalLM
from vllm.transformers_utils.tokenizer import get_tokenizer
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.DEBUG)
-
def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
@@ -50,25 +50,21 @@ def compute_logprobs(
class SimpleGRPOLoss(nn.Module):
"""Simplified GRPO Loss for simplified single step updates
- Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py.
+ Inspired by the Hugging Face TRL implementation:
+ https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624.
"""
- def __init__(self, epsilon=0.1, beta=0.1):
+ def __init__(self, beta: float = 0.1):
super().__init__()
- self.epsilon = epsilon
self.beta = beta
def forward(self, logprobs, ref_logprobs, advantages, padding_mask):
- per_token_kl = (
- torch.exp(ref_logprobs.detach() - logprobs)
- - (ref_logprobs.detach() - logprobs)
- - 1
- )
+ kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
- per_token_loss = -(per_token_policy_loss - self.beta * per_token_kl)
+ per_token_loss = -(per_token_policy_loss - self.beta * kl)
loss = (
- (per_token_loss * padding_mask).sum(dim=1)
- / (padding_mask.sum(dim=1) + 1e-8)
+ ((per_token_loss * padding_mask).sum(dim=1))
+ / (padding_mask.sum(dim=1).clamp(min=1.0))
).mean()
return loss
@@ -82,14 +78,14 @@ class Episode:
pad_id: int
request_len: int
response_len: int
- target: Optional[Any] = None
+ target: Any | None = None
# processed data
- response: Optional[str] = None
- request_tokens: Optional[list[int]] = None
- response_tokens: Optional[list[int]] = None
- ref_logprobs: Optional[torch.Tensor] = None
- reward: Optional[float] = None
- advantage: Optional[float] = None
+ response: str | None = None
+ request_tokens: list[int] | None = None
+ response_tokens: list[int] | None = None
+ ref_logprobs: torch.Tensor | None = None
+ reward: float | None = None
+ advantage: float | None = None
@property
def request_tensor(self):
@@ -126,7 +122,7 @@ def new_group(
target: Any = None,
):
episodes = []
- for i in range(group_size):
+ for _ in range(group_size):
episodes.append(
Episode(
episode_id=str(uuid.uuid4()),
@@ -148,17 +144,15 @@ class Trainer(ForgeActor):
model_name: str
learning_rate: float = 1e-5
beta: float = 0.1
- epsilon: float = 0.1
device: torch.device | None = None
+ state_dict_key: str = "model_state_dict"
dp_rank: int = 0 # TODO: support data parallelism, hard code it for now
@endpoint
- def setup(self):
- # Set device
+ async def setup(self):
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # Initialize model
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
dtype=torch.bfloat16,
@@ -166,60 +160,59 @@ def setup(self):
).to(self.device)
self.model.train()
- # Initialize optimizer
self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=self.learning_rate
)
self.optimizer.zero_grad()
- # Initialize loss
- self.loss = SimpleGRPOLoss(self.epsilon, self.beta)
+ self.loss = SimpleGRPOLoss(self.beta)
- self.logger.info(f"Model initialized on {self.device}")
+ self.logger.info(f"Trainer model initialized on {self.device}")
@endpoint
- async def train_step(self, batch: list[Episode]):
- batch = batch[self.dp_rank]
- pad_id = batch[0].pad_id
+ async def train_step(self, batch: list[list[Episode]]):
+ microbatch = batch[self.dp_rank]
+ pad_id = microbatch[0].pad_id
# prepare batch
- request = [e.request_tensor for e in batch]
+ request = [e.request_tensor for e in microbatch]
request = torch.stack(request).to(self.device) # [b x s]
- response = [e.response_tensor for e in batch]
+ response = [e.response_tensor for e in microbatch]
response = torch.stack(response).to(self.device) # [b x s]
- ref_logprobs = [e.ref_logprobs for e in batch]
+ ref_logprobs = [e.ref_logprobs for e in microbatch]
ref_logprobs = torch.stack(ref_logprobs).to(self.device).squeeze() # [b x s]
- advantages = [e.advantage for e in batch]
+ advantages = [e.advantage for e in microbatch]
advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1]
del batch
- # compute policy logprobs
input_ids = torch.cat([request, response], dim=1)
mask = input_ids != pad_id
logits = self.model(input_ids=input_ids, attention_mask=mask).logits
logprobs = compute_logprobs(logits, response)
del logits
- # compute loss
mask = response != pad_id
loss = self.loss(logprobs, ref_logprobs, advantages, mask)
-
- self.optimizer.zero_grad()
loss.backward()
-
- # # Gradient clipping (optional but recommended for stability)
- # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
-
self.optimizer.step()
+ self.optimizer.zero_grad(set_to_none=True)
- return {"loss": loss.item()}
+ return loss.item()
@endpoint
- async def push_weights(self):
- pass
+ async def push_weights(self, version: int):
+ """Update policy model weights with trainer's current weights."""
+ key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
+ new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28)
+ start_time = time.time()
+ await ts.put_state_dict(new_sd, key)
+ end_time = time.time()
+ self.logger.debug(
+ f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
+ )
@dataclass
@@ -230,11 +223,11 @@ class RewardActor(ForgeActor):
@endpoint
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
- total_reward = 0.0
+ total_rewards = 0.0
for reward_fn in self.reward_functions:
reward = reward_fn(prompt, response, target)
- total_reward += reward
- return total_reward
+ total_rewards += reward
+ return total_rewards / len(self.reward_functions)
class ComputeAdvantages(ForgeActor):
@@ -243,18 +236,11 @@ class ComputeAdvantages(ForgeActor):
@endpoint
async def compute(self, group: Group) -> list[float]:
# TODO: add batch processing
- rewards = torch.Tensor([[e.reward for e in group.episodes]])
+ rewards = torch.tensor([[e.reward for e in group.episodes]])
mean = rewards.mean(1, keepdim=True)
std = rewards.std(1, keepdim=True)
-
- # if std is nan, return 0s. Remove this before shipping
- if std.isnan().any():
- advantages = torch.zeros_like(rewards)
- else:
- advantages = (rewards - mean) / (std + 1e-4)
-
- x = advantages.squeeze(0).tolist()
- return x
+ advantages = (rewards - mean) / (std + 1e-4)
+ return advantages.squeeze(0).tolist()
class RefModel(ForgeActor):
@@ -297,16 +283,24 @@ class DatasetActor(ForgeActor):
revision: str = "main"
data_split: str = "train"
streaming: bool = True
- model: str = "Qwen/Qwen3-1.7B-Base"
+ model: str = "Qwen/Qwen3-1.7B"
@endpoint
def setup(self):
- self.tokenizer = get_tokenizer(self.model)
+ self._tokenizer = get_tokenizer(self.model)
def gsm8k_transform(sample):
+ system_prompt = """
+ Put all your scratchpad work between and tags.
+ Your final answer should be between and tags otherwise it will not be scored.
+ """
request: str = sample["question"]
- formatted_request = self.tokenizer.apply_chat_template(
- [{"role": "user", "content": request}],
+ as_chat = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": request},
+ ]
+ formatted_request = self._tokenizer.apply_chat_template(
+ as_chat,
tokenize=False,
add_generation_prompt=True,
)
@@ -330,7 +324,7 @@ async def sample(self) -> dict[str, str] | None:
@endpoint
async def pad_token(self):
- return self.tokenizer.pad_token_id
+ return self._tokenizer.pad_token_id
async def main(cfg: DictConfig):
@@ -340,15 +334,14 @@ async def main(cfg: DictConfig):
model = cfg.model
max_req_tokens = cfg.max_req_tokens
max_res_tokens = cfg.max_res_tokens
-
- # ---- Setup WandB Logger ---- #
- logger = get_metric_logger(
+ mlogger = get_metric_logger(
"wandb",
freq=1,
project="grpo-training",
)
# ---- Setup services ---- #
+ await ts.initialize()
(
dataloader,
policy,
@@ -371,7 +364,6 @@ async def main(cfg: DictConfig):
spawn_service(
ServiceConfig(**cfg.trainer.service),
Trainer,
- model_name=model,
**exclude_service(cfg.trainer),
),
spawn_service(
@@ -407,7 +399,8 @@ async def continuous_rollouts():
print("Dataloader is empty, exiting continuous rollout")
return
prompt, target = sample["request"], sample["target"]
- version = 0 # await policy.get_current_version.choose()
+ responses = await policy.generate.choose(prompt)
+ version = await policy.get_version.choose()
group = Group.new_group(
group_id=rollout_count,
group_size=group_size,
@@ -419,12 +412,11 @@ async def continuous_rollouts():
target=target,
)
- responses = await policy.generate.choose(prompt)
-
+ # TODO: Parallelize the following calculation
for episode, response in zip(group.episodes, responses.outputs):
episode.request_tokens = responses.prompt_token_ids
episode.response_tokens = response.token_ids
- assert len(response.token_ids) <= max_res_tokens
+ episode.response = response.text
episode.ref_logprobs = await ref_model.forward.choose(episode)
episode.reward = await reward_actor.evaluate_response.choose(
prompt=prompt, response=response.text, target=target
@@ -434,30 +426,33 @@ async def continuous_rollouts():
episode.advantage = advantage
await replay_buffer.add.choose(episode)
+ avg_response_len = (
+ sum(len(e.response_tokens) for e in group.episodes) / group_size
+ )
+ mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count)
+ buffer_size = await replay_buffer._numel.choose()
+ mlogger.log("buffer_size/rollout", buffer_size, rollout_count)
+ avg_reward = sum(e.reward for e in group.episodes) / group_size
+ mlogger.log("avg_reward/rollout", avg_reward, rollout_count)
+
rollout_count += 1
- if rollout_count % 10 == 0:
- avg_reward = sum(e.reward for e in group.episodes) / len(group.episodes)
- print(
- f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
- )
- logger.log("reward_per_rollout", avg_reward, rollout_count)
async def continuous_training():
training_step = 0
+ policy_version = 0
while True:
- batch = await replay_buffer.sample.choose(curr_policy_version=0)
+ batch = await replay_buffer.sample.choose(
+ curr_policy_version=policy_version
+ )
if batch is None:
await asyncio.sleep(0.1)
else:
- training_result = await trainer.train_step.choose(batch)
+ loss = await trainer.train_step.choose(batch)
training_step += 1
- if training_step % 10 == 0:
- print(f"Completed {training_step} training steps")
- if training_result:
- loss_value = training_result.get("loss", 0.0)
- print(f"Latest loss: {loss_value}")
- logger.log("loss/training_step", loss_value, training_step)
- # await trainer.update_weights(policy)
+ mlogger.log("loss/training_step", loss, training_step)
+ await trainer.push_weights.call(policy_version)
+ policy_version += 1
+ await policy.update_weights.call()
print("Starting GRPO training loops...")
# TODO: Start multiple rollouts once all serivces support it
@@ -483,10 +478,10 @@ async def continuous_training():
)
-@parse
-def recipe_main(cfg: DictConfig) -> None:
- asyncio.run(main(cfg))
+if __name__ == "__main__":
+ @parse
+ def _main(cfg):
+ asyncio.run(main(cfg))
-if __name__ == "__main__":
- recipe_main()
+ _main() # @parse grabs the cfg from CLI
diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml
index 8ba96a096..6fc60bf53 100644
--- a/apps/grpo/qwen3_1_7b.yaml
+++ b/apps/grpo/qwen3_1_7b.yaml
@@ -1,11 +1,11 @@
# GRPO Training Configuration
# Global configuration
-group_size: 4
-batch_size: 4
+group_size: 8
+batch_size: 16
max_req_tokens: 512
-max_res_tokens: 128
-model: "Qwen/Qwen3-1.7B-Base"
+max_res_tokens: 512
+model: "Qwen/Qwen3-1.7B"
# Dataset configuration
dataset:
@@ -13,6 +13,7 @@ dataset:
revision: "main"
data_split: "train"
streaming: true
+ model: ${model}
service:
procs_per_replica: 1
num_replicas: 1
@@ -24,10 +25,10 @@ policy:
model: ${model}
tensor_parallel_size: 1
pipeline_parallel_size: 1
- enforce_eager: true
+ enforce_eager: false
sampling_config:
- n: 4
- max_tokens: 128
+ n: ${group_size}
+ max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0
service:
@@ -37,6 +38,7 @@ policy:
# Trainer configuration
trainer:
+ model_name: ${model}
learning_rate: 1e-5
service:
procs_per_replica: 1
@@ -46,7 +48,7 @@ trainer:
# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
- max_policy_age: 0
+ max_policy_age: 1 # Async by 1
dp_size: 1
service:
procs_per_replica: 1
diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py
index 972fda18c..fd9f11482 100644
--- a/src/forge/actors/policy.py
+++ b/src/forge/actors/policy.py
@@ -5,24 +5,23 @@
# LICENSE file in the root directory of this source tree.
import asyncio
-import logging
import os
import sys
+import time
from collections.abc import Mapping
from copy import copy
from dataclasses import asdict, dataclass, field, fields
-from typing import Dict, List
import torch
+import torchstore as ts
from monarch.actor import current_rank, endpoint, ProcMesh
-from torchstore import MultiProcessStore
-from torchstore._state_dict_utils import DELIM
+from torchstore.state_dict_utils import DELIM
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs
from vllm.lora.request import LoRARequest
-from vllm.outputs import CompletionOutput
+from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
@@ -45,9 +44,6 @@
from forge.types import ProcessConfig
-logger = logging.getLogger(__name__)
-
-
@dataclass
class SamplingConfig:
"""
@@ -111,13 +107,13 @@ class Policy(PolicyInterface):
lora_request: LoRARequest | None = None
tokenization_kwargs: dict = field(default_factory=dict)
policy_worker: "PolicyWorker" = None
- store: MultiProcessStore | None = None
def __post_init__(self):
self._run_task: asyncio.Task | None = None
self._policy_proc: ProcMesh | None = None
self._worker_procs: ProcMesh | None = None
self.weights_version: int = 0
+ self.running = False
if isinstance(self.engine_config, Mapping):
self.engine_config = EngineConfig.from_dict(self.engine_config)
if isinstance(self.sampling_config, Mapping):
@@ -131,7 +127,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
engine_config: EngineConfig | Mapping = EngineConfig(),
sampling_config: SamplingConfig | Mapping = SamplingConfig(),
available_devices: str | None = None,
- store: MultiProcessStore | None = None,
**kwargs,
) -> "Policy":
# Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES
@@ -164,7 +159,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
sampling_config=sampling_config,
available_devices=available_devices,
policy_worker=workers,
- store=store,
)
policy._policy_proc = policy_proc
policy._worker_procs = worker_procs
@@ -192,10 +186,10 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
async def setup(self):
# Set up policy_worker
assert self.policy_worker is not None, "Policy worker should not be None"
- await self.policy_worker.setup.call(store=self.store)
+ await self.policy_worker.setup.call()
self.request_id = 0
- self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
+ self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
self.vllm_args = await self.policy_worker.get_vllm_args.choose()
# Setup sampling params
@@ -239,12 +233,21 @@ def start_processing(self):
self._run_task = asyncio.create_task(self.run())
@endpoint
- async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]:
+ async def generate(self, prompt: str, priority: int = 0) -> RequestOutput:
+ """Generate a response for the given prompt
+
+ Args:
+ prompt (str): The prompt to generate a response for.
+ priority (int, optional): The priority of the request. Defaults to 0.
+
+ Returns:
+ RequestOutput: vLLM class with the generated response.
+ """
self.request_id += 1 % sys.maxsize
request_id = str(self.request_id) # implement from a counter
# Wraps prompt into a dict
- prompt: Dict[str, str] = convert_input(prompt=prompt)
+ prompt_dict: dict[str, str] = convert_input(prompt=prompt)
# truncate prmpt
tokenization_kwargs = self.tokenization_kwargs or {}
@@ -259,7 +262,7 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu
# process and tokenize prompt
prompt_str, request = self.processor.process_inputs(
request_id=request_id,
- prompt=prompt,
+ prompt=prompt_dict,
params=self.sampling_params,
arrival_time=None,
lora_request=self.lora_request,
@@ -331,25 +334,27 @@ async def run(self):
engine_core_timestamp=outputs.timestamp,
iteration_stats=None,
)
+
for request_output in processed_outputs.request_outputs:
if request_output.finished:
_, fut = self.requests.pop(request_output.request_id)
fut.set_result(request_output)
@endpoint
- async def update_weights(self) -> int:
- """Update the policy weights."""
- # Wait for all current requests to finish, then publish model weights
- futures = [fut for _, fut in self.requests.values()]
- if futures:
- await asyncio.gather(*futures)
- new_version = self.weights_version + 1
- await self.policy_worker.update.call(version=new_version)
- self.weights_version = new_version
- return self.weights_version
+ async def update_weights(self):
+ # TODO: If generating long sequences, this might be long and will block policy weight updates
+ curr_requests = [fut for _, fut in self.requests.values()]
+ if curr_requests:
+ self.logger.debug(f"Waiting for {len(curr_requests)} pending requests")
+ await asyncio.gather(*curr_requests)
+
+ self.logger.debug(f"Starting weight update on {self.__class__.__name__}")
+ await self.policy_worker.update.call(version=self.weights_version)
+ self.weights_version += 1
+ self.logger.info(f"Weight update completed (now v{self.weights_version})")
@endpoint
- async def _get_model_params(self) -> Dict[str, torch.Tensor]:
+ async def _get_model_params(self) -> dict[str, torch.Tensor]:
"""Get the current model parameters. Only for testing purposes."""
model_params = await self.policy_worker._get_model_params.choose()
return model_params
@@ -391,8 +396,7 @@ def __post_init__(self):
self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS)
@endpoint
- async def setup(self, store: MultiProcessStore = None):
- self.torchstore = store
+ async def setup(self):
# TODO: remove ["gpus"] when monarch implements a flat rank
self.rank = current_rank()["gpus"]
self.worker = self.setup_worker()
@@ -407,9 +411,6 @@ async def _load_tensor_parallel_state_dict(
"""
Load full state dict from torchstore into tensor parallel model with deterministic sharding.
"""
-
- updated_count = 0
- # setting explictly to llama3 for now as its our only use case
sharding = VLLMSharding(
self.vllm_args.parallel_config.tensor_parallel_size, self.rank
)
@@ -419,7 +420,7 @@ async def _load_tensor_parallel_state_dict(
# Load the full tensor from torchstore
# TODO: only get the part of the tensor that is needed
- stored_tensor = await self.torchstore.get(
+ stored_tensor = await ts.get(
f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}"
)
sharding.load_from_source_to_target(
@@ -428,23 +429,17 @@ async def _load_tensor_parallel_state_dict(
current_tensor,
)
- updated_count += 1
-
@endpoint
async def update(self, version: int):
"""Update model weights by reading state dict from torchstore"""
- if self.torchstore is None:
- raise Exception("No torchstore configured, skipping model update")
-
- logger.debug(
- f"Starting model update from torchstore with key: {self.state_dict_key}{DELIM}{version}"
- )
-
+ key = f"{self.state_dict_key}{DELIM}{version}"
model = self.worker.model_runner.model
current_state_dict = model.state_dict()
-
+ start = time.time()
await self._load_tensor_parallel_state_dict(current_state_dict, version)
- logger.debug("Successfully updated model weights from torchstore")
+ self.logger.debug(
+ f"Loaded state dict from {key} in {time.time() - start} seconds"
+ )
@endpoint
async def setup_kv_cache(self):
@@ -481,7 +476,7 @@ async def get_vllm_args(self):
return self.vllm_args
@endpoint
- async def _get_model_params(self) -> Dict[str, torch.Tensor]:
+ async def _get_model_params(self) -> dict[str, torch.Tensor]:
model = self.worker.model_runner.model
state_dict = {}
@@ -514,7 +509,7 @@ def setup_worker(self):
return worker
-def convert_input(prompt=None, prompt_token_ids=None) -> Dict:
+def convert_input(prompt=None, prompt_token_ids=None) -> dict:
assert (prompt is None) ^ (prompt_token_ids is None)
if prompt is not None:
return {"prompt": prompt}
diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py
index 4232ca5ca..062fabe8a 100644
--- a/src/forge/actors/trainer.py
+++ b/src/forge/actors/trainer.py
@@ -268,3 +268,49 @@ def push_weights(self) -> None:
async def cleanup(self) -> None:
if self.engine.checkpointer:
self.engine.checkpointer.close()
+
+
+def _qwen3_hf_to_vllm(
+ sd: dict[str, torch.Tensor], num_layers: int
+) -> dict[str, torch.Tensor]:
+ """Convert transformers state dict to vLLM format. Specifically, this fuses
+ QKV projection and MLP gate_up_proj layers.
+
+ Args:
+ sd (dict): State dict from HF model.
+ num_layers (int): Number of layers in the model.
+
+ Returns:
+ dict: State dict in vLLM format.
+ """
+ load_sd = {}
+
+ # Copy over directly mapped keys
+ for k in sd:
+ if any(
+ x in k
+ for x in [
+ "down_proj",
+ "input_layernorm",
+ "post_attention_layernorm",
+ "o_proj",
+ "norm.weight",
+ "embed_tokens.weight",
+ "lm_head.weight",
+ ]
+ ):
+ load_sd[k] = sd[k]
+
+ for i in range(num_layers):
+ prefix = f"model.layers.{i}."
+ # QKV fusion
+ q = sd[prefix + "self_attn.q_proj.weight"]
+ k = sd[prefix + "self_attn.k_proj.weight"]
+ v = sd[prefix + "self_attn.v_proj.weight"]
+ load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0)
+ # MLP gate_up_proj fusion
+ gate = sd[prefix + "mlp.gate_proj.weight"]
+ up = sd[prefix + "mlp.up_proj.weight"]
+ load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0)
+
+ return load_sd
diff --git a/src/forge/data/rewards.py b/src/forge/data/rewards.py
index 644c69d1b..29a86fc3a 100644
--- a/src/forge/data/rewards.py
+++ b/src/forge/data/rewards.py
@@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.
import re
-from typing import Optional
from forge.interfaces import Reward
@@ -17,62 +16,69 @@ def __init__(self, tolerance: float = 1e-6, partial_credit: float = 0.1):
self.tolerance = tolerance
self.partial_credit = partial_credit
- def _to_float(self, text) -> Optional[float]:
- """Safely parse a string into a float, or return None if invalid."""
- if text is None:
- return None
- try:
- return float(str(text).strip())
- except (ValueError, TypeError):
- return None
-
- def _extract_number(self, text: str) -> Optional[float]:
- """Try to extract a numeric answer from text."""
- number_pattern = r"([+-]?\d+(?:\.\d+)?(?:e[+-]?\d+)?)"
- patterns = [
- r"####\s*" + number_pattern,
- r"(?:the\s+)?answer\s+is\s*" + number_pattern,
- r"(?:answer:|result:)\s*" + number_pattern,
- r"\$" + number_pattern, # currency
- number_pattern, # fallback
- r"=\s*" + number_pattern + r"\s*(?:\.|$)",
- r"\b" + number_pattern + r"\s*(?:\.|$)",
- ]
- text = text.lower().strip()
- for pattern in patterns:
- matches = re.findall(pattern, text)
- if matches:
- return self._to_float(matches[-1])
- return None
-
def __call__(self, prompt: str, response: str, target: str) -> float:
"""Compute math correctness reward."""
- # Parse expected
- expected_answer = self._to_float(target)
+ target_number = self._to_float(target)
+ if target_number is None:
+ return 0.0
- # Parse response
- model_answer = self._extract_number(response)
+ # Look for answer in tags
+ answer_match = re.search(r"(.*?)", response, re.DOTALL)
- # Scoring
- if expected_answer is None or model_answer is None:
- return self.partial_credit # Partial credit for attempting
+ if answer_match:
+ model_answer = self._to_float(answer_match.group(1).strip())
+ if (
+ model_answer is not None
+ and abs(target_number - model_answer) < self.tolerance
+ ):
+ return 1.0 # Correct answer
- if abs(expected_answer - model_answer) < self.tolerance:
- return 1.0 # Correct answer
- return 0.0 # Incorrect answer
+ # Check for partial credit: target number appears elsewhere in response
+ response_without_answer_tags = re.sub(
+ r".*?", "", response, flags=re.DOTALL
+ )
+ # Convert to int if it's a whole number to avoid "117.0" vs "117" mismatch
+ target_str = (
+ str(int(target_number))
+ if target_number.is_integer()
+ else str(target_number)
+ )
+ if target_str in response_without_answer_tags:
+ return self.partial_credit
+
+ return 0.0 # No match
+
+ def _to_float(self, text: str) -> float | None:
+ """Convert text to float, return None if invalid."""
+ try:
+ # Remove common non-numeric characters like $, commas, etc.
+ cleaned_text = re.sub(r"[$,\s]", "", text.strip())
+ return float(cleaned_text)
+ except (ValueError, AttributeError):
+ return None
class ThinkingReward(Reward):
"""Reward class for evaluating use of tags in reasoning."""
- def __init__(self, reward_value: float = 0.5):
- self.reward_value = reward_value
+ def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0):
+ self.partial_reward = partial_reward
+ self.full_reward = full_reward
+ self._THINK_BLOCK_RE = re.compile(
+ r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL
+ )
+ self._THINK_TAG_ATTEMPT_RE = re.compile(r"<\s*/?\s*think\s*>", re.IGNORECASE)
+
+ def __call__(self, prompt: str, response: str, target: str | None = None) -> float:
+ """Compute thinking reward."""
+ if not response:
+ return 0.0
- def __call__(
- self, prompt: str, response: str, target: Optional[str] = None
- ) -> float:
- """Check if response contains ... tags."""
- resp = response.lower()
- if "" in resp and "" in resp:
- return self.reward_value
+ matches = self._THINK_BLOCK_RE.findall(response)
+ has_well_formed = any(len(re.sub(r"\s+", "", m)) >= 1 for m in matches)
+ has_attempt = bool(self._THINK_TAG_ATTEMPT_RE.search(response)) or bool(matches)
+ if has_well_formed:
+ return self.full_reward
+ elif has_attempt:
+ return self.partial_reward
return 0.0
diff --git a/tests/unit_tests/rl/test_math_reward.py b/tests/unit_tests/rl/test_math_reward.py
index 2f3521b4d..726b1173c 100644
--- a/tests/unit_tests/rl/test_math_reward.py
+++ b/tests/unit_tests/rl/test_math_reward.py
@@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.
import unittest
-from unittest import mock
from forge.data.rewards import MathReward
@@ -36,6 +35,13 @@ def test_to_float_valid_numbers(self):
self.assertEqual(self.reward._to_float("0"), 0.0)
self.assertEqual(self.reward._to_float(" 123.45 "), 123.45)
+ def test_to_float_with_currency_and_formatting(self):
+ """Test _to_float with currency symbols and commas."""
+ self.assertEqual(self.reward._to_float("$42"), 42.0)
+ self.assertEqual(self.reward._to_float("$1,000"), 1000.0)
+ self.assertEqual(self.reward._to_float("1,234.56"), 1234.56)
+ self.assertEqual(self.reward._to_float("$ 42.50 "), 42.5)
+
def test_to_float_invalid_inputs(self):
"""Test _to_float with invalid inputs."""
self.assertIsNone(self.reward._to_float("abc"))
@@ -48,154 +54,146 @@ def test_to_float_edge_cases(self):
"""Test _to_float with edge cases."""
self.assertEqual(self.reward._to_float("1e6"), 1000000.0)
self.assertEqual(self.reward._to_float("-1.5e-3"), -0.0015)
- self.assertEqual(self.reward._to_float("inf"), float("inf"))
- self.assertEqual(self.reward._to_float("-inf"), float("-inf"))
-
- def test_extract_number_gsm8k_format(self):
- """Test _extract_number with GSM8K style format."""
- self.assertEqual(self.reward._extract_number("#### 42"), 42.0)
- self.assertEqual(self.reward._extract_number("#### -3.14"), -3.14)
- self.assertEqual(self.reward._extract_number("Some text #### 123.45"), 123.45)
-
- def test_extract_number_answer_patterns(self):
- """Test _extract_number with various answer patterns."""
- self.assertEqual(self.reward._extract_number("The answer is 42"), 42.0)
- self.assertEqual(self.reward._extract_number("answer is 3.14"), 3.14)
- self.assertEqual(self.reward._extract_number("Answer: 123"), 123.0)
- self.assertEqual(self.reward._extract_number("Result: -5.5"), -5.5)
-
- def test_extract_number_equals_pattern(self):
- """Test _extract_number with equals sign patterns."""
- self.assertEqual(self.reward._extract_number("x = 42."), 42.0)
- self.assertEqual(self.reward._extract_number("The result = 3.14"), 3.14)
- self.assertEqual(self.reward._extract_number("calculation = -7.5."), -7.5)
-
- def test_extract_number_end_of_text(self):
- """Test _extract_number with numbers at end of text."""
- self.assertEqual(self.reward._extract_number("The final result is 42."), 42.0)
- self.assertEqual(self.reward._extract_number("We get 3.14"), 3.14)
- self.assertEqual(self.reward._extract_number("Answer: -5.5."), -5.5)
-
- def test_extract_number_fallback_pattern(self):
- """Test _extract_number with fallback pattern (any number)."""
- self.assertEqual(self.reward._extract_number("There are 42 items"), 42.0)
- self.assertEqual(self.reward._extract_number("Cost is $3.14 per item"), 3.14)
- self.assertEqual(self.reward._extract_number("Temperature: -5.5 degrees"), -5.5)
-
- def test_extract_number_multiple_matches(self):
- """Test _extract_number returns the last match when multiple numbers exist."""
- # Should return the last match from the pattern
- self.assertEqual(
- self.reward._extract_number("First 10, then 20, finally 30"), 30.0
- )
- self.assertEqual(
- self.reward._extract_number("#### 5 but actually #### 10"), 10.0
- )
- def test_extract_number_no_match(self):
- """Test _extract_number when no numbers are found."""
- self.assertIsNone(self.reward._extract_number("No numbers here"))
- self.assertIsNone(self.reward._extract_number(""))
- self.assertIsNone(self.reward._extract_number("Just text"))
+ def test_call_correct_answer_in_tags(self):
+ """Test __call__ with correct answers in tags."""
+ self.assertEqual(self.reward("prompt", "42", "42"), 1.0)
+ self.assertEqual(self.reward("prompt", "3.14", "3.14"), 1.0)
+ self.assertEqual(self.reward("prompt", "-5.5", "-5.5"), 1.0)
- def test_extract_number_case_insensitive(self):
- """Test _extract_number is case insensitive."""
- self.assertEqual(self.reward._extract_number("THE ANSWER IS 42"), 42.0)
- self.assertEqual(self.reward._extract_number("Answer: 3.14"), 3.14)
- self.assertEqual(self.reward._extract_number("RESULT: 123"), 123.0)
+ def test_call_answer_tags_with_whitespace(self):
+ """Test __call__ with answer tags containing whitespace."""
+ self.assertEqual(self.reward("prompt", " 42 ", "42"), 1.0)
+ self.assertEqual(
+ self.reward("prompt", "\n3.14\n", "3.14"), 1.0
+ )
- def test_call_correct_answer(self):
- """Test __call__ with correct answers."""
- self.assertEqual(self.reward("prompt", "The answer is 42", "42"), 1.0)
- self.assertEqual(self.reward("prompt", "#### 3.14", "3.14"), 1.0)
- self.assertEqual(self.reward("prompt", "Result: -5.5", "-5.5"), 1.0)
+ def test_call_answer_tags_with_complex_content(self):
+ """Test __call__ with complex content in answer tags."""
+ response = """
+ Let me solve this step by step:
+ First, I calculate 2 + 3 = 5
+ Then, I multiply by 4: 5 * 4 = 20
+ Finally, I subtract 8: 20 - 8 = 12
+ 12
+ """
+ self.assertEqual(self.reward("prompt", response, "12"), 1.0)
def test_call_within_tolerance(self):
"""Test __call__ with answers within tolerance."""
# Default tolerance is 1e-6
- self.assertEqual(self.reward("prompt", "42.0000001", "42"), 1.0)
- self.assertEqual(self.reward("prompt", "3.1400001", "3.14"), 1.0)
-
- # Custom tolerance
- self.assertEqual(self.custom_reward("prompt", "42.0001", "42"), 1.0)
- self.assertEqual(self.custom_reward("prompt", "3.141", "3.14"), 1.0)
-
- def test_call_outside_tolerance(self):
- """Test __call__ with answers outside tolerance."""
- self.assertEqual(self.reward("prompt", "42.1", "42"), 0.0)
- self.assertEqual(self.reward("prompt", "3.15", "3.14"), 0.0)
- self.assertEqual(self.custom_reward("prompt", "42.01", "42"), 0.0)
-
- def test_call_invalid_target(self):
- """Test __call__ with invalid target values."""
self.assertEqual(
- self.reward("prompt", "42", "invalid"), self.reward.partial_credit
+ self.reward("prompt", "42.0000001", "42"), 1.0
)
- self.assertEqual(self.reward("prompt", "42", ""), self.reward.partial_credit)
self.assertEqual(
- self.reward("prompt", "42", "not a number"), self.reward.partial_credit
+ self.reward("prompt", "3.1400001", "3.14"), 1.0
)
- def test_call_invalid_response(self):
- """Test __call__ with invalid response values."""
+ # Custom tolerance
self.assertEqual(
- self.reward("prompt", "no number", "42"), self.reward.partial_credit
+ self.custom_reward("prompt", "42.0001", "42"), 1.0
)
- self.assertEqual(self.reward("prompt", "", "42"), self.reward.partial_credit)
self.assertEqual(
- self.reward("prompt", "just text", "42"), self.reward.partial_credit
+ self.custom_reward("prompt", "3.141", "3.14"), 1.0
+ )
+
+ def test_call_outside_tolerance(self):
+ """Test __call__ with answers outside tolerance."""
+ self.assertEqual(self.reward("prompt", "42.1", "42"), 0.0)
+ self.assertEqual(self.reward("prompt", "3.15", "3.14"), 0.0)
+ self.assertEqual(
+ self.custom_reward("prompt", "42.01", "42"), 0.0
)
- def test_call_both_invalid(self):
- """Test __call__ with both invalid target and response."""
+ def test_call_partial_credit_target_in_response(self):
+ """Test __call__ with partial credit when target appears in response."""
+ response = "The calculation shows 42 but I put 43"
+ self.assertEqual(self.reward("prompt", response, "42"), 0.1)
+
+ response = "Let me work through this: 42 + 1 = 43. 43"
+ self.assertEqual(self.reward("prompt", response, "42"), 0.1)
+
+ def test_call_partial_credit_custom_value(self):
+ """Test __call__ with custom partial credit value."""
+ response = "The calculation shows 42 but I put 43"
+ self.assertEqual(self.custom_reward("prompt", response, "42"), 0.2)
+
+ def test_call_no_partial_credit_with_answer_tags(self):
+ """Test __call__ doesn't give partial credit if target is only in answer tags."""
+ response = "Let me solve this. 42"
+ # Target 100 is not elsewhere in response, so no partial credit
+ self.assertEqual(self.reward("prompt", response, "100"), 0.0)
+
+ def test_call_integer_target_formatting(self):
+ """Test __call__ with integer targets formatted correctly."""
+ # Integer targets should be formatted without decimal point
+ response = "I calculated and got 117 as the answer. 118"
+ self.assertEqual(self.reward("prompt", response, "117"), 0.1)
+
+ # Should work with 117.0 in target too
+ self.assertEqual(self.reward("prompt", response, "117.0"), 0.1)
+
+ def test_call_float_target_formatting(self):
+ """Test __call__ with float targets."""
+ response = "I calculated and got 3.14 as the answer. 3.15"
+ self.assertEqual(self.reward("prompt", response, "3.14"), 0.1)
+
+ def test_call_invalid_target(self):
+ """Test __call__ with invalid target values."""
+ self.assertEqual(self.reward("prompt", "42", "invalid"), 0.0)
+ self.assertEqual(self.reward("prompt", "42", ""), 0.0)
self.assertEqual(
- self.reward("prompt", "no number", "invalid"), self.reward.partial_credit
+ self.reward("prompt", "42", "not a number"), 0.0
)
- self.assertEqual(self.reward("prompt", "", ""), self.reward.partial_credit)
- def test_call_custom_partial_credit(self):
- """Test __call__ uses custom partial credit value."""
- self.assertEqual(self.custom_reward("prompt", "no number", "42"), 0.2)
- self.assertEqual(self.custom_reward("prompt", "42", "invalid"), 0.2)
+ def test_call_no_answer_tags(self):
+ """Test __call__ with response that has no answer tags."""
+ # Should still check for partial credit
+ self.assertEqual(self.reward("prompt", "The answer is 42", "42"), 0.1)
+ self.assertEqual(self.reward("prompt", "No matching number", "42"), 0.0)
+
+ def test_call_invalid_answer_in_tags(self):
+ """Test __call__ with invalid answer in tags."""
+ response = "not a number but 42 is correct"
+ self.assertEqual(self.reward("prompt", response, "42"), 0.1)
def test_call_zero_values(self):
"""Test __call__ with zero values."""
- self.assertEqual(self.reward("prompt", "0", "0"), 1.0)
- self.assertEqual(self.reward("prompt", "The answer is 0", "0.0"), 1.0)
+ self.assertEqual(self.reward("prompt", "0", "0"), 1.0)
+ self.assertEqual(self.reward("prompt", "0.0", "0"), 1.0)
def test_call_negative_values(self):
"""Test __call__ with negative values."""
- self.assertEqual(self.reward("prompt", "-42", "-42"), 1.0)
- self.assertEqual(self.reward("prompt", "#### -3.14", "-3.14"), 1.0)
- self.assertEqual(self.reward("prompt", "-5", "-4.9"), 0.0)
+ self.assertEqual(self.reward("prompt", "-42", "-42"), 1.0)
+ self.assertEqual(self.reward("prompt", "-3.14", "-3.14"), 1.0)
def test_call_large_numbers(self):
"""Test __call__ with large numbers."""
- self.assertEqual(self.reward("prompt", "1000000", "1000000"), 1.0)
- self.assertEqual(self.reward("prompt", "1e6", "1000000"), 1.0)
- self.assertEqual(self.reward("prompt", "1000001", "1000000"), 0.0)
+ self.assertEqual(
+ self.reward("prompt", "1000000", "1000000"), 1.0
+ )
+ self.assertEqual(self.reward("prompt", "1e6", "1000000"), 1.0)
def test_call_small_numbers(self):
"""Test __call__ with very small numbers."""
- self.assertEqual(self.reward("prompt", "0.000001", "0.000001"), 1.0)
- self.assertEqual(self.reward("prompt", "1e-6", "0.000001"), 1.0)
+ self.assertEqual(
+ self.reward("prompt", "0.000001", "0.000001"), 1.0
+ )
+ self.assertEqual(
+ self.reward("prompt", "1e-6", "0.000001"), 1.0
+ )
- def test_call_complex_response_text(self):
- """Test __call__ with complex response text containing multiple elements."""
- response = """
- Let me solve this step by step:
- First, I calculate 2 + 3 = 5
- Then, I multiply by 4: 5 * 4 = 20
- Finally, I subtract 8: 20 - 8 = 12
- #### 12
- """
- self.assertEqual(self.reward("prompt", response, "12"), 1.0)
+ def test_call_multiple_answer_tags(self):
+ """Test __call__ with multiple answer tags (should use first one)."""
+ response = "First answer: 42 Second: 43"
+ self.assertEqual(self.reward("prompt", response, "42"), 1.0)
+ self.assertEqual(self.reward("prompt", response, "43"), 0.0)
- def test_call_with_units_and_formatting(self):
- """Test __call__ with responses containing units and formatting."""
- self.assertEqual(self.reward("prompt", "The cost is $42.50", "42.5"), 1.0)
- self.assertEqual(self.reward("prompt", "Distance: 3.14 meters", "3.14"), 1.0)
- self.assertEqual(self.reward("prompt", "Temperature is -5.5°C", "-5.5"), 1.0)
+ # Test case where target appears outside answer tags for partial credit
+ response_with_partial = (
+ "I think the answer is 43. 42 But 43 might be better."
+ )
+ self.assertEqual(self.reward("prompt", response_with_partial, "43"), 0.1)
if __name__ == "__main__":
diff --git a/tests/unit_tests/rl/test_thinking_reward.py b/tests/unit_tests/rl/test_thinking_reward.py
index 592ceb896..b95823e9a 100644
--- a/tests/unit_tests/rl/test_thinking_reward.py
+++ b/tests/unit_tests/rl/test_thinking_reward.py
@@ -13,29 +13,36 @@ class TestThinkingReward(unittest.TestCase):
def setUp(self):
"""Set up test fixtures before each test method."""
self.reward = ThinkingReward()
- self.custom_reward = ThinkingReward(reward_value=0.8)
+ self.custom_reward = ThinkingReward(partial_reward=0.3, full_reward=0.9)
def test_init_default_values(self):
"""Test ThinkingReward initialization with default values."""
reward = ThinkingReward()
- self.assertEqual(reward.reward_value, 0.5)
+ self.assertEqual(reward.partial_reward, 0.2)
+ self.assertEqual(reward.full_reward, 1.0)
def test_init_custom_values(self):
"""Test ThinkingReward initialization with custom values."""
- reward = ThinkingReward(reward_value=0.8)
- self.assertEqual(reward.reward_value, 0.8)
+ reward = ThinkingReward(partial_reward=0.3, full_reward=0.9)
+ self.assertEqual(reward.partial_reward, 0.3)
+ self.assertEqual(reward.full_reward, 0.9)
- def test_call_with_both_tags(self):
- """Test __call__ with response containing both and tags."""
- response = "This is my reasoning"
- result = self.reward("prompt", response)
- self.assertEqual(result, 0.5)
+ def test_regex_patterns(self):
+ """Test that regex patterns are compiled correctly."""
+ reward = ThinkingReward()
+ self.assertIsNotNone(reward._THINK_BLOCK_RE)
+ self.assertIsNotNone(reward._THINK_TAG_ATTEMPT_RE)
+
+ def test_call_with_well_formed_thinking_block(self):
+ """Test __call__ with well-formed thinking blocks."""
+ result = self.reward("prompt", "This is my reasoning")
+ self.assertEqual(result, 1.0)
- result = self.custom_reward("prompt", response)
- self.assertEqual(result, 0.8)
+ result = self.custom_reward("prompt", "This is my reasoning")
+ self.assertEqual(result, 0.9)
- def test_call_with_both_tags_complex_content(self):
- """Test __call__ with complex content between thinking tags."""
+ def test_call_with_well_formed_thinking_block_complex_content(self):
+ """Test __call__ with complex content in thinking blocks."""
response = """
Let me solve this problem step by step.
@@ -47,40 +54,58 @@ def test_call_with_both_tags_complex_content(self):
The answer is 4.
"""
result = self.reward("prompt", response)
- self.assertEqual(result, 0.5)
+ self.assertEqual(result, 1.0)
+
+ def test_call_with_minimal_content_thinking_block(self):
+ """Test __call__ with minimal content that still counts as well-formed."""
+ result = self.reward("prompt", "x")
+ self.assertEqual(result, 1.0)
+
+ def test_call_with_empty_thinking_block(self):
+ """Test __call__ with empty thinking block."""
+ result = self.reward("prompt", "")
+ self.assertEqual(result, 0.2) # Should give partial reward, not full
+
+ def test_call_with_whitespace_only_thinking_block(self):
+ """Test __call__ with whitespace-only thinking block."""
+ result = self.reward("prompt", " \n \t ")
+ self.assertEqual(result, 0.2) # Should give partial reward, not full
def test_call_with_only_opening_tag(self):
- """Test __call__ with response containing only tag."""
- response = "This is incomplete reasoning"
- result = self.reward("prompt", response)
- self.assertEqual(result, 0.0)
+ """Test __call__ with response containing only opening tag."""
+ result = self.reward("prompt", "This is incomplete reasoning")
+ self.assertEqual(result, 0.2) # Should give partial reward for attempt
def test_call_with_only_closing_tag(self):
- """Test __call__ with response containing only tag."""
- response = "This is incomplete reasoning"
- result = self.reward("prompt", response)
- self.assertEqual(result, 0.0)
+ """Test __call__ with response containing only closing tag."""
+ result = self.reward("prompt", "This is incomplete reasoning")
+ self.assertEqual(result, 0.2) # Should give partial reward for attempt
def test_call_with_no_tags(self):
"""Test __call__ with response containing no thinking tags."""
- response = "This is just a regular response without any thinking tags."
- result = self.reward("prompt", response)
+ result = self.reward(
+ "prompt", "This is just a regular response without any thinking tags."
+ )
self.assertEqual(result, 0.0)
def test_call_case_insensitive(self):
"""Test __call__ is case insensitive for thinking tags."""
- # Mixed case tags should work
- response = "This is my reasoning"
- result = self.reward("prompt", response)
- self.assertEqual(result, 0.5)
+ result = self.reward("prompt", "This is my reasoning")
+ self.assertEqual(result, 1.0)
- response = "This is my reasoning"
- result = self.reward("prompt", response)
- self.assertEqual(result, 0.5)
+ result = self.reward("prompt", "This is my reasoning")
+ self.assertEqual(result, 1.0)
- response = "This is my reasoning"
- result = self.reward("prompt", response)
- self.assertEqual(result, 0.5)
+ result = self.reward("prompt", "This is my reasoning")
+ self.assertEqual(result, 1.0)
+
+ def test_call_with_whitespace_in_tags(self):
+ """Test __call__ with whitespace in thinking tags."""
+ result = self.reward("prompt", "< think >This is my reasoning think >")
+ self.assertEqual(result, 1.0)
+
+ result = self.reward("prompt", "<\tthink\n>Content\tthink\n>")
+ self.assertEqual(result, 1.0)
def test_call_multiple_thinking_blocks(self):
"""Test __call__ with multiple thinking blocks."""
@@ -90,54 +115,93 @@ def test_call_multiple_thinking_blocks(self):
Second thought
"""
result = self.reward("prompt", response)
- self.assertEqual(result, 0.5)
+ self.assertEqual(result, 1.0)
def test_call_nested_tags(self):
"""Test __call__ with nested or malformed tags."""
- # Nested tags - should still work as long as both tags exist
- response = "Outer inner thought"
+ result = self.reward(
+ "prompt", "Outer inner thought"
+ )
+ self.assertEqual(result, 1.0)
+
+ def test_call_multiline_thinking_block(self):
+ """Test __call__ with multiline thinking blocks."""
+ response = """
+ This is a multiline
+ thinking block with
+ lots of content
+ """
result = self.reward("prompt", response)
- self.assertEqual(result, 0.5)
-
- def test_call_empty_thinking_block(self):
- """Test __call__ with empty thinking block."""
- response = ""
- result = self.reward("prompt", response)
- self.assertEqual(result, 0.5)
+ self.assertEqual(result, 1.0)
def test_call_empty_response(self):
"""Test __call__ with empty response."""
result = self.reward("prompt", "")
self.assertEqual(result, 0.0)
- def test_call_tags_with_extra_whitespace(self):
- """Test __call__ with thinking tags containing extra whitespace."""
- response = "< think >This has spaces< /think >"
- result = self.reward("prompt", response)
- self.assertEqual(result, 0.0) # Should not match due to spaces in tags
+ def test_call_none_response(self):
+ """Test __call__ with None response."""
+ result = self.reward("prompt", None)
+ self.assertEqual(result, 0.0)
def test_call_with_target_parameter(self):
"""Test __call__ with target parameter (should be ignored)."""
- response = "This is my reasoning"
- result = self.reward("prompt", response, target="some target")
- self.assertEqual(result, 0.5)
+ result = self.reward(
+ "prompt", "This is my reasoning", target="some target"
+ )
+ self.assertEqual(result, 1.0)
result = self.reward("prompt", "no tags", target="some target")
self.assertEqual(result, 0.0)
- def test_call_zero_reward_value(self):
- """Test __call__ with zero reward value."""
- zero_reward = ThinkingReward(reward_value=0.0)
- response = "This is my reasoning"
- result = zero_reward("prompt", response)
+ result = self.reward(
+ "prompt", "This is my reasoning", target=None
+ )
+ self.assertEqual(result, 1.0)
+
+ def test_call_custom_reward_values(self):
+ """Test __call__ with custom reward values."""
+ response_full = "This is proper reasoning"
+ response_partial = ""
+ response_none = "no thinking tags"
+
+ # Test custom partial reward
+ self.assertEqual(self.custom_reward("prompt", response_full), 0.9)
+ self.assertEqual(self.custom_reward("prompt", response_partial), 0.3)
+ self.assertEqual(self.custom_reward("prompt", response_none), 0.0)
+
+ def test_call_zero_custom_values(self):
+ """Test __call__ with zero custom values."""
+ zero_reward = ThinkingReward(partial_reward=0.0, full_reward=0.0)
+ result = zero_reward("prompt", "This is my reasoning")
self.assertEqual(result, 0.0)
- def test_call_negative_reward_value(self):
- """Test __call__ with negative reward value."""
- negative_reward = ThinkingReward(reward_value=-0.5)
- response = "This is my reasoning"
- result = negative_reward("prompt", response)
- self.assertEqual(result, -0.5)
+ def test_call_negative_reward_values(self):
+ """Test __call__ with negative reward values."""
+ negative_reward = ThinkingReward(partial_reward=-0.1, full_reward=-0.5)
+
+ self.assertEqual(
+ negative_reward("prompt", "This is proper reasoning"), -0.5
+ )
+ self.assertEqual(negative_reward("prompt", ""), -0.1)
+
+ def test_call_edge_case_characters(self):
+ """Test __call__ with edge case characters in thinking blocks."""
+ result = self.reward(
+ "prompt", "Special chars: @#$%^&*()_+-=[]{}|;':\",./<>?`~"
+ )
+ self.assertEqual(result, 1.0)
+
+ def test_call_unicode_characters(self):
+ """Test __call__ with unicode characters in thinking blocks."""
+ result = self.reward("prompt", "Unicode: αβγδε 中文 🚀")
+ self.assertEqual(result, 1.0)
+
+ def test_call_very_long_thinking_block(self):
+ """Test __call__ with very long thinking blocks."""
+ long_content = "A" * 10000
+ result = self.reward("prompt", f"{long_content}")
+ self.assertEqual(result, 1.0)
if __name__ == "__main__":