Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 249 additions & 0 deletions agentlightning/verl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def _validate(self):
return test_metrics

def _train_step(self, batch_dict: dict) -> dict:
# Check if RAFT mode is enabled
if self.config.algorithm.adv_estimator == "raft":
return self._train_step_raft(batch_dict)

# Isolate in a separate method to automatically recycle the variables before validation.
batch: DataProto = DataProto.from_single_dict(batch_dict)
metrics = {}
Expand Down Expand Up @@ -388,6 +392,251 @@ def _train_step(self, batch_dict: dict) -> dict:

return metrics

def _train_step_raft(self, batch_dict: dict) -> dict:
"""
RAFT training step: Simplified training loop that only trains on r=1 samples.

RAFT (Rejection sampling Adaptive Fine-Tuning) differs from GRPO/PPO by:
1. Rejection sampling: Only keeping samples with reward r=1
2. Simple loss: Using standard cross-entropy (NLL) loss instead of advantage-weighted loss
3. No critic: No value function estimation needed
4. No advantage: No advantage function or GAE computation needed
"""
batch: DataProto = DataProto.from_single_dict(batch_dict)
metrics = {}
timing_raw = {}

with _timer("step", timing_raw):
# When agent mode is enabled, we read the batch as it is.
gen_batch = batch

# Generate rollouts and collect data
with _timer("gen", timing_raw):
self.async_rollout_manager.wake_up()
self.agent_mode_daemon.set_up_data_and_server(
gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses
)
self.agent_mode_daemon.run_until_all_finished()
batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch(
max_prompt_length=self.config.data.max_prompt_length,
max_response_length=self.config.data.max_response_length,
device=gen_batch.batch["fake_ids"].device,
)
metrics.update(agent_metrics)
self.agent_mode_daemon.clear_data_and_server()
self.async_rollout_manager.sleep()

# RAFT Step 1: Rejection Sampling - Filter to keep only r=1 samples
with _timer("rejection_sampling", timing_raw):
# Extract rewards from token_level_scores (sum to get sequence-level reward)
# The reward is stored at the last token position in token_level_scores
sequence_rewards = batch.batch["token_level_scores"].sum(dim=-1) # (batch_size,)

# Binary reward: 1.0 for success, 0.0 for failure
# In RAFT, we only keep samples with reward == 1.0
is_positive_reward = (sequence_rewards == 1.0)
positive_indices = is_positive_reward.nonzero(as_tuple=True)[0]

# Log rejection sampling statistics
n_total = len(batch)
n_positive = len(positive_indices)
n_rejected = n_total - n_positive
metrics["raft/n_total_samples"] = n_total
metrics["raft/n_positive_samples"] = n_positive
metrics["raft/n_rejected_samples"] = n_rejected
metrics["raft/rejection_rate"] = n_rejected / n_total if n_total > 0 else 0.0
metrics["raft/positive_rate"] = n_positive / n_total if n_total > 0 else 0.0

# If no positive samples, skip this training step
if n_positive == 0:
metrics["raft/loss"] = 0.0
metrics["raft/skipped_no_positive_samples"] = 1
return metrics

# Filter batch to keep only positive samples
positive_batch = batch[positive_indices.cpu().tolist()]

# RAFT Step 2: Compute response mask for the filtered batch
positive_batch.batch["response_mask"] = compute_response_mask(positive_batch)

# Set uid (required by update_actor, similar to GRPO)
# uid is used for algorithm like GRPO, should be aligned to data id
if "data_id_list" in positive_batch.non_tensor_batch:
positive_batch.non_tensor_batch["uid"] = positive_batch.non_tensor_batch["data_id_list"]

# Drop samples with prompts that are too long
keep_indices = (~positive_batch.batch["is_drop_mask"]).nonzero(as_tuple=True)[0]
metrics["raft/n_triplets_prompt_too_long"] = (
positive_batch.batch["is_drop_mask"].shape[0] - keep_indices.shape[0]
)
if len(keep_indices) == 0:
metrics["raft/loss"] = 0.0
metrics["raft/skipped_all_dropped"] = 1
return metrics
positive_batch = positive_batch[keep_indices]

