You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/rlhf.qmd
+207Lines changed: 207 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -721,6 +721,213 @@ trl:
721
721
722
722
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
723
723
724
+
#### Async GRPO
725
+
726
+
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
727
+
728
+
```yaml
729
+
trl:
730
+
use_data_producer: true # Enable data producer protocol
731
+
use_vllm: true
732
+
async_prefetch: true # Generate rollouts in background thread
733
+
prefetch_depth: 1 # Number of rollouts to prefetch
734
+
vllm_sync_interval: 2 # Sync weights to vLLM every N steps
735
+
```
736
+
737
+
::: {.callout-note}
738
+
Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled).
739
+
:::
740
+
741
+
##### vLLM LoRA Sync
742
+
743
+
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
744
+
745
+
```yaml
746
+
adapter: lora
747
+
lora_r: 32
748
+
lora_alpha: 64
749
+
lora_target_linear: true
750
+
751
+
trl:
752
+
vllm_lora_sync: true # Enable native LoRA sync
753
+
```
754
+
755
+
When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
769
+
:::
770
+
771
+
##### Streaming Partial Batch
772
+
773
+
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
774
+
775
+
```yaml
776
+
trl:
777
+
streaming_partial_batch: true
778
+
```
779
+
780
+
##### Importance Sampling Correction
781
+
782
+
When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
783
+
784
+
```yaml
785
+
trl:
786
+
vllm_importance_sampling_correction: true # Enable IS correction
787
+
importance_sampling_level: token # 'token' or 'sequence'
788
+
off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this
789
+
```
790
+
791
+
- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel)
792
+
- `importance_sampling_level: sequence` applies per-sequence IS ratios
793
+
- `off_policy_mask_threshold`masks out sequences where the IS ratio indicates they are too far off-policy
794
+
795
+
##### Replay Buffer
796
+
797
+
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
798
+
799
+
```yaml
800
+
trl:
801
+
replay_buffer_size: 100 # Max cached groups (0 = disabled)
802
+
replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)
803
+
```
804
+
805
+
::: {.callout-note}
806
+
When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
807
+
:::
808
+
809
+
##### Deferred Re-rolling
810
+
811
+
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
812
+
813
+
```yaml
814
+
trl:
815
+
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
816
+
reroll_max_groups: 1 # Max groups to replace per batch
817
+
```
818
+
819
+
##### Zero-Advantage Batch Skipping
820
+
821
+
When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`.
822
+
823
+
```yaml
824
+
trl:
825
+
skip_zero_advantage_batches: true # default
826
+
```
827
+
828
+
##### Parallel Reward Workers
829
+
830
+
Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
831
+
832
+
```yaml
833
+
trl:
834
+
reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)
With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
929
+
:::
930
+
724
931
### GDPO
725
932
726
933
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
0 commit comments