-
Notifications
You must be signed in to change notification settings - Fork 1.1k
add RAFT #363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add RAFT #363
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -200,6 +200,10 @@ def _validate(self): | |
| return test_metrics | ||
|
|
||
| def _train_step(self, batch_dict: dict) -> dict: | ||
| # Check if RAFT mode is enabled | ||
| if self.config.algorithm.adv_estimator == "raft": | ||
| return self._train_step_raft(batch_dict) | ||
|
|
||
| # Isolate in a separate method to automatically recycle the variables before validation. | ||
| batch: DataProto = DataProto.from_single_dict(batch_dict) | ||
| metrics = {} | ||
|
|
@@ -388,6 +392,251 @@ def _train_step(self, batch_dict: dict) -> dict: | |
|
|
||
| return metrics | ||
|
|
||
| def _train_step_raft(self, batch_dict: dict) -> dict: | ||
| """ | ||
| RAFT training step: Simplified training loop that only trains on r=1 samples. | ||
|
|
||
| RAFT (Rejection sampling Adaptive Fine-Tuning) differs from GRPO/PPO by: | ||
| 1. Rejection sampling: Only keeping samples with reward r=1 | ||
| 2. Simple loss: Using standard cross-entropy (NLL) loss instead of advantage-weighted loss | ||
| 3. No critic: No value function estimation needed | ||
| 4. No advantage: No advantage function or GAE computation needed | ||
| """ | ||
| batch: DataProto = DataProto.from_single_dict(batch_dict) | ||
| metrics = {} | ||
| timing_raw = {} | ||
|
|
||
| with _timer("step", timing_raw): | ||
| # When agent mode is enabled, we read the batch as it is. | ||
| gen_batch = batch | ||
|
|
||
| # Generate rollouts and collect data | ||
| with _timer("gen", timing_raw): | ||
| self.async_rollout_manager.wake_up() | ||
| self.agent_mode_daemon.set_up_data_and_server( | ||
| gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses | ||
| ) | ||
| self.agent_mode_daemon.run_until_all_finished() | ||
| batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch( | ||
| max_prompt_length=self.config.data.max_prompt_length, | ||
| max_response_length=self.config.data.max_response_length, | ||
| device=gen_batch.batch["fake_ids"].device, | ||
| ) | ||
| metrics.update(agent_metrics) | ||
| self.agent_mode_daemon.clear_data_and_server() | ||
| self.async_rollout_manager.sleep() | ||
|
|
||
| # RAFT Step 1: Rejection Sampling - Filter to keep only r=1 samples | ||
| with _timer("rejection_sampling", timing_raw): | ||
| # Extract rewards from token_level_scores (sum to get sequence-level reward) | ||
| # The reward is stored at the last token position in token_level_scores | ||
| sequence_rewards = batch.batch["token_level_scores"].sum(dim=-1) # (batch_size,) | ||
|
|
||
| # Binary reward: 1.0 for success, 0.0 for failure | ||
| # In RAFT, we only keep samples with reward == 1.0 | ||
| is_positive_reward = (sequence_rewards == 1.0) | ||
| positive_indices = is_positive_reward.nonzero(as_tuple=True)[0] | ||
|
|
||
| # Log rejection sampling statistics | ||
| n_total = len(batch) | ||
| n_positive = len(positive_indices) | ||
| n_rejected = n_total - n_positive | ||
| metrics["raft/n_total_samples"] = n_total | ||
| metrics["raft/n_positive_samples"] = n_positive | ||
| metrics["raft/n_rejected_samples"] = n_rejected | ||
| metrics["raft/rejection_rate"] = n_rejected / n_total if n_total > 0 else 0.0 | ||
| metrics["raft/positive_rate"] = n_positive / n_total if n_total > 0 else 0.0 | ||
|
|
||
| # If no positive samples, skip this training step | ||
| if n_positive == 0: | ||
| metrics["raft/loss"] = 0.0 | ||
| metrics["raft/skipped_no_positive_samples"] = 1 | ||
| return metrics | ||
|
|
||
| # Filter batch to keep only positive samples | ||
| positive_batch = batch[positive_indices.cpu().tolist()] | ||
|
|
||
| # RAFT Step 2: Compute response mask for the filtered batch | ||
| positive_batch.batch["response_mask"] = compute_response_mask(positive_batch) | ||
|
|
||
| # Set uid (required by update_actor, similar to GRPO) | ||
| # uid is used for algorithm like GRPO, should be aligned to data id | ||
| if "data_id_list" in positive_batch.non_tensor_batch: | ||
| positive_batch.non_tensor_batch["uid"] = positive_batch.non_tensor_batch["data_id_list"] | ||
|
|
||
| # Drop samples with prompts that are too long | ||
| keep_indices = (~positive_batch.batch["is_drop_mask"]).nonzero(as_tuple=True)[0] | ||
| metrics["raft/n_triplets_prompt_too_long"] = ( | ||
| positive_batch.batch["is_drop_mask"].shape[0] - keep_indices.shape[0] | ||
| ) | ||
| if len(keep_indices) == 0: | ||
| metrics["raft/loss"] = 0.0 | ||
| metrics["raft/skipped_all_dropped"] = 1 | ||
| return metrics | ||
| positive_batch = positive_batch[keep_indices] | ||
|
|
||
| # Round to mini batch size for efficient training | ||
| mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size | ||
| n_transition = len(positive_batch) | ||
| random_indices = list(range(n_transition)) | ||
| random.shuffle(random_indices) | ||
| positive_batch.reorder(torch.tensor(random_indices).type(torch.int32)) | ||
| n_remained_transition = n_transition // mini_batch_size * mini_batch_size | ||
| positive_batch = positive_batch[list(range(n_remained_transition))] | ||
| metrics["raft/n_triplets_dropped_remainder"] = n_transition - n_remained_transition | ||
|
|
||
| # Balance batch if enabled | ||
| if self.config.trainer.balance_batch: | ||
| self._balance_batch(positive_batch, metrics=metrics) | ||
|
|
||
| # Pad batch for distributed training | ||
| positive_batch, pad_size = pad_dataproto_to_divisor(positive_batch, self.actor_rollout_wg.world_size) | ||
|
|
||
| # RAFT Step 3: Prepare batch for RAFT loss computation | ||
| # Remove advantage-related fields since RAFT doesn't use them | ||
| raft_batch = positive_batch | ||
| max_response_length = raft_batch.batch["responses"].shape[-1] | ||
|
|
||
| # Unpad before computing loss | ||
| raft_batch = unpad_dataproto(raft_batch, pad_size=pad_size) | ||
|
|
||
| # RAFT Step 4: Prepare batch for actor update | ||
| # Need to compute old_log_probs and set required meta_info fields | ||
| with _timer("prepare_raft_batch", timing_raw): | ||
| # Ensure uid is set (may have been lost during filtering) | ||
| if "data_id_list" in raft_batch.non_tensor_batch: | ||
| raft_batch.non_tensor_batch["uid"] = raft_batch.non_tensor_batch["data_id_list"] | ||
|
|
||
| # Compute global_token_num (required by update_actor) | ||
| raft_batch.meta_info["global_token_num"] = torch.sum(raft_batch.batch["attention_mask"], dim=-1).tolist() | ||
|
|
||
| # Pad batch for distributed training before computing log_probs | ||
| raft_batch, pad_size_prep = pad_dataproto_to_divisor(raft_batch, self.actor_rollout_wg.world_size) | ||
|
|
||
| # Compute old_log_probs (required by update_actor, similar to GRPO) | ||
| # This is needed even for RAFT because update_actor expects this field | ||
| old_log_prob = self.actor_rollout_wg.compute_log_prob(raft_batch) | ||
| entropys = old_log_prob.batch["entropys"] | ||
| response_masks = raft_batch.batch["response_mask"] | ||
| loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode | ||
| entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) | ||
| old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} | ||
| metrics.update(old_log_prob_metrics) | ||
| old_log_prob.batch.pop("entropys") | ||
| raft_batch = raft_batch.union(old_log_prob) | ||
|
|
||
| # Set required meta_info fields (similar to GRPO) | ||
| raft_batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable | ||
| # Temperature is required by update_actor (from config or default 0.7) | ||
| raft_batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.get("temperature", 0.7) | ||
|
|
||
| # Unpad before setting advantages | ||
| raft_batch = unpad_dataproto(raft_batch, pad_size=pad_size_prep) | ||
|
|
||
| # RAFT Step 5: Pure SFT update | ||
| # Use standard cross-entropy loss like TRL's SFTTrainer | ||
| # Reference: trl/trainer/sft_trainer.py compute_loss method | ||
| with _timer("update_actor_sft", timing_raw): | ||
| # Prepare inputs for SFT loss computation (like SFTTrainer) | ||
| # SFTTrainer expects: input_ids, attention_mask, labels | ||
| input_ids = raft_batch.batch["input_ids"] # (batch_size, seq_len) | ||
| attention_mask = raft_batch.batch["attention_mask"] # (batch_size, seq_len) | ||
|
|
||
| # Create labels: -100 for prompt tokens (ignore in loss), actual token IDs for response | ||
| # This matches SFTTrainer's label format | ||
| labels = input_ids.clone() | ||
| # Shift labels for next-token prediction: predict token[t] given tokens[<t] | ||
| labels[:, :-1] = input_ids[:, 1:].clone() | ||
| labels[:, -1] = -100 # Last token has no next token | ||
|
|
||
| # Mask prompt tokens with -100 (they will be ignored in loss) | ||
| # Only compute loss on response tokens | ||
| prompt_length = input_ids.shape[-1] - max_response_length | ||
| labels[:, :prompt_length] = -100 | ||
|
|
||
| # Also mask padding tokens | ||
| labels = labels.masked_fill(~attention_mask.bool(), -100) | ||
|
|
||
| # Set advantages to 1.0 (no advantage weighting, pure SFT) | ||
| # This makes the PPO loss equivalent to standard cross-entropy | ||
| raft_batch.batch["advantages"] = torch.ones( | ||
| (len(raft_batch), max_response_length), | ||
| device=input_ids.device, | ||
| dtype=torch.float32 | ||
| ) | ||
| raft_batch.batch["returns"] = raft_batch.batch["advantages"].clone() | ||
|
|
||
| # Store labels in batch for potential use | ||
| raft_batch.batch["labels"] = labels | ||
|
||
|
|
||
| # Remove any existing values field (no critic in RAFT) | ||
| if "values" in raft_batch.batch: | ||
| raft_batch.batch.pop("values") | ||
|
|
||
| # Pad again for distributed training before update_actor | ||
| raft_batch, pad_size_actor = pad_dataproto_to_divisor(raft_batch, self.actor_rollout_wg.world_size) | ||
|
|
||
| # Temporarily disable PPO clipping for pure SFT (like SFTTrainer) | ||
| # Set clip_ratio to 1.0 effectively disables clipping | ||
| original_clip_low = self.config.actor_rollout_ref.actor.get("clip_ratio_low", 0.2) | ||
| original_clip_high = self.config.actor_rollout_ref.actor.get("clip_ratio_high", 0.3) | ||
|
|
||
| # Disable clipping: set both ratios to 1.0 (no clipping in pure SFT) | ||
|
||
| self.config.actor_rollout_ref.actor["clip_ratio_low"] = 1.0 | ||
| self.config.actor_rollout_ref.actor["clip_ratio_high"] = 1.0 | ||
|
|
||
| try: | ||
| # Update actor with pure SFT loss | ||
| # With advantages=1.0 and clip_ratio=1.0, this becomes standard cross-entropy | ||
| # This mimics SFTTrainer.compute_loss() behavior | ||
| actor_output = self.actor_rollout_wg.update_actor(raft_batch) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the called update_actor function is still the RL one, not the one in "SFTTrainer"? not sure if verl has SFTTrainer...
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Verl does not have a SFTrainer. SFTrainer inherits from transformers.Trainer and requires the complete HuggingFace Trainer infrastructure, whereas VERL uses Ray distributed training and a custom worker group. Directly adopting SFTrainer would disrupt VERL's existing architecture. Additionally, SFTrainer and VERL use different data formats. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sense. Then why SFT can be implemented in this way? Will update_actor actually compute SFT loss as we expected?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Here is the derivation of SFT from PPO: The standard PPO objective (with clipping) is defined as: To adapt the PPO framework for Supervised Fine-Tuning (SFT) mode, we set the following neutral conditions: Set advantages: A=1 Therefore: In SFT, we directly optimize the current policy without relying on importance sampling. When the old policy is equal to (or very close to) the current policy, the PPO objective is replaced with the Negative Log-Likelihood loss, which is what we want to minimize: Summary, by: Setting A=1.0$to remove advantage weighting. |
||
| actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) | ||
| metrics.update(actor_output_metrics) | ||
|
|
||
| # Extract and log the SFT loss | ||
| # Use actor loss from update_actor output (same as GRPO) | ||
| # Note: update_actor returns "actor/pg_loss" not "actor/loss" | ||
| if "actor/pg_loss" in actor_output_metrics: | ||
| metrics["raft/loss"] = actor_output_metrics["actor/pg_loss"] | ||
| elif "actor/loss" in actor_output_metrics: | ||
| metrics["raft/loss"] = actor_output_metrics["actor/loss"] | ||
| else: | ||
| # Fallback: use a default value if loss not found | ||
| metrics["raft/loss"] = 0.0 | ||
| finally: | ||
| # Restore original clipping ratios | ||
| self.config.actor_rollout_ref.actor["clip_ratio_low"] = original_clip_low | ||
| self.config.actor_rollout_ref.actor["clip_ratio_high"] = original_clip_high | ||
|
|
||
| # Log that we're using pure SFT update (like SFTTrainer) | ||
| metrics["raft/pure_sft_update"] = 1.0 | ||
|
|
||
| # Log rollout generations if enabled | ||
| rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) | ||
| if rollout_data_dir: | ||
| with _timer("dump_rollout_generations", timing_raw): | ||
| # Unpad for logging | ||
| log_batch = unpad_dataproto(raft_batch, pad_size_actor) | ||
| inputs = self.tokenizer.batch_decode(log_batch.batch["prompts"], skip_special_tokens=True) | ||
| outputs = self.tokenizer.batch_decode(log_batch.batch["responses"], skip_special_tokens=True) | ||
| # Get scores from the filtered batch | ||
| log_scores = log_batch.batch["token_level_scores"].sum(dim=-1).cpu().tolist() | ||
| self._dump_generations( | ||
| inputs=inputs, | ||
| outputs=outputs, | ||
| scores=log_scores, | ||
| reward_extra_infos_dict={}, | ||
| dump_path=rollout_data_dir, | ||
| ) | ||
|
|
||
| # Compute training metrics | ||
| # Note: We skip critic metrics for RAFT since there's no critic | ||
| metrics.update(compute_timing_metrics(batch=raft_batch, timing_raw=timing_raw)) | ||
| n_gpus = self.resource_pool_manager.get_n_gpus() | ||
| metrics.update(compute_throughout_metrics(batch=raft_batch, timing_raw=timing_raw, n_gpus=n_gpus)) | ||
|
|
||
| return metrics | ||
|
|
||
| def fit(self): | ||
| logger = Tracking( | ||
| project_name=self.config.trainer.project_name, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why three paddings (here and L514, L577) are needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first padding operation was a legacy from the previous code and has now been removed. The latter two padding operations are necessary:
compute_log_prob requires distributed training, and the batch size must be divisible by the world_size.
update_actor requires distributed training, and the batch size must be divisible by the world_size.