# Round to mini batch size for efficient training
mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size
n_transition = len(positive_batch)
random_indices = list(range(n_transition))
random.shuffle(random_indices)
positive_batch.reorder(torch.tensor(random_indices).type(torch.int32))
n_remained_transition = n_transition // mini_batch_size * mini_batch_size
positive_batch = positive_batch[list(range(n_remained_transition))]
metrics["raft/n_triplets_dropped_remainder"] = n_transition - n_remained_transition

# Balance batch if enabled
if self.config.trainer.balance_batch:
self._balance_batch(positive_batch, metrics=metrics)

# Pad batch for distributed training
positive_batch, pad_size = pad_dataproto_to_divisor(positive_batch, self.actor_rollout_wg.world_size)
Copy link

@XufangLuo XufangLuo Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why three paddings (here and L514, L577) are needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why three paddings (here and L514, L577) are needed?

The first padding operation was a legacy from the previous code and has now been removed. The latter two padding operations are necessary:

compute_log_prob requires distributed training, and the batch size must be divisible by the world_size.

update_actor requires distributed training, and the batch size must be divisible by the world_size.


# RAFT Step 3: Prepare batch for RAFT loss computation
# Remove advantage-related fields since RAFT doesn't use them
raft_batch = positive_batch
max_response_length = raft_batch.batch["responses"].shape[-1]

# Unpad before computing loss
raft_batch = unpad_dataproto(raft_batch, pad_size=pad_size)

# RAFT Step 4: Prepare batch for actor update
# Need to compute old_log_probs and set required meta_info fields
with _timer("prepare_raft_batch", timing_raw):
# Ensure uid is set (may have been lost during filtering)
if "data_id_list" in raft_batch.non_tensor_batch:
raft_batch.non_tensor_batch["uid"] = raft_batch.non_tensor_batch["data_id_list"]

# Compute global_token_num (required by update_actor)
raft_batch.meta_info["global_token_num"] = torch.sum(raft_batch.batch["attention_mask"], dim=-1).tolist()

# Pad batch for distributed training before computing log_probs
raft_batch, pad_size_prep = pad_dataproto_to_divisor(raft_batch, self.actor_rollout_wg.world_size)

# Compute old_log_probs (required by update_actor, similar to GRPO)
# This is needed even for RAFT because update_actor expects this field
old_log_prob = self.actor_rollout_wg.compute_log_prob(raft_batch)
entropys = old_log_prob.batch["entropys"]
response_masks = raft_batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
raft_batch = raft_batch.union(old_log_prob)

# Set required meta_info fields (similar to GRPO)
raft_batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
# Temperature is required by update_actor (from config or default 0.7)
raft_batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.get("temperature", 0.7)

# Unpad before setting advantages
raft_batch = unpad_dataproto(raft_batch, pad_size=pad_size_prep)

# RAFT Step 5: Pure SFT update
# Use standard cross-entropy loss like TRL's SFTTrainer
# Reference: trl/trainer/sft_trainer.py compute_loss method
with _timer("update_actor_sft", timing_raw):
# Prepare inputs for SFT loss computation (like SFTTrainer)
# SFTTrainer expects: input_ids, attention_mask, labels
input_ids = raft_batch.batch["input_ids"] # (batch_size, seq_len)
attention_mask = raft_batch.batch["attention_mask"] # (batch_size, seq_len)

# Create labels: -100 for prompt tokens (ignore in loss), actual token IDs for response
# This matches SFTTrainer's label format
labels = input_ids.clone()
# Shift labels for next-token prediction: predict token[t] given tokens[<t]
labels[:, :-1] = input_ids[:, 1:].clone()
labels[:, -1] = -100 # Last token has no next token

# Mask prompt tokens with -100 (they will be ignored in loss)
# Only compute loss on response tokens
prompt_length = input_ids.shape[-1] - max_response_length
labels[:, :prompt_length] = -100

# Also mask padding tokens
labels = labels.masked_fill(~attention_mask.bool(), -100)

