diff --git a/pyproject.toml b/pyproject.toml index 8460b5b78..bb66d0191 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dev = [ "anyio", "pytest-asyncio", "multiprocess", + "langid", ] docs = [ "sphinx==7.2.6", diff --git a/src/forge/data/rewards.py b/src/forge/data/rewards.py index 23a0002df..91ed7fea5 100644 --- a/src/forge/data/rewards.py +++ b/src/forge/data/rewards.py @@ -57,15 +57,28 @@ def _to_float(self, text: str) -> float | None: class ThinkingReward: - """Reward class for evaluating use of tags in reasoning.""" + """Reward class for evaluating use of thinking tags in reasoning. - def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0): + Args: + partial_reward: Reward for partial tag usage (incomplete/malformed) + full_reward: Reward for well-formed thinking blocks with content + tag: Tag name to use (default "think", can use "思考" for Japanese, etc.) + """ + + def __init__( + self, partial_reward: float = 0.2, full_reward: float = 1.0, tag: str = "think" + ): self.partial_reward = partial_reward self.full_reward = full_reward + self.tag = tag + # Build regex patterns for the specified tag self._THINK_BLOCK_RE = re.compile( - r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL + rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>", + re.IGNORECASE | re.DOTALL, + ) + self._THINK_TAG_ATTEMPT_RE = re.compile( + rf"<\s*/?\s*{re.escape(tag)}\s*>", re.IGNORECASE ) - self._THINK_TAG_ATTEMPT_RE = re.compile(r"<\s*/?\s*think\s*>", re.IGNORECASE) def __call__(self, prompt: str, response: str, target: str | None = None) -> float: """Compute thinking reward.""" @@ -80,3 +93,132 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo elif has_attempt: return self.partial_reward return 0.0 + + +class LanguageReward: + """Reward class for evaluating the language used in responses. + + This reward uses langid to detect the language and rewards responses that use + the target language. The detection strategy depends on the format: + - If exactly one thinking block: detect language of the block content + - Otherwise (no blocks or multiple blocks): detect language of whole response + + Note: Format enforcement (single vs multiple blocks) is handled by ThinkingReward. + This reward focuses purely on language detection. + + Args: + target_language: ISO 639-1 language code (e.g., 'en', 'ja', 'zh', 'es') + match_reward: Reward when detected language matches target (default: 1.0) + no_match_reward: Reward when language doesn't match (default: 0.0) + tag: Tag name to use (default "思考" for multilingual, can use "think", etc.) + debug: If True, print debug samples showing model outputs and detected language + debug_sample_rate: Fraction of calls to debug (e.g., 0.1 = 10% of calls) + + Note: Requires langid to be installed. Install with: pip install langid + """ + + def __init__( + self, + target_language: str = "en", + match_reward: float = 1.0, + no_match_reward: float = 0.0, + tag: str = "思考", + debug: bool = False, + debug_sample_rate: float = 0.1, + ): + self.target_language = target_language + self.match_reward = match_reward + self.no_match_reward = no_match_reward + self.tag = tag + self.debug = debug + self.debug_sample_rate = debug_sample_rate + self._debug_counter = 0 + # Build regex pattern for the specified tag + self._THINK_BLOCK_RE = re.compile( + rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>", re.DOTALL + ) + + # Lazy import langid with helpful error message + try: + import langid + + self._langid = langid + except ImportError: + raise ImportError( + "langid is required for LanguageReward but is not installed. " + "Please install it with: pip install langid" + ) from None + + def __call__(self, prompt: str, response: str, target: str | None = None) -> float: + """Compute language reward based on detected language. + + Detection strategy: + - If exactly one thinking block: detect language of block content + - Otherwise: detect language of whole response + + Args: + prompt: The input prompt (unused but kept for signature consistency) + response: The model response + target: Optional target string (unused but kept for signature consistency) + + Returns: + match_reward if detected language matches target, no_match_reward otherwise + """ + # Increment counter for sampling + self._debug_counter += 1 + should_debug = ( + self.debug + and self.debug_sample_rate > 0 + and (self._debug_counter % int(1 / self.debug_sample_rate)) == 0 + ) + + if not response: + if should_debug: + print( + f"\n[LanguageReward] Empty response | Reward: {self.no_match_reward}" + ) + return self.no_match_reward + + # Extract all thinking blocks + matches = self._THINK_BLOCK_RE.findall(response) + + # Determine what text to analyze + if len(matches) == 1: + # Single block: detect language of block content only + text_to_analyze = matches[0].strip() + detection_mode = "single block" + else: + # No blocks or multiple blocks: detect language of whole response + text_to_analyze = response.strip() + detection_mode = f"{len(matches)} blocks, using whole response" + + # Remove extra whitespace + text_to_analyze = re.sub(r"\s+", " ", text_to_analyze).strip() + + if not text_to_analyze: + if should_debug: + print(f"\n[LanguageReward] Empty text | Reward: {self.no_match_reward}") + return self.no_match_reward + + # Detect language using langid + detected_lang, confidence = self._langid.classify(text_to_analyze) + + # Check if language matches target + reward = ( + self.match_reward + if detected_lang == self.target_language + else self.no_match_reward + ) + + if should_debug: + sample = text_to_analyze[:150].replace("\n", " ") + match_symbol = "✓" if detected_lang == self.target_language else "✗" + print( + f"\n[LanguageReward] Detection mode: {detection_mode}" + f"\n Target: {self.target_language} | Detected: {detected_lang} | " + f"Confidence: {confidence:.2f}" + f"\n Sample: {sample}..." + f"\n → Reward: {reward} {match_symbol}" + ) + + return reward diff --git a/tests/sandbox/grpo_language/README.md b/tests/sandbox/grpo_language/README.md new file mode 100644 index 000000000..f40b15c1a --- /dev/null +++ b/tests/sandbox/grpo_language/README.md @@ -0,0 +1,98 @@ +# GRPO with Language Reward + +This sandbox app demonstrates using GRPO training with a language reward that encourages the model to think in a specific target language. + +## Overview + +This app extends the standard GRPO training (from `apps/grpo/`) by adding a `LanguageReward` that evaluates whether the model's thinking (text within `<思考>` tags) is in the target language. + +**Key Insight**: Uses Japanese tags `<思考>` (shikō = "thinking") instead of English `` tags to break the model's association between thinking tags and English language. This helps encourage multilingual thinking. + +## Key Features + +- **Multi-objective training**: Combines three rewards: + - `MathReward`: Evaluates correctness of math answers + - `ThinkingReward`: Encourages use of `<思考>` tags + - `LanguageReward`: Rewards thinking in target language (Japanese by default) + +- **Japanese thinking tags**: Uses `<思考>` instead of `` to encourage non-English reasoning + +- **Language detection**: Uses `langid` to detect the language of thinking blocks + +- **Configurable target language**: While this app defaults to Japanese (`ja`), the `LanguageReward` can be configured for any ISO 639-1 language code + +- **Configurable tags**: Both rewards support custom tag names via the `tag` parameter + +## Requirements + +Before running this app, install the required language detection library: + +```bash +pip install langid +``` + +## Usage + +```bash +python -m sandbox.grpo_language.main --config apps/grpo/qwen3_1_7b.yaml +``` + +You can use any of the config files from `apps/grpo/` (e.g., `qwen3_1_7b.yaml`, `qwen3_8b.yaml`, `qwen3_32b.yaml`). + +## How It Works + +1. The model receives a math problem and is instructed to use `<思考>` tags for reasoning +2. During training, the model generates responses with thinking blocks +3. Three rewards are computed: + - **MathReward**: Did it get the right answer? + - **ThinkingReward**: Did it use `<思考>` tags properly? (single block = full reward, multiple blocks = partial reward) + - **LanguageReward**: Did it use the target language? Detection strategy: + - If exactly one thinking block: detect language of block content only + - Otherwise (no blocks or multiple blocks): detect language of whole response + - Returns match_reward (1.0) if detected language matches target, no_match_reward (0.0) otherwise +4. The model is trained to maximize all three rewards + +**Note**: ThinkingReward enforces format (single vs multiple blocks), while LanguageReward focuses purely on language detection. This separation of concerns allows each reward to specialize in one aspect of the desired behavior. + +## Configuration + +### Target Language + +The target language is configured as Japanese in `main.py`: + +```python +LanguageReward(target_language="ja", tag="思考") +ThinkingReward(tag="思考") +``` + +To use a different language: +1. Change `target_language` to the appropriate ISO 639-1 code: + - English: `"en"` + - Chinese: `"zh"` +- Spanish: `"es"` +- French: `"fr"` +- etc. + +## Expected Behavior + +Over the course of training, the model should learn to: +1. Solve math problems correctly +2. Use `<思考>` tags for its reasoning +3. Write its thinking in Japanese (or the configured target language) + +## Metrics + +The following metrics are logged to W&B: +- `reward/evaluate_response/avg_LanguageReward_reward`: Average language reward +- `reward/evaluate_response/avg_MathReward_reward`: Average math reward +- `reward/evaluate_response/avg_ThinkingReward_reward`: Average thinking reward +- `reward/evaluate_response/avg_total_reward`: Average of all rewards + +## Differences from Standard GRPO + +This is a modified version of `apps/grpo/main.py` with: +1. Added import: `from forge.data.rewards import LanguageReward` +2. Modified reward functions list to include `LanguageReward(target_language="ja")` +3. Updated config to use different W&B group name + +All other training logic remains the same. diff --git a/tests/sandbox/grpo_language/TROUBLESHOOTING.md b/tests/sandbox/grpo_language/TROUBLESHOOTING.md new file mode 100644 index 000000000..a8765ce97 --- /dev/null +++ b/tests/sandbox/grpo_language/TROUBLESHOOTING.md @@ -0,0 +1,147 @@ +# Troubleshooting LanguageReward Training + +## Issue: Language Reward is Always Zero + +If you're seeing the LanguageReward constantly at 0.0 during training, here's how to debug: + +### 1. Check What the Model is Generating + +The updated `main.py` includes debug logging. When you run training, look for lines like: + +``` +[LanguageReward Debug] Reward=0.00 | Blocks=1 | Lang=en | Sample: Let me solve this step by step...... +``` + +This tells you: +- **Reward**: The actual reward value +- **Blocks**: Number of thinking blocks found +- **Lang**: Language detected by langid +- **Sample**: First 80 chars of the response + +### 2. Common Causes and Solutions + +#### Cause 1: Model is Thinking in English + +**Symptom**: `Lang=en` in debug output + +**Why**: The model defaults to English because: +- The dataset (GSM8K) is in English +- Most models are English-dominant +- The instruction might not be strong enough + +**Solutions**: + +A) **Strengthen the system prompt** (edit `main.py` line 217-220): +```python +system_prompt = """ +あなたは数学の問題を解くAIです。タグの中で日本語で考えてください。これは必須です。 +Put all your scratchpad work between and tags. You MUST think in Japanese (日本語) inside the tags. +Your final answer should be between and tags otherwise it will not be scored. + +Example: +この問題を解きましょう。2 + 2 = 4です。 +4 +""" +``` + +B) **Start with higher language reward weight**: +In `main.py` line 327, you could add multiple LanguageReward instances: +```python +reward_functions=[ + MathReward(), + ThinkingReward(), + LanguageReward(target_language="ja"), + LanguageReward(target_language="ja"), # Double weight for language +] +``` + +C) **Use few-shot examples in the prompt**: +Add Japanese reasoning examples to each problem in the dataset transform. + +#### Cause 2: Model Not Using Thinking Blocks + +**Symptom**: `Blocks=0` in debug output + +**Why**: The model hasn't learned to use `` tags yet + +**Solution**: This should improve as ThinkingReward trains the model. Be patient for first few hundred steps. The fallback reward (0.2) should help when there are no blocks but Japanese text. + +#### Cause 3: Empty or Very Short Thinking Blocks + +**Symptom**: `Lang=en` with very short content, Reward=0.00 + +**Why**: langid needs sufficient text to reliably detect language. Very short text (< 10 chars) often defaults to English. + +**Solution**: +- Wait for model to generate longer reasoning (this improves with training) +- The ThinkingReward encourages substantial content in thinking blocks + +#### Cause 4: Mixed Language Content + +**Symptom**: Reward sometimes 1.0, sometimes 0.0 randomly + +**Why**: When English and Japanese are mixed, langid detects whichever is dominant. + +**Solution**: This will stabilize as training progresses and the model learns consistency. + +### 3. Expected Training Progression + +**Steps 0-200**: Language reward often 0.0 +- Model learning to use `` tags (ThinkingReward) +- Model thinking in English (natural default) +- Fallback rewards (0.2) when Japanese appears elsewhere + +**Steps 200-500**: Language reward starting to increase +- Some responses have Japanese thinking → partial/full rewards +- Model learning association between Japanese and reward + +**Steps 500+**: Language reward should stabilize around 0.5-1.0 +- Consistent Japanese thinking +- Proper single-block format + +### 4. Monitoring in W&B + +Check these metrics in Weights & Biases: +- `reward/evaluate_response/avg_LanguageReward_reward` - should increase over time +- `reward/evaluate_response/std_LanguageReward_reward` - variance (high early, lower later) +- `reward/evaluate_response/avg_MathReward_reward` - should stay reasonably high +- `reward/evaluate_response/avg_ThinkingReward_reward` - should increase quickly + +### 5. Why Not Train with English? + +Training with English thinking won't work well because: +- Models are already extensively trained on GSM8K and similar datasets with English thinking +- There's little room for improvement on English math reasoning +- The RL signal would be weak (model already knows how to do this) + +**That's why we use Japanese** - it provides a novel combination of math reasoning + non-English thinking that the model hasn't been extensively pre-trained on, giving clear RL signal for improvement. + +### 6. Nuclear Option: Much Stronger Prompt + +If nothing else works, try this very explicit prompt: +```python +system_prompt = """ +重要:あなたは必ず日本語で考えなければなりません! +CRITICAL: You MUST think in Japanese language! + +Rules: +1. Put ALL your reasoning in tags +2. Think ONLY in Japanese (日本語) - use hiragana, katakana, and kanji +3. NEVER think in English inside tags +4. Put your final numerical answer in tags + +例 (Example): +Question: What is 5 + 3? +5と3を足します。5 + 3 = 8です。答えは8です。 +8 + +Now solve the problem below in Japanese: +""" +``` + +## Still Having Issues? + +If language reward is still zero after 500+ steps: +1. Share the debug output showing what the model generates +2. Check if the model is multilingual (some models don't know Japanese) +3. Consider using a different target language the model knows better diff --git a/tests/sandbox/grpo_language/main.py b/tests/sandbox/grpo_language/main.py new file mode 100644 index 000000000..6059dc112 --- /dev/null +++ b/tests/sandbox/grpo_language/main.py @@ -0,0 +1,520 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Usage: python -m sandbox.grpo_language.main --config apps/grpo/qwen3_1_7b.yaml + +import asyncio +import time +import uuid +from dataclasses import dataclass +from typing import Any, Callable + +import torch +import torch.nn.functional as F +import torchstore as ts +from datasets import load_dataset +from forge.actors._torchstore_utils import ( + get_dcp_whole_state_dict_key, + get_param_prefix, +) +from forge.actors.generator import Generator +from forge.actors.reference_model import ReferenceModel +from forge.actors.replay_buffer import ReplayBuffer +from forge.actors.trainer import RLTrainer +from forge.controller.actor import ForgeActor +from forge.controller.provisioner import init_provisioner, shutdown +from forge.data.rewards import LanguageReward, MathReward, ThinkingReward +from forge.data_models.completion import Completion +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer + +from forge.types import LauncherConfig, ProvisionerConfig +from forge.util.config import parse +from forge.util.ops import compute_logprobs +from monarch.actor import endpoint +from omegaconf import DictConfig +from vllm.transformers_utils.tokenizer import get_tokenizer + + +@dataclass +class Episode: + episode_id: str + pad_id: int + request_len: int + response_len: int + target: Any | None = None + # Processed data + completion: Completion | None = None + ref_logprobs: torch.Tensor | None = None + reward: float | None = None + advantage: float | None = None + + @property + def policy_version(self) -> int | None: + return self.completion.generator_version + + @property + def request_tensor(self) -> torch.Tensor: + tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long) + if tensor.shape[0] < self.request_len: # left pad + diff = self.request_len - tensor.shape[0] + tensor = F.pad(tensor, (diff, 0), value=self.pad_id) + return tensor + + @property + def response_tensor(self) -> torch.Tensor: + tensor: torch.Tensor = self.completion.token_ids.to(torch.long) + if tensor.shape[0] < self.response_len: # right pad + diff = self.response_len - tensor.shape[0] + tensor = F.pad(tensor, (0, diff), value=self.pad_id) + return tensor + + +# Represents the group (G) of episodes in GRPO +Group = list[Episode] + +# Represents the Policy Model to collect data from +Policy = Generator + + +def collate( + batches: list[Group], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """ + Collates a list of batches into a single batch of inputs and targets. + Each batch is a list of episodes, and each episode is a dict of tensors. + """ + inputs = [] + targets = [] + for batch in batches: + request = [e.request_tensor for e in batch] + request = torch.stack(request) # [b x s] + + response = [e.response_tensor for e in batch] + response = torch.stack(response) # [b x s] + + ref_logprobs = [e.ref_logprobs for e in batch] + ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s] + + advantages = [e.advantage for e in batch] + advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1] + + pad_id = batch[0].pad_id + mask = response != pad_id + + input = {"tokens": torch.cat([request, response], dim=1)} + target = { + "response": response, + "ref_logprobs": ref_logprobs, + "advantages": advantages, + "padding_mask": mask, + } + inputs.append(input) + targets.append(target) + return inputs, targets + + +# Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss` +def simple_grpo_loss( + logits: torch.Tensor, + response: torch.Tensor, + ref_logprobs: torch.Tensor, + advantages: torch.Tensor, + padding_mask: torch.Tensor, + beta: float = 1e-4, +) -> torch.Tensor: + logprobs: torch.Tensor = compute_logprobs(logits, response) + kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 + per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages + per_token_loss = -(per_token_policy_loss - beta * kl) + loss = ( + ((per_token_loss * padding_mask).sum(dim=1)) + / (padding_mask.sum(dim=1).clamp(min=1.0)) + ).mean() + return loss + + +@dataclass +class RewardActor(ForgeActor): + + reward_functions: list[Callable] + + @endpoint + async def evaluate_response(self, prompt: str, response: str, target: str) -> float: + total_rewards = 0.0 + for reward_fn in self.reward_functions: + reward = reward_fn(prompt, response, target) + total_rewards += reward + + # Get a name for the reward function (works for classes, functions, lambdas) + reward_fn_name = getattr( + reward_fn, "__name__", reward_fn.__class__.__name__ + ) + + # per function reward + record_metric( + f"reward/evaluate_response/sum_{reward_fn_name}_reward", + reward, + Reduce.SUM, + ) + record_metric( + f"reward/evaluate_response/avg_{reward_fn_name}_reward", + reward, + Reduce.MEAN, + ) + record_metric( + f"reward/evaluate_response/std_{reward_fn_name}_reward", + reward, + Reduce.STD, + ) + + record_metric( + "reward/evaluate_response/avg_total_reward", + reward, + Reduce.MEAN, + ) + + record_metric( + f"reward/evaluate_response/count_{reward_fn_name}_calls", + 1, + Reduce.SUM, + ) + + avg_reward = total_rewards / len(self.reward_functions) + return avg_reward + + +@dataclass +class ComputeAdvantages(ForgeActor): + @endpoint + async def compute(self, group: Group) -> list[float]: + # TODO: add batch processing + rewards = torch.tensor([[e.reward for e in group]]) + mean = rewards.mean(1, keepdim=True) + std = rewards.std(1, keepdim=True) + advantages = (rewards - mean) / (std + 1e-4) + return advantages.squeeze(0).tolist() + + +@dataclass +class DatasetActor(ForgeActor): + """Actor wrapper for HuggingFace dataset to provide async interface.""" + + path: str = "openai/gsm8k" + revision: str = "main" + data_split: str = "train" + streaming: bool = True + model: str = "Qwen/Qwen3-1.7B" + + @endpoint + def setup(self): + self._tokenizer = get_tokenizer(self.model) + self._epoch = 0 + + def gsm8k_transform(sample): + system_prompt = """ +You are a helpful AI assistant that solves math problems. + +Please show your reasoning inside <思考> tags, then provide your final numerical answer inside tags. + +Example: +Question: What is 12 + 5? +<思考>12と5を足します。12 + 5 = 17です。 +17 + """ + request: str = sample["question"] + as_chat = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": request}, + ] + formatted_request = self._tokenizer.apply_chat_template( + as_chat, + tokenize=False, + add_generation_prompt=True, + ) + target: str = sample["answer"] + formatted_target = target.split("#### ")[1] + return {"request": formatted_request, "target": formatted_target} + + self._base_dataset = load_dataset( + self.path, self.revision, split=self.data_split, streaming=self.streaming + ) + self._base_dataset = self._base_dataset.map(gsm8k_transform) + self._base_dataset = self._base_dataset.shuffle() + self._iterator = iter(self._base_dataset) + + @endpoint + async def sample(self) -> dict[str, str] | None: + try: + sample = next(self._iterator) + + record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM) + record_metric( + "dataset/sample/avg_sample_len", + len(sample["request"]), + Reduce.MEAN, + ) + + record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX) + return sample + except StopIteration: + # Restart iterator for next epoch + self._epoch += 1 + print( + f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}" + ) + self._base_dataset.set_epoch(self._epoch) + self._iterator = iter(self._base_dataset) + return next(self._iterator) + + @endpoint + async def pad_token(self): + return self._tokenizer.pad_token_id + + +async def drop_weights(version: int): + print(f"Dropping weights @ version {version}") + start_time = time.perf_counter() + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + # TODO: once we have something like `get_meta()` in torchstore, we can just + # query the type of the object instead of relying on keys. + dcp_key = get_dcp_whole_state_dict_key(version) + if dcp_key in matching_keys: + dcp_handle = await ts.get(dcp_key) + dcp_handle.drop() + for key in matching_keys: + await ts.delete(key) + elapsed = time.perf_counter() - start_time + print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") + + +async def main(cfg: DictConfig): + """Main GRPO training loop with rollout and training processes.""" + group_size = cfg.group_size + max_req_tokens = cfg.max_req_tokens + max_res_tokens = cfg.max_res_tokens + + # ---- Global setups ---- # + provisioner = None + if cfg.get("provisioner", None) is not None: + provisioner = await init_provisioner( + ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) + ) + else: + provisioner = await init_provisioner() + + metric_logging_cfg = cfg.get("metric_logging", {}) + mlogger = await get_or_create_metric_logger(process_name="Controller") + await mlogger.init_backends.call_one(metric_logging_cfg) + + # ---- Setup services ---- # + + ( + dataloader, + policy, + trainer, + replay_buffer, + compute_advantages, + ref_model, + reward_actor, + ) = await asyncio.gather( + DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset), + Policy.options(**cfg.services.policy).as_service(**cfg.policy), + RLTrainer.options(**cfg.actors.trainer).as_actor( + **cfg.trainer, loss=simple_grpo_loss + ), + ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor( + **cfg.replay_buffer, collate=collate + ), + ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(), + ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model), + RewardActor.options(**cfg.services.reward_actor).as_service( + reward_functions=[ + MathReward(), + ThinkingReward(tag="思考"), # Use Japanese tag + LanguageReward( + target_language="ja", + tag="思考", + match_reward=2.0, + debug=True, + debug_sample_rate=0.1, + ), # Japanese language reward with debug + ] + ), + ) + + # Set max_steps to the configured value, or -1 if not specified or Null + max_steps = cfg.trainer.training.steps or -1 + + print("All services initialized successfully!") + shutdown_event = asyncio.Event() + # Here we spawn a torchstore storage volume per trainer process. + # We initialize after service initialization because torchstore currently + # requires access to the underlying proc meshes in the local rank strategy. + # We should be able to hide this in the future. + # TODO: support multiple host meshes + trainer_num_procs = cfg.actors.trainer["procs"] + trainer_host_mesh_name = cfg.actors.trainer["mesh_name"] + trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name) + await ts.initialize( + mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}), + strategy=ts.LocalRankStrategy(), + ) + print("Torchstore successfully initialized with local rank strategy") + + # ---- Core RL loops ---- # + async def continuous_rollouts(): + rollout_count = 0 + pad_id = await dataloader.pad_token.call_one() + while not shutdown_event.is_set(): + t = Tracer("main_perf/continuous_rollouts") + t.start() + sample = await dataloader.sample.call_one() + if sample is None: + print("Dataloader is empty, exiting continuous rollout") + return + + t.step("data_loading") + + prompt, target = sample["request"], sample["target"] + responses: list[Completion] = await policy.generate.route(prompt) + t.step("policy_generation") + + # Construct episodes and calculate rewards + episodes = [] + input_ids = torch.ones( + (group_size, max_req_tokens + max_res_tokens), + dtype=torch.long, + ) + for i, response in enumerate(responses): + episode = Episode( + episode_id=str(uuid.uuid4()), + pad_id=pad_id, + request_len=max_req_tokens, + response_len=max_res_tokens, + target=target, + completion=response, + ) + episode.reward = await reward_actor.evaluate_response.route( + prompt=prompt, response=response.text, target=target + ) + episodes.append(episode) + + # Build input_ids for reference logprobs + input_ids[i, :max_req_tokens] = episode.request_tensor + input_ids[i, max_req_tokens:] = episode.response_tensor + + t.step("reward_evaluation") + + ref_logprobs = await ref_model.forward.route( + input_ids, max_req_tokens, return_logprobs=True + ) + t.step("reference_model_calculate_logprobs") + + for i, episode in enumerate(episodes): + episode.ref_logprobs = ref_logprobs[i] + del ref_logprobs, input_ids + + advantages = await compute_advantages.compute.call_one(episodes) + for episode, advantage in zip(episodes, advantages): + episode.advantage = advantage + await replay_buffer.add.call_one(episode) + + rollout_count += 1 + record_metric( + "main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM + ) + t.stop() + + async def continuous_training(): + training_step = 0 + restart_tracer = True # Flag to control when to restart tracer + + while max_steps == -1 or training_step < max_steps: + # Restart tracer when needed (initial start or after completing a training step) + # Otherwise, we cannot measure time waiting for buffer + if restart_tracer: + t = Tracer("main_perf/continuous_training") + t.start() + restart_tracer = False + + batch = await replay_buffer.sample.call_one( + curr_policy_version=training_step + ) + if batch is None: + await asyncio.sleep(0.1) + else: + t.step("waiting_for_buffer") + + inputs, targets = batch + await trainer.train_step.call(inputs, targets) + training_step += 1 + t.step("train_step") + + await trainer.push_weights.call(training_step) + t.step("push_weights") + + await policy.update_weights.fanout(training_step) + t.step("update_weights") + + if training_step >= 2: + await drop_weights(training_step - 1) + t.step("drop_weights") + + t.stop() + restart_tracer = True + + # Flush metrics every training step to WandB + await mlogger.flush.call_one(training_step) + + print( + f"Reached training limit ({max_steps} steps). Exiting continuous_training loop." + ) + + num_rollout_threads = cfg.get("rollout_threads", 1) + num_training_threads = cfg.get("training_threads", 1) + print( + f"Starting GRPO with {num_rollout_threads} rollout threads, {num_training_threads} training threads" + ) + rollout_tasks = [ + asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads) + ] + training_task = asyncio.create_task(continuous_training()) + + try: + await training_task + except KeyboardInterrupt: + print("Training interrupted by user") + finally: + print("Shutting down... (this may take a few seconds)") + shutdown_event.set() + + try: + # Give rollouts up to 5s to finish naturally + await asyncio.wait_for( + asyncio.gather(*rollout_tasks, return_exceptions=True), + timeout=5, + ) + except asyncio.TimeoutError: + print("Timeout waiting for rollouts; forcing cancellation...") + for t in rollout_tasks: + t.cancel() + await asyncio.gather(*rollout_tasks, return_exceptions=True) + + training_task.cancel() + + await shutdown() + + +if __name__ == "__main__": + + @parse + def _main(cfg): + asyncio.run(main(cfg)) + + _main() # @parse grabs the cfg from CLI diff --git a/tests/unit_tests/rl/test_language_reward.py b/tests/unit_tests/rl/test_language_reward.py new file mode 100644 index 000000000..423ba4829 --- /dev/null +++ b/tests/unit_tests/rl/test_language_reward.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import unittest +from unittest.mock import patch + + +class TestLanguageReward(unittest.TestCase): + def setUp(self): + """Set up test fixtures before each test method.""" + # Import after patching to avoid ImportError + from forge.data.rewards import LanguageReward + + self.LanguageReward = LanguageReward + self.reward_en = LanguageReward(target_language="en") + self.reward_ja = LanguageReward(target_language="ja") + self.custom_reward = LanguageReward( + target_language="ja", + match_reward=0.9, + no_match_reward=0.1, + ) + + def test_init_default_values(self): + """Test LanguageReward initialization with default values.""" + reward = self.LanguageReward() + self.assertEqual(reward.target_language, "en") + self.assertEqual(reward.match_reward, 1.0) + self.assertEqual(reward.no_match_reward, 0.0) + + def test_init_custom_values(self): + """Test LanguageReward initialization with custom values.""" + reward = self.LanguageReward( + target_language="ja", + match_reward=0.9, + no_match_reward=0.1, + ) + self.assertEqual(reward.target_language, "ja") + self.assertEqual(reward.match_reward, 0.9) + self.assertEqual(reward.no_match_reward, 0.1) + + def test_init_missing_langid(self): + """Test LanguageReward initialization without langid installed.""" + # Remove langid from modules if it exists + langid_module = sys.modules.get("langid") + if "langid" in sys.modules: + del sys.modules["langid"] + + with patch.dict("sys.modules", {"langid": None}): + with self.assertRaises(ImportError) as context: + # Re-import to trigger the ImportError + import importlib + + import forge.data.rewards + + importlib.reload(forge.data.rewards) + forge.data.rewards.LanguageReward() + + self.assertIn("langid is required", str(context.exception)) + self.assertIn("pip install langid", str(context.exception)) + + # Restore langid module if it existed + if langid_module is not None: + sys.modules["langid"] = langid_module + + def test_regex_pattern(self): + """Test that regex pattern is compiled correctly.""" + reward = self.LanguageReward() + self.assertIsNotNone(reward._THINK_BLOCK_RE) + + def test_call_with_english_thinking(self): + """Test __call__ with English text in thinking blocks.""" + response = "<思考>This is English reasoning about math problems." + result = self.reward_en("prompt", response) + self.assertEqual(result, 1.0) + + def test_call_with_japanese_thinking(self): + """Test __call__ with Japanese text in thinking blocks.""" + response = "<思考>これは日本語で考えています。数学の問題を解きます。" + result = self.reward_ja("prompt", response) + self.assertEqual(result, 1.0) + + # English reward should give no_match_reward for Japanese text + result = self.reward_en("prompt", response) + self.assertEqual(result, 0.0) + + def test_call_with_chinese_thinking(self): + """Test __call__ with Chinese text in thinking blocks.""" + response = "<思考>这是中文思考。我们需要解决这个数学问题。" + reward_zh = self.LanguageReward(target_language="zh") + result = reward_zh("prompt", response) + # langid should detect this as Chinese (zh) + self.assertEqual(result, 1.0) + + def test_call_with_spanish_thinking(self): + """Test __call__ with Spanish text in thinking blocks.""" + response = ( + "<思考>Este es un razonamiento en español sobre problemas matemáticos." + ) + reward_es = self.LanguageReward(target_language="es") + result = reward_es("prompt", response) + # langid should detect this as Spanish (es) + self.assertEqual(result, 1.0) + + def test_call_language_mismatch(self): + """Test __call__ when detected language doesn't match target.""" + # Japanese reward with English text + response = "<思考>This is English reasoning." + result = self.reward_ja("prompt", response) + self.assertEqual(result, 0.0) + + # English reward with Japanese text + response = "<思考>これは日本語です。" + result = self.reward_en("prompt", response) + self.assertEqual(result, 0.0) + + def test_call_with_no_thinking_tags(self): + """Test __call__ with response containing no thinking tags but correct language.""" + result = self.reward_en( + "prompt", "This is just a regular response without any thinking tags." + ) + # No thinking blocks -> detect whole response, English detected -> match_reward + self.assertEqual(result, 1.0) + + def test_call_with_no_thinking_tags_wrong_language(self): + """Test __call__ with response containing no thinking tags and wrong language.""" + result = self.reward_en("prompt", "これは日本語の応答です。タグはありません。") + # No thinking blocks -> detect whole response, Japanese detected -> no_match_reward + self.assertEqual(result, 0.0) + + def test_call_with_empty_thinking_block(self): + """Test __call__ with empty thinking block.""" + result = self.reward_en("prompt", "<思考>") + self.assertEqual(result, 0.0) + + def test_call_with_whitespace_only_thinking_block(self): + """Test __call__ with whitespace-only thinking block.""" + result = self.reward_en("prompt", "<思考> \n \t ") + self.assertEqual(result, 0.0) + + def test_call_with_proper_tags(self): + """Test __call__ with properly formatted thinking tags.""" + response = "<思考>This is English reasoning." + result = self.reward_en("prompt", response) + self.assertEqual(result, 1.0) + + # Japanese content should also work + response = "<思考>これは日本語です。" + result = self.reward_ja("prompt", response) + self.assertEqual(result, 1.0) + + def test_call_with_whitespace_in_tags(self): + """Test __call__ with whitespace in thinking tags.""" + response = "< 思考 >This is English reasoning." + result = self.reward_en("prompt", response) + self.assertEqual(result, 1.0) + + def test_call_multiple_thinking_blocks(self): + """Test __call__ with multiple thinking blocks - detects whole response language.""" + response = """ + <思考>First thought in English. + Some text in between. + <思考>Second thought also in English. + """ + result = self.reward_en("prompt", response) + # Multiple blocks -> detect whole response, English detected -> match_reward + self.assertEqual(result, 1.0) + + def test_call_multiple_thinking_blocks_mixed_languages(self): + """Test __call__ with multiple thinking blocks in different languages.""" + response = """ + <思考>First thought in English with lots of content here. + <思考>これは短い日本語。 + """ + result = self.reward_en("prompt", response) + # Multiple blocks -> detect whole response, langid will detect dominant language + # Should return match_reward (1.0) if English dominant, or no_match_reward (0.0) if not + self.assertIn(result, [0.0, 1.0]) + + def test_call_multiline_thinking_block(self): + """Test __call__ with multiline thinking blocks.""" + response = """<思考> + This is a multiline + thinking block with + lots of English content + about solving problems + """ + result = self.reward_en("prompt", response) + self.assertEqual(result, 1.0) + + def test_call_empty_response(self): + """Test __call__ with empty response.""" + result = self.reward_en("prompt", "") + self.assertEqual(result, 0.0) + + def test_call_none_response(self): + """Test __call__ with None response.""" + result = self.reward_en("prompt", None) + self.assertEqual(result, 0.0) + + def test_call_with_target_parameter(self): + """Test __call__ with target parameter (should be ignored).""" + response = "<思考>This is English reasoning." + result = self.reward_en("prompt", response, target="some target") + self.assertEqual(result, 1.0) + + # English text without tags -> detect whole response -> match_reward + result = self.reward_en( + "prompt", + "This is a response without thinking tags but in English language.", + target="some target", + ) + self.assertEqual(result, 1.0) + + def test_call_custom_reward_values(self): + """Test __call__ with custom reward values.""" + response_ja_single = "<思考>これは日本語です。" + response_ja_multiple = "<思考>最初の考え。<思考>次の考え。" + response_ja_no_tags = "これはタグなしの日本語です。" + response_en = "<思考>This is English." + response_none = "" + + # Test custom match reward (single block, correct language) + self.assertEqual(self.custom_reward("prompt", response_ja_single), 0.9) + # Test custom match reward (multiple blocks -> whole response, correct language) + self.assertEqual(self.custom_reward("prompt", response_ja_multiple), 0.9) + # Test custom match reward (no blocks -> whole response, correct language) + self.assertEqual(self.custom_reward("prompt", response_ja_no_tags), 0.9) + # Test custom no_match reward (wrong language) + self.assertEqual(self.custom_reward("prompt", response_en), 0.1) + # Test empty response + self.assertEqual(self.custom_reward("prompt", response_none), 0.1) + + def test_call_zero_custom_values(self): + """Test __call__ with zero custom values.""" + zero_reward = self.LanguageReward( + target_language="en", match_reward=0.0, no_match_reward=0.0 + ) + result = zero_reward("prompt", "<思考>This is English.") + self.assertEqual(result, 0.0) + + def test_call_with_special_characters(self): + """Test __call__ with special characters in thinking blocks.""" + response = ( + "<思考>English with special chars: @#$%^&*()_+-=[]{}|;':\",./<>?`~" + ) + result = self.reward_en("prompt", response) + self.assertEqual(result, 1.0) + + def test_call_with_mixed_content_outside_tags(self): + """Test __call__ with mixed language content outside thinking tags.""" + # Content outside think tags should be ignored + response = """ + これは日本語のテキストです。 + <思考>But this is English reasoning inside the tags. + もっと日本語のテキスト。 + """ + result = self.reward_en("prompt", response) + # Should detect English from thinking block only + self.assertEqual(result, 1.0) + + def test_call_with_numbers_and_symbols(self): + """Test __call__ with thinking blocks containing mostly numbers.""" + response = "<思考>Calculate: 2 + 2 = 4, then 4 * 3 = 12" + result = self.reward_en("prompt", response) + # Should still detect as English due to words like "Calculate" and "then" + self.assertEqual(result, 1.0) + + def test_call_very_long_thinking_block(self): + """Test __call__ with very long thinking blocks.""" + long_content = "This is English content. " * 1000 + result = self.reward_en("prompt", f"<思考>{long_content}") + self.assertEqual(result, 1.0) + + def test_call_with_code_in_thinking(self): + """Test __call__ with code snippets in thinking blocks.""" + response = """<思考> + Let me write some Python code to solve this: + def calculate(x): + return x * 2 + The function doubles the input value. + """ + result = self.reward_en("prompt", response) + # Should detect as English due to surrounding text + self.assertEqual(result, 1.0) + + def test_different_language_codes(self): + """Test __call__ with various ISO 639-1 language codes.""" + # Test a few common languages + languages = { + "fr": "Ceci est un texte en français avec beaucoup de contenu.", + "de": "Dies ist ein deutscher Text mit viel Inhalt.", + "it": "Questo è un testo italiano con molto contenuto.", + "pt": "Este é um texto em português com muito conteúdo.", + } + + for lang_code, text in languages.items(): + reward = self.LanguageReward(target_language=lang_code) + response = f"<思考>{text}" + result = reward("prompt", response) + # langid should detect these correctly + self.assertEqual( + result, + 1.0, + f"Failed to detect {lang_code} language: '{text[:50]}...'", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/rl/test_thinking_reward.py b/tests/unit_tests/rl/test_thinking_reward.py index b95823e9a..10b7bf38e 100644 --- a/tests/unit_tests/rl/test_thinking_reward.py +++ b/tests/unit_tests/rl/test_thinking_reward.py @@ -203,6 +203,19 @@ def test_call_very_long_thinking_block(self): result = self.reward("prompt", f"{long_content}") self.assertEqual(result, 1.0) + def test_custom_tag(self): + """Test that ThinkingReward uses the custom tag passed in.""" + # Create reward with custom Japanese tag + custom_tag_reward = ThinkingReward(tag="思考") + + # Response with custom tag should get full reward + result = custom_tag_reward("prompt", "<思考>This is my reasoning") + self.assertEqual(result, 1.0) + + # Response with default "think" tag should get no reward + result = custom_tag_reward("prompt", "This is my reasoning") + self.assertEqual(result, 0.0) + if __name__ == "__main__": unittest.main()