diff --git a/apps/toy_rl/sumdigits.py b/apps/toy_rl/sumdigits.py index 6b1d8d763..3a38c9ff0 100644 --- a/apps/toy_rl/sumdigits.py +++ b/apps/toy_rl/sumdigits.py @@ -10,7 +10,6 @@ import random import time import uuid -from collections.abc import Iterator from dataclasses import dataclass from typing import Any @@ -23,18 +22,27 @@ from forge.cli.config import parse from forge.controller.actor import ForgeActor from forge.controller.provisioner import shutdown - -from forge.losses.reinforce_loss import ReinforceLoss +from forge.losses.grpo_loss import SimpleGRPOLoss from forge.util.metric_logging import get_metric_logger + +from forge.util.ops import selective_log_softmax from monarch.actor import endpoint from omegaconf import DictConfig -from torch.utils.data import IterableDataset from torchstore.state_dict_utils import DELIM from transformers import AutoModelForCausalLM from vllm.transformers_utils.tokenizer import get_tokenizer +def pad_sequence( + tensor: torch.Tensor, target_len: int, pad_value: float = 0.0 +) -> torch.Tensor: + diff = target_len - tensor.size(0) + if diff > 0: + return F.pad(tensor, (0, diff), value=pad_value) + return tensor + + # TODO: Episode and Group and duplicated and needs clean up. @dataclass class Episode: @@ -199,6 +207,45 @@ def new_group( return cls(str(group_id), episodes) +class RefModel(ForgeActor): + def __init__(self, model_name, device: torch.device | None = None): + super().__init__() + self.model_name = model_name + + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = device + + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + dtype=torch.bfloat16, + trust_remote_code=True, + ).to(self.device) + self.model.eval() + + self.logger.info(f"Model initialized on {self.device}") + + @endpoint + async def forward(self, episode: Episode) -> torch.Tensor: + input_ids = ( + pad_sequence(episode.input_ids, episode.max_seq_len - 1, episode.pad_id) + .to(self.device) + .unsqueeze(0) + ) + target_ids = ( + pad_sequence(episode.target_ids, episode.max_seq_len - 1, episode.pad_id) + .to(self.device) + .unsqueeze(0) + ) + mask = input_ids != episode.pad_id + + with torch.inference_mode(): + logits = self.model(input_ids=input_ids, attention_mask=mask).logits + + return selective_log_softmax(logits, target_ids).squeeze(0) + + @dataclass class Trainer(ForgeActor): """Reinforce Loss Trainer implementation for policy optimization.""" @@ -224,7 +271,9 @@ async def setup(self): self.model.parameters(), lr=self.learning_rate ) self.optimizer.zero_grad() - self.loss = ReinforceLoss() + + # beta = 0.01 for quicker convergence + self.loss = SimpleGRPOLoss(0.01) self.logger.info(f"Trainer model initialized on {self.device}") @endpoint @@ -238,14 +287,16 @@ def train_step(self, episodes: list[Episode]) -> float: batch_loss_masks = [] batch_weights = [] batch_sampling_log_probs = [] + batch_ref_logprobs = [] for episode in episodes: - input_ids = self.pad_sequence(episode.input_ids, max_seq_len, pad_id) - target_ids = self.pad_sequence(episode.target_ids, max_seq_len, pad_id) - loss_mask = self.pad_sequence(episode.loss_mask, max_seq_len, 0.0) - sampling_log_probs = self.pad_sequence( + input_ids = pad_sequence(episode.input_ids, max_seq_len, pad_id) + target_ids = pad_sequence(episode.target_ids, max_seq_len, pad_id) + loss_mask = pad_sequence(episode.loss_mask, max_seq_len, 0.0) + sampling_log_probs = pad_sequence( episode.sampling_log_probs, max_seq_len, 0.0 ) - weights = self.pad_sequence(episode.weighted_advantages, max_seq_len, 0.0) + weights = pad_sequence(episode.weighted_advantages, max_seq_len, 0.0) + ref_logprobs = episode.ref_logprobs # Exclude padded response tokens from loss valid_mask = target_ids != pad_id @@ -258,6 +309,7 @@ def train_step(self, episodes: list[Episode]) -> float: batch_loss_masks.append(loss_mask) batch_weights.append(weights) batch_sampling_log_probs.append(sampling_log_probs) + batch_ref_logprobs.append(ref_logprobs) # Stack into batched tensors input_ids = torch.stack(batch_input_ids).to(self.device) @@ -265,6 +317,7 @@ def train_step(self, episodes: list[Episode]) -> float: loss_masks = torch.stack(batch_loss_masks).to(self.device) weights = torch.stack(batch_weights).to(self.device) sampling_log_probs = torch.stack(batch_sampling_log_probs).to(self.device) + ref_logprobs = torch.stack(batch_ref_logprobs).to(self.device) # Create attention mask attention_mask = input_ids != pad_id @@ -272,8 +325,10 @@ def train_step(self, episodes: list[Episode]) -> float: # Forward pass logits = self.model(input_ids=input_ids, attention_mask=attention_mask).logits + trainer_log_probs = selective_log_softmax(logits, target_ids) # Compute loss only on response tokens - loss = self.loss(logits, target_ids, loss_masks, weights, sampling_log_probs) + # loss = self.loss(logits, target_ids, loss_masks, weights, sampling_log_probs) + loss = self.loss(trainer_log_probs, ref_logprobs, weights, loss_masks) loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) @@ -296,14 +351,6 @@ async def push_weights(self, version: int): f"Pushed weights to {key} in {end_time - start_time:.2f} seconds" ) - def pad_sequence( - self, tensor: torch.Tensor, target_len: int, pad_value: float = 0.0 - ) -> torch.Tensor: - diff = target_len - tensor.size(0) - if diff > 0: - return F.pad(tensor, (0, diff), value=pad_value) - return tensor - @dataclass class RewardActor(ForgeActor): @@ -311,53 +358,59 @@ class RewardActor(ForgeActor): @endpoint async def evaluate_response(self, prompt: str, response: str, target: str) -> float: - if response == target: - return 1.0 - return 0.0 + reward = 1.0 if response.strip() == "10" else 0.0 + return reward @dataclass -class SumDigitsDataset(IterableDataset): +class SumDigitsDataset: def __init__(self, tokenizer, max_samples=1000): - self.min_digit_length = 2 - self.max_digit_length = 3 self.max_numbers = max_samples - self.data = self.generate_random_number() self._tokenizer = tokenizer - def __iter__(self) -> Iterator[Any]: - for data in self.data: - answer = str(sum(int(x) for x in data)) - system_prompt = """ - A conversation between User and Assistant. The user asks a question, and the Assistant solves it. - The assistant only gives very concise answers. - """ - request: str = f"What is the sum of the digits of {data}" - as_chat = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": request}, - ] - formatted_request = self._tokenizer.apply_chat_template( - as_chat, - tokenize=False, - add_generation_prompt=True, - ) - yield { - "question": formatted_request, - "request": formatted_request, - "answer": answer, - "target": answer, - } - - def generate_random_number(self) -> Iterator[str]: - while True: - yield self.generate_one() + def generate_sample(self, step: int) -> dict[str, str]: + """Generate a single sample based on training step for progressive difficulty.""" + data = self.generate_one(step) + answer = str(sum(int(x) for x in data)) - def generate_one(self) -> str: - return "".join( - str(random.randint(0, 9)) - for _ in range(random.randint(self.min_digit_length, self.max_digit_length)) + system_prompt = """ + A conversation between User and Assistant. The user asks a question, and the Assistant solves it. + The assistant only gives very concise answers (just the number, no explanation). + """ + request: str = f"What is the sum of the digits of {data}" + as_chat = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": request}, + ] + formatted_request = self._tokenizer.apply_chat_template( + as_chat, + tokenize=False, + add_generation_prompt=True, ) + return { + "question": formatted_request, + "request": formatted_request, + "answer": answer, + "target": answer, + } + + def generate_one(self, step: int) -> str: + """Generate number based on training step for curriculum learning.""" + if step < 200: + # Early training: 2-digit numbers (10-99) + min_val, max_val = 10, 99 + elif step < 1000: + # Later training: 1-4 digit numbers (0-1000) + min_val, max_val = 0, 1000 + elif step < 3000: + # Later training: 1-6 digit numbers (0-100000) + min_val, max_val = 0, 100000 + else: + # Later training: 1-8 digit numbers (0-10000000) + min_val, max_val = 0, 10000000 + + number = random.randint(min_val, max_val) + return str(number) @dataclass @@ -369,13 +422,15 @@ class DatasetActor(ForgeActor): @endpoint def setup(self): self._tokenizer = get_tokenizer(self.model) - self._iterator = iter(SumDigitsDataset(self._tokenizer)) + self._dataset = SumDigitsDataset(self._tokenizer) @endpoint - async def sample(self) -> dict[str, str] | None: + async def sample(self, step: int = 0) -> dict[str, str] | None: + """Sample with progressive difficulty based on training step.""" try: - return next(self._iterator) - except StopIteration: + return self._dataset.generate_sample(step) + except Exception as e: + self.logger.error(f"Error generating sample: {e}") return None @endpoint @@ -397,7 +452,14 @@ async def main(cfg: DictConfig): # ---- Setup services ---- # await ts.initialize() - (dataloader, policy, trainer, replay_buffer, reward_actor,) = await asyncio.gather( + ( + dataloader, + policy, + trainer, + replay_buffer, + reward_actor, + ref_model, + ) = await asyncio.gather( DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset), Policy.options(**cfg.services.policy).as_service(**cfg.policy), Trainer.options(**cfg.services.trainer).as_service(**cfg.trainer), @@ -405,6 +467,7 @@ async def main(cfg: DictConfig): **cfg.replay_buffer ), RewardActor.options(**cfg.services.reward_actor).as_service(), + RefModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model), ) print("All services initialized successfully!") @@ -414,7 +477,8 @@ async def continuous_rollouts(): rollout_count = 0 pad_id = await dataloader.pad_token.choose() while True: - sample = await dataloader.sample.choose() + # Pass rollout_count for curriculum learning + sample = await dataloader.sample.choose(rollout_count) if sample is None: print("Dataloader is empty, exiting continuous rollout") return @@ -446,13 +510,13 @@ async def continuous_rollouts(): ) ] ) + episode.ref_logprobs = await ref_model.forward.choose(episode) episode.reward = await reward_actor.evaluate_response.choose( prompt=prompt, response=response.text, target=target ) episode.advantage = episode.reward # simple case for now for episode in group.episodes: await replay_buffer.add.choose(episode) - avg_response_len = ( sum(len(e.response_tokens) for e in group.episodes) / group_size ) diff --git a/apps/toy_rl/sumdigits.yaml b/apps/toy_rl/sumdigits.yaml index f97cb7e75..1389daeb7 100644 --- a/apps/toy_rl/sumdigits.yaml +++ b/apps/toy_rl/sumdigits.yaml @@ -1,10 +1,10 @@ # Toy app Training Configuration # Global configuration -group_size: 8 -batch_size: 16 -max_req_tokens: 512 -max_res_tokens: 512 +group_size: 6 +batch_size: 12 +max_req_tokens: 64 +max_res_tokens: 64 model: "Qwen/Qwen2.5-0.5B-Instruct" # Dataset configuration @@ -60,3 +60,7 @@ services: procs_per_replica: 1 num_replicas: 1 with_gpus: false + ref_model: + procs_per_replica: 1 + num_replicas: 1 + with_gpus: true diff --git a/src/forge/losses/reinforce_loss.py b/src/forge/losses/reinforce_loss.py index 9f0595d96..c2dedd530 100644 --- a/src/forge/losses/reinforce_loss.py +++ b/src/forge/losses/reinforce_loss.py @@ -5,9 +5,10 @@ # LICENSE file in the root directory of this source tree. import torch -import torch.nn.functional as F from torch import nn +from forge.util.ops import selective_log_softmax + class ReinforceLoss(nn.Module): """Reinforce loss function with optional importance ratio clipping. @@ -28,7 +29,7 @@ def __init__(self): def forward( self, trainer_logits, target_ids, target_mask, target_weights, target_log_probs ): - trainer_log_probs = self.selective_log_softmax(trainer_logits, target_ids) + trainer_log_probs = selective_log_softmax(trainer_logits, target_ids) target_mask = target_mask.detach() target_weights = target_weights target_mask_sum = target_mask.sum() @@ -47,47 +48,3 @@ def forward( denominator = target_mask_sum return numerator / denominator - - def selective_log_softmax(self, logits, index) -> torch.Tensor: - """ - A memory-efficient implementation of the common `log_softmax -> gather` operation. - - This function is equivalent to the following naive implementation: - ```python - logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) - ``` - - Args: - logits (`torch.Tensor`): - Logits tensor of shape `(..., num_classes)`. - index (`torch.Tensor`): - Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. - - Returns: - `torch.Tensor`: - Gathered log probabilities with the same shape as `index`. - """ - if logits.dtype in [torch.float32, torch.float64]: - selected_logits = torch.gather( - logits, dim=-1, index=index.unsqueeze(-1) - ).squeeze(-1) - # loop to reduce peak mem consumption - logsumexp_values = torch.stack( - [torch.logsumexp(lg, dim=-1) for lg in logits] - ) - per_token_logps = ( - selected_logits - logsumexp_values - ) # log_softmax(x_i) = x_i - logsumexp(x) - else: - # logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach - per_token_logps = [] - for row_logits, row_labels in zip( - logits, index - ): # loop to reduce peak mem consumption - row_logps = F.log_softmax(row_logits, dim=-1) - row_per_token_logps = row_logps.gather( - dim=-1, index=row_labels.unsqueeze(-1) - ).squeeze(-1) - per_token_logps.append(row_per_token_logps) - per_token_logps = torch.stack(per_token_logps) - return per_token_logps diff --git a/src/forge/util/ops.py b/src/forge/util/ops.py new file mode 100644 index 000000000..49044b33f --- /dev/null +++ b/src/forge/util/ops.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + + +def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor: + """ + A memory-efficient implementation of the common `log_softmax -> gather` operation. + + This function is equivalent to the following naive implementation: + ```python + logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) + ``` + + Args: + logits (`torch.Tensor`): + Logits tensor of shape `(..., num_classes)`. + index (`torch.Tensor`): + Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. + + Returns: + `torch.Tensor`: + Gathered log probabilities with the same shape as `index`. + """ + if logits.dtype in [torch.float32, torch.float64]: + selected_logits = torch.gather( + logits, dim=-1, index=index.unsqueeze(-1) + ).squeeze(-1) + # loop to reduce peak mem consumption + logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + per_token_logps = ( + selected_logits - logsumexp_values + ) # log_softmax(x_i) = x_i - logsumexp(x) + else: + # logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach + per_token_logps = [] + for row_logits, row_labels in zip( + logits, index + ): # loop to reduce peak mem consumption + row_logps = F.log_softmax(row_logits, dim=-1) + row_per_token_logps = row_logps.gather( + dim=-1, index=row_labels.unsqueeze(-1) + ).squeeze(-1) + per_token_logps.append(row_per_token_logps) + per_token_logps = torch.stack(per_token_logps) + return per_token_logps diff --git a/tests/unit_tests/util/__init__.py b/tests/unit_tests/util/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/tests/unit_tests/util/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/unit_tests/util/test_ops.py b/tests/unit_tests/util/test_ops.py new file mode 100644 index 000000000..834de3199 --- /dev/null +++ b/tests/unit_tests/util/test_ops.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +import torch.nn.functional as F +from forge.util.ops import selective_log_softmax + + +class TestOps: + @pytest.mark.timeout(10) + def test_basic_2d(self): + """Test basic 2D case.""" + logits = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + index = torch.tensor([0, 2]) # Select positions 0 and 2 + result = selective_log_softmax(logits, index) + # Compare with torch's implementation + expected = torch.gather( + F.log_softmax(logits, dim=-1), dim=-1, index=index.unsqueeze(-1) + ).squeeze(-1) + assert torch.allclose(result, expected, atol=1e-5) + assert result.shape == (2,) # Same shape as index + + @pytest.mark.timeout(10) + def test_single_row(self): + """Test with single row.""" + logits = torch.tensor([[1.0, 2.0, 3.0]]) + index = torch.tensor([1]) # Select middle element + result = selective_log_softmax(logits, index) + # Manual calculation: log_softmax then select index 1 + log_probs = F.log_softmax(logits, dim=-1) + expected = log_probs[0, 1] + assert torch.allclose(result, expected) + assert result.shape == (1,) + + @pytest.mark.timeout(10) + def test_different_dtypes(self): + """Test with different data types.""" + logits_f32 = torch.randn(2, 4, dtype=torch.float32) + logits_bf16 = torch.randn(2, 4, dtype=torch.bfloat16) + index = torch.tensor([0, 3]) + result_f32 = selective_log_softmax(logits_f32, index) + result_bf16 = selective_log_softmax(logits_bf16, index) + # Check output dtypes match input dtypes + assert result_f32.dtype == torch.float32 + assert result_bf16.dtype == torch.bfloat16 + # Check shapes + assert result_f32.shape == (2,) + assert result_bf16.shape == (2,) + + @pytest.mark.timeout(10) + def test_3d_tensor(self): + """Test with 3D tensor.""" + batch, seq, vocab = 2, 3, 5 + logits = torch.randn(batch, seq, vocab) + index = torch.randint(0, vocab, (batch, seq)) + result = selective_log_softmax(logits, index) + # Should have same shape as index + assert result.shape == (batch, seq) + # All values should be negative (log probabilities) + assert (result <= 0).all() + + @pytest.mark.timeout(10) + def test_known_values(self): + """Test with known values for manual verification.""" + # Simple case where we can calculate by hand + logits = torch.tensor([[0.0, 0.0]]) # Equal logits + index = torch.tensor([0]) + result = selective_log_softmax(logits, index) + # log_softmax of [0, 0] gives [-log(2), -log(2)] + # Selecting index 0 should give -log(2) + expected = -torch.log(torch.tensor(2.0)) + assert torch.allclose(result, expected, atol=1e-6) + + @pytest.mark.timeout(10) + def test_edge_cases(self): + """Test edge cases.""" + # Test with single class + logits = torch.tensor([[5.0]]) + index = torch.tensor([0]) + result = selective_log_softmax(logits, index) + # log_softmax of single element is 0 + assert torch.allclose(result, torch.tensor([0.0])) + # Test with large values (numerical stability) + logits = torch.tensor([[100.0, 200.0]]) + index = torch.tensor([1]) + result = selective_log_softmax(logits, index) + # Should not be NaN or inf + assert torch.isfinite(result).all()