# Set advantages to 1.0 (no advantage weighting, pure SFT)
# This makes the PPO loss equivalent to standard cross-entropy
raft_batch.batch["advantages"] = torch.ones(
(len(raft_batch), max_response_length),
device=input_ids.device,
dtype=torch.float32
)
raft_batch.batch["returns"] = raft_batch.batch["advantages"].clone()

# Store labels in batch for potential use
raft_batch.batch["labels"] = labels

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even labels are specified here. will these labels be used by the following update_actor method? or will they be used as expected?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even labels are specified here. will these labels be used by the following update_actor method? or will they be used as expected?

These labels will not be used, but the update_actor method requires these fields. Therefore, I have retained the original GRPO processing logic to ensure the data format meets the requirements.


# Remove any existing values field (no critic in RAFT)
if "values" in raft_batch.batch:
raft_batch.batch.pop("values")

# Pad again for distributed training before update_actor
raft_batch, pad_size_actor = pad_dataproto_to_divisor(raft_batch, self.actor_rollout_wg.world_size)

# Temporarily disable PPO clipping for pure SFT (like SFTTrainer)
# Set clip_ratio to 1.0 effectively disables clipping
original_clip_low = self.config.actor_rollout_ref.actor.get("clip_ratio_low", 0.2)
original_clip_high = self.config.actor_rollout_ref.actor.get("clip_ratio_high", 0.3)

# Disable clipping: set both ratios to 1.0 (no clipping in pure SFT)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why clipping can be disabled by setting them to 1.0? This might not be related to RAFT.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the derivation of SFT from PPO:

The standard PPO objective (with clipping) is defined as:
L = - E(min( rA, clip( r, 1-ϵ,1+ϵ )A ))

To adapt the PPO framework for Supervised Fine-Tuning (SFT) mode, we set the following neutral conditions:

  • Set advantages: A=1
  • Disable clipping: ϵ = 1.0
    clip(r, 1-1, 1+0) = clip(r, 0, 2) ≈ r (r = exp(log_prob - old_log_prob), which is usually ranges in [0, 2])

When A=1 and clipping is disabled, the loss simplifies to:
L = - E(min( r, clip( r, 1-ϵ,1+ϵ ))) = -E(r)
logr = logπ(a|s)-logπold(a|s)

Therefore:
L = - E( exp(logπ(a|s)-logπold(a|s)))

In SFT, we directly optimize the current policy without relying on importance sampling. When the old policy is equal to (or very close to) the current policy, the PPO objective is replaced with the Negative Log-Likelihood loss, which is what we want to minimize:
L = - E( logπ(a|s))
For language models, this is the standard Cross-Entropy Loss.

Summary, by:

  • Setting A=1.0$to remove advantage weighting.
  • Disabling clipping to remove the PPO clipping mechanism.
    The final loss effectively degenerates (or is replaced by) the standard Cross-Entropy Loss (Negative Log-Likelihood), which is the SFT loss.
    Thus, the PPO framework, under these specific conditions, becomes equivalent to standard supervised learning (SFT).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some things I am not very sure.

  1. clipping is applied to importance sampling ratio, which is a ratio between prob and old_prob. So the the numerical range of importance sampling ratio is from 0 to positive infinity. Setting high and low to 1 leads to clip(r, 0, 2), which may not disabled clipping.
  2. If you want to use A to implement RAFT and SFT loss, what are the "labels" used for?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing out the issues. I have pushed new modifications.

  1. To simplify the formula, I previously assumed π ≈ π_old and set clip_ratio_low and clip_ratio_high to the same value. In fact, a more rigorous way to disable clipping would be to set clip_ratio_high to an extremely large number.

  2. Initially, I intended to calculate the loss following the implementation of SFTrainer. However, I later changed the calculation method, resulting in variables from both approaches coexisting. The "labels" were redundant code and have now been removed.

self.config.actor_rollout_ref.actor["clip_ratio_low"] = 1.0
self.config.actor_rollout_ref.actor["clip_ratio_high"] = 1.0

