Skip to content

Commit 33ee2d6

Browse files
committed
update docs
1 parent 0fd5115 commit 33ee2d6

File tree

1 file changed

+207
-0
lines changed

1 file changed

+207
-0
lines changed

docs/rlhf.qmd

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,213 @@ trl:
721721

722722
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
723723

724+
#### Async GRPO
725+
726+
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
727+
728+
```yaml
729+
trl:
730+
use_data_producer: true # Enable data producer protocol
731+
use_vllm: true
732+
async_prefetch: true # Generate rollouts in background thread
733+
prefetch_depth: 1 # Number of rollouts to prefetch
734+
vllm_sync_interval: 2 # Sync weights to vLLM every N steps
735+
```
736+
737+
::: {.callout-note}
738+
Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled).
739+
:::
740+
741+
##### vLLM LoRA Sync
742+
743+
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
744+
745+
```yaml
746+
adapter: lora
747+
lora_r: 32
748+
lora_alpha: 64
749+
lora_target_linear: true
750+
751+
trl:
752+
vllm_lora_sync: true # Enable native LoRA sync
753+
```
754+
755+
When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
756+
757+
```bash
758+
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
759+
```
760+
761+
Then start training on a separate GPU:
762+
763+
```bash
764+
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
765+
```
766+
767+
::: {.callout-tip}
768+
LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
769+
:::
770+
771+
##### Streaming Partial Batch
772+
773+
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
774+
775+
```yaml
776+
trl:
777+
streaming_partial_batch: true
778+
```
779+
780+
##### Importance Sampling Correction
781+
782+
When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
783+
784+
```yaml
785+
trl:
786+
vllm_importance_sampling_correction: true # Enable IS correction
787+
importance_sampling_level: token # 'token' or 'sequence'
788+
off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this
789+
```
790+
791+
- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel)
792+
- `importance_sampling_level: sequence` applies per-sequence IS ratios
793+
- `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy
794+
795+
##### Replay Buffer
796+
797+
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
798+
799+
```yaml
800+
trl:
801+
replay_buffer_size: 100 # Max cached groups (0 = disabled)
802+
replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)
803+
```
804+
805+
::: {.callout-note}
806+
When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
807+
:::
808+
809+
##### Deferred Re-rolling
810+
811+
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
812+
813+
```yaml
814+
trl:
815+
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
816+
reroll_max_groups: 1 # Max groups to replace per batch
817+
```
818+
819+
##### Zero-Advantage Batch Skipping
820+
821+
When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`.
822+
823+
```yaml
824+
trl:
825+
skip_zero_advantage_batches: true # default
826+
```
827+
828+
##### Parallel Reward Workers
829+
830+
Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
831+
832+
```yaml
833+
trl:
834+
reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)
835+
```
836+
837+
##### Full Async GRPO Example
838+
839+
```yaml
840+
base_model: Qwen/Qwen2.5-1.5B-Instruct
841+
842+
vllm:
843+
host: 0.0.0.0
844+
port: 8000
845+
gpu_memory_utilization: 0.35
846+
dtype: auto
847+
848+
adapter: lora
849+
lora_r: 32
850+
lora_alpha: 64
851+
lora_target_linear: true
852+
853+
rl: grpo
854+
trl:
855+
use_data_producer: true
856+
use_vllm: true
857+
async_prefetch: true
858+
prefetch_depth: 1
859+
vllm_sync_interval: 2
860+
vllm_lora_sync: true
861+
streaming_partial_batch: true
862+
vllm_importance_sampling_correction: true
863+
off_policy_mask_threshold: 0.5
864+
importance_sampling_level: token
865+
num_generations: 8
866+
max_completion_length: 512
867+
reward_funcs:
868+
- rewards.accuracy_reward
869+
reroll_start_fraction: 0.5
870+
replay_buffer_size: 100
871+
reward_num_workers: 4
872+
skip_zero_advantage_batches: true
873+
874+
datasets:
875+
- path: AI-MO/NuminaMath-TIR
876+
type: rewards.prompt_transform
877+
split: train
878+
879+
gradient_accumulation_steps: 4
880+
micro_batch_size: 2
881+
max_steps: 500
882+
learning_rate: 1e-5
883+
bf16: true
884+
gradient_checkpointing: true
885+
```
886+
887+
```bash
888+
# Terminal 1: Start vLLM on GPU 0
889+
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
890+
891+
# Terminal 2: Train on GPU 1
892+
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
893+
```
894+
895+
##### Multi-GPU Async GRPO
896+
897+
Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.
898+
899+
**FSDP:**
900+
901+
```yaml
902+
fsdp:
903+
- full_shard
904+
- auto_wrap
905+
fsdp_config:
906+
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
907+
gradient_checkpointing_kwargs:
908+
use_reentrant: false
909+
```
910+
911+
**DeepSpeed ZeRO-3:**
912+
913+
```yaml
914+
deepspeed: deepspeed_configs/zero3_bf16.json
915+
gradient_checkpointing_kwargs:
916+
use_reentrant: true # Required for ZeRO-3
917+
```
918+
919+
```bash
920+
# Terminal 1: Start vLLM on GPU 0
921+
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
922+
923+
# Terminal 2: Train on GPUs 0,1
924+
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml
925+
```
926+
927+
::: {.callout-important}
928+
With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
929+
:::
930+
724931
### GDPO
725932

726933
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.

0 commit comments

Comments
 (0)