try:
# Update actor with pure SFT loss
# With advantages=1.0 and clip_ratio=1.0, this becomes standard cross-entropy
# This mimics SFTTrainer.compute_loss() behavior
actor_output = self.actor_rollout_wg.update_actor(raft_batch)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the called update_actor function is still the RL one, not the one in "SFTTrainer"? not sure if verl has SFTTrainer...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verl does not have a SFTrainer. SFTrainer inherits from transformers.Trainer and requires the complete HuggingFace Trainer infrastructure, whereas VERL uses Ray distributed training and a custom worker group. Directly adopting SFTrainer would disrupt VERL's existing architecture. Additionally, SFTrainer and VERL use different data formats.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. Then why SFT can be implemented in this way? Will update_actor actually compute SFT loss as we expected?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. Then why SFT can be implemented in this way? Will update_actor actually compute SFT loss as we expected?

Here is the derivation of SFT from PPO:

The standard PPO objective (with clipping) is defined as:
L = - E(min( rA, clip( r, 1-ϵ,1+ϵ )A ))

To adapt the PPO framework for Supervised Fine-Tuning (SFT) mode, we set the following neutral conditions:

Set advantages: A=1
Disable clipping: ϵ = 1.0
clip(r, 1-1, 1+0) = clip(r, 0, 2) ≈ r (r = exp(log_prob - old_log_prob), which is usually ranges in [0, 2])
When A=1 and clipping is disabled, the loss simplifies to:
L = - E(min( r, clip( r, 1-ϵ,1+ϵ ))) = -E(r)
logr = logπ(a|s)-logπold(a|s)

Therefore:
L = - E( exp(logπ(a|s)-logπold(a|s)))

In SFT, we directly optimize the current policy without relying on importance sampling. When the old policy is equal to (or very close to) the current policy, the PPO objective is replaced with the Negative Log-Likelihood loss, which is what we want to minimize:
L = - E( logπ(a|s))
For language models, this is the standard Cross-Entropy Loss.

Summary, by:

Setting A=1.0$to remove advantage weighting.
Disabling clipping to remove the PPO clipping mechanism.
The final loss effectively degenerates (or is replaced by) the standard Cross-Entropy Loss (Negative Log-Likelihood), which is the SFT loss.
Thus, the PPO framework, under these specific conditions, becomes equivalent to standard supervised learning (SFT).

actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)

# Extract and log the SFT loss
# Use actor loss from update_actor output (same as GRPO)
# Note: update_actor returns "actor/pg_loss" not "actor/loss"
if "actor/pg_loss" in actor_output_metrics:
metrics["raft/loss"] = actor_output_metrics["actor/pg_loss"]
elif "actor/loss" in actor_output_metrics:
metrics["raft/loss"] = actor_output_metrics["actor/loss"]
else:
# Fallback: use a default value if loss not found
metrics["raft/loss"] = 0.0
finally:
# Restore original clipping ratios
self.config.actor_rollout_ref.actor["clip_ratio_low"] = original_clip_low
self.config.actor_rollout_ref.actor["clip_ratio_high"] = original_clip_high

# Log that we're using pure SFT update (like SFTTrainer)
metrics["raft/pure_sft_update"] = 1.0

# Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
with _timer("dump_rollout_generations", timing_raw):
# Unpad for logging
log_batch = unpad_dataproto(raft_batch, pad_size_actor)
inputs = self.tokenizer.batch_decode(log_batch.batch["prompts"], skip_special_tokens=True)
outputs = self.tokenizer.batch_decode(log_batch.batch["responses"], skip_special_tokens=True)
# Get scores from the filtered batch
log_scores = log_batch.batch["token_level_scores"].sum(dim=-1).cpu().tolist()
self._dump_generations(
inputs=inputs,
outputs=outputs,
scores=log_scores,
reward_extra_infos_dict={},
dump_path=rollout_data_dir,
)

# Compute training metrics
# Note: We skip critic metrics for RAFT since there's no critic
metrics.update(compute_timing_metrics(batch=raft_batch, timing_raw=timing_raw))
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=raft_batch, timing_raw=timing_raw, n_gpus=n_gpus))

return metrics

def fit(self):
logger = Tracking(
project_name=self.config.trainer.project_name,
Expand Down