diff --git a/README.md b/README.md index 793b87cdf8..2a15afe76b 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ Trinity-RFT provides functionalities for users with different backgrounds and ob | *Multi-step agentic RL* | + [Concatenated multi-turn workflow](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_multi_turn.html)
+ [General multi-step workflow](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_step_wise.html)
+ [ReAct workflow with an agent framework](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_react.html)
+ [Example: train a web-search agent](https://github.com/modelscope/Trinity-RFT/tree/main/examples/agentscope_websearch) | | *Full-lifecycle data pipelines* | + [Rollout task mixing and selection](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/develop_selector.html)
+ [Online task curriculum](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) (📝 [paper](https://arxiv.org/pdf/2510.26374))
+ [Research project: learn-to-ask](https://github.com/modelscope/Trinity-RFT/tree/main/examples/learn_to_ask) (📝 [paper](https://arxiv.org/pdf/2510.25441))
+ [Experience replay with prioritization](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [Advanced data processing & human-in-the-loop](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_data_functionalities.html) | | *Algorithm development* | + [RL algorithm development with Trinity-RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_mix_algo.html) (📝 [paper](https://arxiv.org/pdf/2508.11408))
+ [Research project: group-relative REINFORCE](https://github.com/modelscope/Trinity-RFT/tree/main/examples/rec_gsm8k) (📝 [paper](https://arxiv.org/abs/2509.24203))
+ Non-verifiable domains: [RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_ruler), [trainable RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_trainable_ruler), [rubric-as-reward](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_rubric_as_reward) | -| *Going deeper into Trinity-RFT* | + [Full configurations](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)
+ [Benchmark toolkit for quick verification and experimentation](./benchmark/README.md)
+ [Understand the coordination between explorer and trainer](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/synchronizer.html) | +| *Going deeper into Trinity-RFT* | + [Full configurations](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)
+ [Benchmark toolkit for quick verification and experimentation](./benchmark/README.md)
+ [GPU Resource and Training Configuration Guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_gpu_configs.html)
+ [Understand the coordination between explorer and trainer](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/synchronizer.html) | > [!NOTE] diff --git a/README_zh.md b/README_zh.md index 423e76e103..de5ed46344 100644 --- a/README_zh.md +++ b/README_zh.md @@ -48,7 +48,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能: | *多轮智能体强化学习* | + [拼接多轮任务](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/example_multi_turn.html)
+ [通用多轮任务](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/example_step_wise.html)
+ [调用智能体框架中的 ReAct 工作流](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/example_react.html)
+ [例子:训练一个网络搜索智能体](https://github.com/modelscope/Trinity-RFT/tree/main/examples/agentscope_websearch) | | *全生命周期的数据流水线* | + [Rollout 任务混合与选取](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/develop_selector.html)
+ [在线任务选择](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) (📝 [论文](https://arxiv.org/pdf/2510.26374))
+ [研究项目:learn-to-ask](https://github.com/modelscope/Trinity-RFT/tree/main/examples/learn_to_ask) (📝 [论文](https://arxiv.org/pdf/2510.25441))
+ [经验回放机制](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [高级数据处理能力 & Human-in-the-loop](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/example_data_functionalities.html) | | *强化学习算法开发* | + [使用 Trinity-RFT 进行 RL 算法开发](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/example_mix_algo.html) (📝 [论文](https://arxiv.org/pdf/2508.11408))
+ [研究项目: group-relative REINFORCE](https://github.com/modelscope/Trinity-RFT/tree/main/examples/rec_gsm8k) (📝 [论文](https://arxiv.org/abs/2509.24203))
+ 不可验证的领域: [RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_ruler), [可训练 RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_trainable_ruler), [rubric-as-reward](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_rubric_as_reward) | -| *深入认识 Trinity-RFT* | + [完整配置指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)
+ [用于快速验证和实验的 Benchmark 工具](./benchmark/README.md)
+ [理解 explorer-trainer 同步逻辑](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/synchronizer.html) | +| *深入认识 Trinity-RFT* | + [完整配置指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)
+ [用于快速验证和实验的 Benchmark 工具](./benchmark/README.md)
+ [GPU 资源与训练配置对应指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_gpu_configs.html)
+ [理解 explorer-trainer 同步逻辑](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/synchronizer.html) | > [!NOTE] diff --git a/docs/sphinx_doc/source/index.rst b/docs/sphinx_doc/source/index.rst index 21794b4138..815223f135 100644 --- a/docs/sphinx_doc/source/index.rst +++ b/docs/sphinx_doc/source/index.rst @@ -23,6 +23,7 @@ Welcome to Trinity-RFT's documentation! tutorial/develop_operator.md tutorial/develop_selector.md tutorial/trinity_configs.md + tutorial/trinity_gpu_configs.md tutorial/synchronizer.md diff --git a/docs/sphinx_doc/source/main.md b/docs/sphinx_doc/source/main.md index 429a6d2c4c..adcbcc13fa 100644 --- a/docs/sphinx_doc/source/main.md +++ b/docs/sphinx_doc/source/main.md @@ -31,7 +31,7 @@ Trinity-RFT provides functionalities for users with different backgrounds and ob | *Multi-step agentic RL* | + [Concatenated multi-turn workflow](/tutorial/example_multi_turn.md)
+ [General multi-step workflow](/tutorial/example_step_wise.md)
+ [ReAct workflow with an agent framework](/tutorial/example_react.md)
+ [Example: train a web-search agent](https://github.com/modelscope/Trinity-RFT/tree/main/examples/agentscope_websearch) | | *Full-lifecycle data pipelines* | + [Rollout task mixing and selection](/tutorial/develop_selector.md)
+ [Online task curriculum](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) (📝 [paper](https://arxiv.org/pdf/2510.26374))
+ [Research project: learn-to-ask](https://github.com/modelscope/Trinity-RFT/tree/main/examples/learn_to_ask) (📝 [paper](https://arxiv.org/pdf/2510.25441))
+ [Experience replay with prioritization](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [Advanced data processing & human-in-the-loop](/tutorial/example_data_functionalities.md) | | *Algorithm development* | + [RL algorithm development with Trinity-RFT](/tutorial/example_mix_algo.md) (📝 [paper](https://arxiv.org/pdf/2508.11408))
+ [Research project: group-relative REINFORCE](https://github.com/modelscope/Trinity-RFT/tree/main/examples/rec_gsm8k) (📝 [paper](https://arxiv.org/abs/2509.24203))
+ Non-verifiable domains: [RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_ruler), [trainable RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_trainable_ruler), [rubric-as-reward](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_rubric_as_reward) | -| *Going deeper into Trinity-RFT* | + [Full configurations](/tutorial/trinity_configs.md)
+ [Benchmark toolkit for quick verification and experimentation](https://github.com/modelscope/Trinity-RFT/tree/main/benchmark/README.md)
+ [Understand the coordination between explorer and trainer](/tutorial/synchronizer.md) | +| *Going deeper into Trinity-RFT* | + [Full configurations](/tutorial/trinity_configs.md)
+ [Benchmark toolkit for quick verification and experimentation](https://github.com/modelscope/Trinity-RFT/tree/main/benchmark/README.md)
+ [GPU Resource and Training Configuration Guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_gpu_configs.html)
+ [Understand the coordination between explorer and trainer](/tutorial/synchronizer.md) | diff --git a/docs/sphinx_doc/source/tutorial/develop_selector.md b/docs/sphinx_doc/source/tutorial/develop_selector.md index c84f8f4267..d7ecdcb293 100644 --- a/docs/sphinx_doc/source/tutorial/develop_selector.md +++ b/docs/sphinx_doc/source/tutorial/develop_selector.md @@ -1,4 +1,4 @@ -# 🧪 Experimental: Task Selection & Scheduling System +# 🧪 Experimental: Task Selection ```{note} This module is currently in **experimental status**. Interfaces may change in future versions. diff --git a/docs/sphinx_doc/source/tutorial/trinity_gpu_configs.md b/docs/sphinx_doc/source/tutorial/trinity_gpu_configs.md new file mode 100644 index 0000000000..9ac94eb1fe --- /dev/null +++ b/docs/sphinx_doc/source/tutorial/trinity_gpu_configs.md @@ -0,0 +1,280 @@ +# GPU Configuration Guide + +This document provides recommended training configurations for Qwen3 series models on **NVIDIA A100 80GB** and **H20 96GB** GPUs. +Based on model size (0.6B ~ 14B) and context length (`model.max_model_len`), we present feasible Trainer module setups across varying numbers of GPUs. + +> ⚠️ **Note**: +> Due to the sparation design of rollout and training with Trinity. The following description of the number of GPUs refers to the number available for `Trainer`, not the total number of GPUs used by Trinity. + +> 💡 **Terminology** +> +> - **vanilla**: No special configuration required; default settings suffice. +> - **Env**: Set the following environment variable **before launching training (before starting Ray)**: +> ```bash +> export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +> ``` +> - **Offload**: Enable **FSDP v2 + CPU Offload** to reduce GPU memory usage. +> - **SP=N**: Use **Sequence Parallelism** with parallelism degree N (typically N ≤ number of GPUs). +> - **Combined entries (e.g., `Env + SP=2`)**: All listed conditions must be satisfied simultaneously. +> - **“-”**: The combination of current hardware and configuration **cannot support training** for this model under the given sequence length. + +--- + +## Long Context Support + +Qwen3 series models natively support a maximum context length of **40,960 tokens**. +For training beyond this length (e.g., 51,200, 81,920 tokens), we use **YaRN RoPE extension**. The relevant configuration is as follows: + +```yaml +model: + model_path: ${oc.env:MODEL_PATH,Qwen/Qwen3-0.6B} + max_prompt_tokens: 2048 + max_model_len: ${oc.env:MAX_MODEL_LEN,4096} + rope_scaling: + rope_type: yarn + factor: ${oc.decode:${oc.env:FACTOR}} # Recommended value = MAX_MODEL_LEN / 40960 + original_max_position_embeddings: 40960 +``` + +> ✅ When using YaRN, ensure `factor` is set reasonably to avoid numerical instability. + +--- + +## 💡 Relationship Between GPU Memory Usage and `max_token_len_per_gpu` + +Trinity Trainer enables dynamic batch sizing by default (`trainer.use_dynamic_bsz=True`). With a fixed model, actual GPU memory consumption is primarily determined by the following two parameters: + +- `trainer.trainer_config.actor_rollout_ref.actor.ppo_max_token_len_per_gpu` +- `trainer.trainer_config.actor_rollout_ref.ref.log_prob_max_token_len_per_gpu` + +If these parameters are not manually configured, Trinity automatically uses the following default value: +```python +trainer.max_token_len_per_gpu = ceil(2 * model.max_model_len / trainer.ulysses_sequence_parallel_size) +``` + +📌 **This implies that**: +- The longer the context length, the more tokens each GPU must process, resulting in higher memory pressure. +- To support **longer context lengths**, you can manually adjust the parameters above (though this may impact training efficiency). + +> All experimental results presented in this guide are based on the aforementioned default settings. For extreme optimization, please fine-tune these parameters according to your specific requirements. + +--- + +## A100 80GB GPU Configuration Recommendations + +> ⚠️ **Single-GPU Limitation**: Training models ≥4B or with context lengths >20K on a single A100 GPU places extreme pressure on VRAM. **Multi-GPU setups are strongly recommended**. + +### 1 GPU + +
Click to view detailed configurations + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:--------------|:--------------|:--------------|:--------------|:--------------| +| 4096 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 8192 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 12288 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 16384 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 20480 | vanilla | Env + Offload | Env + Offload | Env + Offload | Env + Offload | +| 24576 | Env | Env + Offload | Env + Offload | Env + Offload | Env + Offload | +| 28672 | Env + Offload | Env + Offload | Env + Offload | - | - | +| 32768 | - | - | - | - | - | + +
+ +--- + +### 2 GPUs + +
✅ Recommended: 2 GPUs significantly improve long-context training capability for 4B~14B models. Enable SP=2 when using longer contexts. + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:---------------------|:---------------------|:---------------------|:---------------------|:---------------------| +| 4096 | vanilla | vanilla | vanilla | Env | Env + Offload | +| 8192 | vanilla | vanilla | vanilla | Env + Offload | Env + Offload | +| 12288 | vanilla | vanilla | vanilla | Env + Offload | Env + Offload | +| 16384 | vanilla | vanilla | Env | Env + Offload | Env + Offload | +| 20480 | vanilla | vanilla | SP=2 | Env + Offload | Env + Offload | +| 24576 | vanilla | Env | SP=2 | Env + Offload | Env + Offload | +| 28672 | Env | SP=2 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 32768 | SP=2 | SP=2 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 36864 | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 40960 | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 51200 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | - | +| 61440 | Env + Offload + SP=2 | - | - | - | - | +| 71680 | - | - | - | - | - | + +
+ +--- + +### 4 GPUs + +
✅ Recommended: Ideal setup for training 8B/14B models with ultra-long contexts (>60K) + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:---------------------|:---------------------|:---------------------|:---------------------|:---------------------| +| 4096 | vanilla | vanilla | vanilla | vanilla | Env | +| 8192 | vanilla | vanilla | vanilla | vanilla | Env + SP=2 | +| 12288 | vanilla | vanilla | vanilla | Env | Env + SP=4 | +| 16384 | vanilla | vanilla | vanilla | SP=2 | Env + SP=4 | +| 20480 | vanilla | vanilla | vanilla | SP=2 | Env + SP=4 | +| 24576 | vanilla | Env | SP=2 | Env + SP=2 | Env + Offload | +| 28672 | Env | SP=2 | SP=2 | Env + SP=2 | Env + Offload + SP=2 | +| 32768 | SP=2 | SP=2 | SP=2 | SP=4 | Env + Offload + SP=2 | +| 36864 | SP=2 | SP=2 | SP=2 | SP=4 | Env + Offload + SP=2 | +| 40960 | SP=2 | SP=2 | Env + SP=2 | SP=4 | Env + Offload + SP=2 | +| 51200 | Env + SP=2 | Env + SP=2 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | +| 61440 | SP=4 | SP=4 | SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 71680 | SP=4 | SP=4 | SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 81920 | SP=4 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 92160 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 102400 | Env + SP=4 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | - | +| 112640 | Env + SP=4 | Env + Offload + SP=4 | - | - | - | +| 122880 | Env + Offload + SP=4 | - | - | - | - | +| 133120 | - | - | - | - | - | + +
+ +--- + +### 6 GPUs + +
✅ Good support for small-to-medium models (≤4B), but still limited for 14B models with ultra-long contexts + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:-------------|:-------------|:---------------------|:---------------------|:---------------------| +| 4096 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 8192 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 12288 | vanilla | vanilla | vanilla | vanilla | SP=2 | +| 16384 | vanilla | vanilla | vanilla | Env | SP=2 | +| 20480 | vanilla | vanilla | vanilla | SP=2 | Env + SP=2 | +| 24576 | vanilla | Env | Env | SP=2 | Env + Offload | +| 28672 | Env | Env | SP=2 | SP=2 | Env + Offload + SP=2 | +| 32768 | SP=2 | SP=2 | SP=2 | Env + SP=2 | Env + Offload + SP=2 | +| 36864 | SP=2 | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 40960 | SP=2 | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 51200 | Env + SP=2 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | - | +| 61440 | Env + SP=2 | - | - | - | - | +| 71680 | - | - | - | - | - | + +
+ +--- + +## H20 96GB GPU Configuration Recommendations + +The H20 has larger VRAM (96GB) but lower compute performance compared to the A100. + +### 1 GPU + +
Single GPU supports 4B models up to ~32K context length + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:--------------|:--------------|:--------------|:--------------|:--------------| +| 4096 | vanilla | vanilla | vanilla | Env + Offload | Env + Offload | +| 8192 | vanilla | vanilla | vanilla | Env + Offload | Env + Offload | +| 12288 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 16384 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 20480 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 24576 | vanilla | Env | Env + Offload | Env + Offload | Env + Offload | +| 28672 | vanilla | Env + Offload | Env + Offload | Env + Offload | Env + Offload | +| 32768 | Env | Env + Offload | Env + Offload | - | - | +| 36864 | Env + Offload | Env + Offload | - | - | - | +| 40960 | - | - | - | - | - | + +
+ +--- + +### 2 GPUs + +
Supports 14B models up to 50K context length + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:-------------|:-------------|:---------------------|:---------------------|:---------------------| +| 4096 | vanilla | vanilla | vanilla | vanilla | Env + Offload | +| 8192 | vanilla | vanilla | vanilla | vanilla | Env + Offload | +| 12288 | vanilla | vanilla | vanilla | SP=2 | Env + Offload | +| 16384 | vanilla | vanilla | vanilla | SP=2 | Env + Offload | +| 20480 | vanilla | vanilla | Env | Env + Offload | Env + Offload | +| 24576 | vanilla | vanilla | SP=2 | Env + Offload | Env + Offload | +| 28672 | vanilla | Env | SP=2 | Env + Offload | Env + Offload | +| 32768 | Env | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 36864 | Env | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 40960 | SP=2 | SP=2 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 51200 | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 61440 | Env + SP=2 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | - | +| 71680 | Env + SP=2 | - | - | - | - | +| 81920 | - | - | - | - | - | + +
+ +--- + +### 4 GPUs + +
✅ Supports training 14B models up to 100K context length + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:-------------|:---------------------|:---------------------|:---------------------|:---------------------| +| 4096 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 8192 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 12288 | vanilla | vanilla | vanilla | vanilla | SP=2 | +| 16384 | vanilla | vanilla | vanilla | vanilla | SP=2 | +| 20480 | vanilla | vanilla | vanilla | Env | Env + SP=2 | +| 24576 | vanilla | vanilla | vanilla | SP=2 | SP=4 | +| 28672 | vanilla | vanilla | Env | SP=2 | SP=4 | +| 32768 | Env | Env | SP=2 | Env + SP=2 | SP=4 | +| 36864 | Env | SP=2 | SP=2 | Env + SP=2 | Env + SP=4 | +| 40960 | SP=2 | SP=2 | SP=2 | SP=4 | Env + Offload + SP=2 | +| 51200 | SP=2 | SP=2 | Env + SP=2 | SP=4 | Env + Offload + SP=2 | +| 61440 | SP=2 | Env + SP=2 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | +| 71680 | Env + SP=2 | SP=4 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | +| 81920 | SP=4 | SP=4 | SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 92160 | SP=4 | SP=4 | SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 102400 | SP=4 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 112640 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 122880 | Env + SP=4 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | - | +| 133120 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | - | - | +| 143360 | Env + SP=4 | - | - | - | - | +| 153600 | - | - | - | - | - | + +
+ +--- + +### 6 GPUs + +
Good support for small-to-medium models (≤4B), but still limited for 14B models with ultra-long contexts + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:-------------|:---------------------|:---------------------|:---------------------|:---------------------| +| 4096 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 8192 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 12288 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 16384 | vanilla | vanilla | vanilla | vanilla | Env | +| 20480 | vanilla | vanilla | vanilla | vanilla | SP=2 | +| 24576 | vanilla | vanilla | vanilla | SP=2 | SP=2 | +| 28672 | vanilla | vanilla | Env | SP=2 | Env + SP=2 | +| 32768 | Env | Env | SP=2 | SP=2 | Env + SP=2 | +| 36864 | Env | SP=2 | SP=2 | SP=2 | Env + Offload + SP=2 | +| 40960 | SP=2 | SP=2 | SP=2 | Env + SP=2 | Env + Offload + SP=2 | +| 51200 | SP=2 | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 61440 | SP=2 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | - | +| 71680 | Env + SP=2 | Env + Offload + SP=2 | - | - | - | +| 81920 | - | - | - | - | - | + +
+ +--- + +## ✅ Best Practices + +1. **Start with the simplest configuration**: Try `vanilla` first, and incrementally enable advanced features only when encountering OOM errors. +2. **Always use YaRN for long contexts**: For contexts exceeding 40,960 tokens, configure `rope_scaling` and set `factor` appropriately. +3. **OOM troubleshooting sequence**: + - Step 1: Set `export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` + - Step 2: Increase **Sequence Parallelism (SP)** + - Step 3: Enable **FSDP v2 + CPU Offload** +4. **Choosing SP parallelism degree**: Prefer values that are **common divisors of both GPU count and attention head count** (e.g., 2, 4). +5. **Prefer multi-GPU over single-GPU**: Even when VRAM appears sufficient, multi-GPU setups improve training efficiency and stability through parallelization. diff --git a/docs/sphinx_doc/source_zh/index.rst b/docs/sphinx_doc/source_zh/index.rst index 3e4fbc276f..378f2dcc91 100644 --- a/docs/sphinx_doc/source_zh/index.rst +++ b/docs/sphinx_doc/source_zh/index.rst @@ -22,6 +22,7 @@ tutorial/develop_operator.md tutorial/develop_selector.md tutorial/trinity_configs.md + tutorial/trinity_gpu_configs.md tutorial/synchronizer.md .. toctree:: diff --git a/docs/sphinx_doc/source_zh/main.md b/docs/sphinx_doc/source_zh/main.md index 99d89a1ad2..0e69be26e1 100644 --- a/docs/sphinx_doc/source_zh/main.md +++ b/docs/sphinx_doc/source_zh/main.md @@ -30,7 +30,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能: | *多轮智能体强化学习* | + [拼接多轮任务](/tutorial/example_multi_turn.md)
+ [通用多轮任务](/tutorial/example_step_wise.md)
+ [调用智能体框架中的 ReAct 工作流](/tutorial/example_react.md)
+ [例子:训练一个网络搜索智能体](https://github.com/modelscope/Trinity-RFT/tree/main/examples/agentscope_websearch) | | *全生命周期的数据流水线* | + [Rollout 任务混合与选取](/tutorial/develop_selector.md)
+ [在线任务选择](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) (📝 [论文](https://arxiv.org/pdf/2510.26374))
+ [研究项目:learn-to-ask](https://github.com/modelscope/Trinity-RFT/tree/main/examples/learn_to_ask) (📝 [论文](https://arxiv.org/pdf/2510.25441))
+ [经验回放机制](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [高级数据处理能力 & Human-in-the-loop](/tutorial/example_data_functionalities.md) | | *强化学习算法开发* | + [使用 Trinity-RFT 进行 RL 算法开发](/tutorial/example_mix_algo.md) (📝 [论文](https://arxiv.org/pdf/2508.11408))
+ [研究项目: group-relative REINFORCE](https://github.com/modelscope/Trinity-RFT/tree/main/examples/rec_gsm8k) (📝 [论文](https://arxiv.org/abs/2509.24203))
+ 不可验证的领域: [RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_ruler), [可训练 RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_trainable_ruler), [rubric-as-reward](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_rubric_as_reward) | -| *深入认识 Trinity-RFT* | + [完整配置指南](/tutorial/trinity_configs.md)
+ [用于快速验证和实验的 Benchmark 工具](https://github.com/modelscope/Trinity-RFT/tree/main/benchmark/README.md)
+ [理解 explorer-trainer 同步逻辑](/tutorial/synchronizer.md) | +| *深入认识 Trinity-RFT* | + [完整配置指南](/tutorial/trinity_configs.md)
+ [用于快速验证和实验的 Benchmark 工具](https://github.com/modelscope/Trinity-RFT/tree/main/benchmark/README.md)
+ [GPU 资源与训练配置对应指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_gpu_configs.html)
+ [理解 explorer-trainer 同步逻辑](/tutorial/synchronizer.md) | diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_selector.md b/docs/sphinx_doc/source_zh/tutorial/develop_selector.md index 872e3819c4..1f92f05d4c 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_selector.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_selector.md @@ -1,4 +1,4 @@ -# 🧪 实验性功能:任务选择与调度系统 +# 🧪 实验性功能:任务选择器 ```{note} 该模块目前处于 **实验阶段**,接口可能在后续版本中发生变化。 diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_gpu_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_gpu_configs.md new file mode 100644 index 0000000000..97616d1795 --- /dev/null +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_gpu_configs.md @@ -0,0 +1,280 @@ +# GPU 资源相关配置指南 + +本文档为在 **NVIDIA A100 80GB** 和 **H20 96GB** 显卡上训练 Qwen3 系列模型提供推荐的训练配置建议。 +根据模型大小(0.6B ~ 14B)与上下文长度(`model.max_model_len`),我们给出了Trainer模块在不同 GPU 数量下的可行方案。 + +> ⚠️ **注意** +> 由于在Trinity内,采样与训练是分离的。以下关于GPU数量的描述指的是`Trainer`部分可使用的数量,而非Trinity总共使用的GPU数量。 + +> 💡 **术语说明** +> +> - **vanilla**:无需特殊配置,使用默认设置即可。 +> - **Env**:需在启动训练前(启动ray之前)设置环境变量: +> ```bash +> export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +> ``` +> - **Offload**:需启用 **FSDP v2 + CPU Offload** 技术以节省显存。 +> - **SP=N**:表示使用 **Sequence Parallelism(序列并行)**,并行度为 N(通常 N ≤ GPU 数量)。 +> - **组合项(如 `Env + SP=2`)**:需同时满足所有列出的条件。 +> - **“-”**:当前硬件与配置组合下,无法支持该模型在此序列长度下进行训练。 + +--- + +## 关于长上下文支持 + +Qwen3 系列模型原生支持的最大上下文长度为 **40,960 tokens**。 +对于超过此长度的训练(如 51,200、81,920 等),我们通过 **YaRN RoPE 扩展** 实现。相关配置如下: + +```yaml +model: + model_path: ${oc.env:MODEL_PATH,Qwen/Qwen3-0.6B} + max_prompt_tokens: 2048 + max_model_len: ${oc.env:MAX_MODEL_LEN,4096} + rope_scaling: + rope_type: yarn + factor: ${oc.decode:${oc.env:FACTOR}} # 推荐值 = MAX_MODEL_LEN / 40960 + original_max_position_embeddings: 40960 +``` + +> ✅ 使用 YaRN 时,请确保 `factor` 设置合理,避免数值不稳定。 + +--- + +## 💡 显存使用与 `max_token_len_per_gpu` 的关系 + +Trinity Trainer 默认启用了动态批大小(`trainer.use_dynamic_bsz=True`),在固定模型的情况下,实际显存消耗主要由以下两个参数决定: + +- `trainer.trainer_config.actor_rollout_ref.actor.ppo_max_token_len_per_gpu` +- `trainer.trainer_config.actor_rollout_ref.ref.log_prob_max_token_len_per_gpu` + +如果未手动设置,Trinity会自动用该默认值: +```python +trainer.max_token_len_per_gpu = ceil(2 * model.max_model_len / trainer.ulysses_sequence_parallel_size) +``` + +📌 **这意味着**: +- 上下文越长,每张 GPU 要处理的 token 越多,显存压力越大。 +- 如果想支持**更长上下文**,可以手动设置上述参数(但可能影响训练效率)。 + +> 本指南中的所有实验结果都是基于上述默认设置得出的。如需极限优化,请根据实际情况调整这些参数。 + +--- + +## A100 80GB 显卡配置建议 + +> ⚠️ **单卡限制**:在 1 张 A100 上训练 ≥4B 模型或 >20K 上下文时,显存压力极大,**强烈建议使用多卡方案**。 + +### 1 张 GPU + +
点击查看详细配置 + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:--------------|:--------------|:--------------|:--------------|:--------------| +| 4096 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 8192 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 12288 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 16384 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 20480 | vanilla | Env + Offload | Env + Offload | Env + Offload | Env + Offload | +| 24576 | Env | Env + Offload | Env + Offload | Env + Offload | Env + Offload | +| 28672 | Env + Offload | Env + Offload | Env + Offload | - | - | +| 32768 | - | - | - | - | - | + +
+ +--- + +### 2 张 GPU + +
✅ 推荐:2 卡显著提升 4B~14B 模型的长上下文训练能力,上下文较长时建议启用 SP=2 + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:---------------------|:---------------------|:---------------------|:---------------------|:---------------------| +| 4096 | vanilla | vanilla | vanilla | Env | Env + Offload | +| 8192 | vanilla | vanilla | vanilla | Env + Offload | Env + Offload | +| 12288 | vanilla | vanilla | vanilla | Env + Offload | Env + Offload | +| 16384 | vanilla | vanilla | Env | Env + Offload | Env + Offload | +| 20480 | vanilla | vanilla | SP=2 | Env + Offload | Env + Offload | +| 24576 | vanilla | Env | SP=2 | Env + Offload | Env + Offload | +| 28672 | Env | SP=2 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 32768 | SP=2 | SP=2 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 36864 | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 40960 | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 51200 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | - | +| 61440 | Env + Offload + SP=2 | - | - | - | - | +| 71680 | - | - | - | - | - | + +
+ +--- + +### 4 张 GPU + +
✅ 推荐:训练 8B/14B 模型 + 超长上下文(>60K)的理想配置 + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:---------------------|:---------------------|:---------------------|:---------------------|:---------------------| +| 4096 | vanilla | vanilla | vanilla | vanilla | Env | +| 8192 | vanilla | vanilla | vanilla | vanilla | Env + SP=2 | +| 12288 | vanilla | vanilla | vanilla | Env | Env + SP=4 | +| 16384 | vanilla | vanilla | vanilla | SP=2 | Env + SP=4 | +| 20480 | vanilla | vanilla | vanilla | SP=2 | Env + SP=4 | +| 24576 | vanilla | Env | SP=2 | Env + SP=2 | Env + Offload | +| 28672 | Env | SP=2 | SP=2 | Env + SP=2 | Env + Offload + SP=2 | +| 32768 | SP=2 | SP=2 | SP=2 | SP=4 | Env + Offload + SP=2 | +| 36864 | SP=2 | SP=2 | SP=2 | SP=4 | Env + Offload + SP=2 | +| 40960 | SP=2 | SP=2 | Env + SP=2 | SP=4 | Env + Offload + SP=2 | +| 51200 | Env + SP=2 | Env + SP=2 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | +| 61440 | SP=4 | SP=4 | SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 71680 | SP=4 | SP=4 | SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 81920 | SP=4 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 92160 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 102400 | Env + SP=4 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | - | +| 112640 | Env + SP=4 | Env + Offload + SP=4 | - | - | - | +| 122880 | Env + Offload + SP=4 | - | - | - | - | +| 133120 | - | - | - | - | - | + +
+ +--- + +### 6 张 GPU + +
✅ 对中小模型(≤4B)支持较好,但对 14B 模型在超长上下文下仍存在限制 + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:-------------|:-------------|:---------------------|:---------------------|:---------------------| +| 4096 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 8192 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 12288 | vanilla | vanilla | vanilla | vanilla | SP=2 | +| 16384 | vanilla | vanilla | vanilla | Env | SP=2 | +| 20480 | vanilla | vanilla | vanilla | SP=2 | Env + SP=2 | +| 24576 | vanilla | Env | Env | SP=2 | Env + Offload | +| 28672 | Env | Env | SP=2 | SP=2 | Env + Offload + SP=2 | +| 32768 | SP=2 | SP=2 | SP=2 | Env + SP=2 | Env + Offload + SP=2 | +| 36864 | SP=2 | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 40960 | SP=2 | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 51200 | Env + SP=2 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | - | +| 61440 | Env + SP=2 | - | - | - | - | +| 71680 | - | - | - | - | - | + +
+ +--- + +## H20 96GB 显卡配置建议 + +H20 显存更大(96GB),但计算能力弱于 A100。 + +### 1 张 GPU + +
单卡可支持 4B 模型至 ~32K 上下文 + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:--------------|:--------------|:--------------|:--------------|:--------------| +| 4096 | vanilla | vanilla | vanilla | Env + Offload | Env + Offload | +| 8192 | vanilla | vanilla | vanilla | Env + Offload | Env + Offload | +| 12288 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 16384 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 20480 | vanilla | vanilla | Env + Offload | Env + Offload | Env + Offload | +| 24576 | vanilla | Env | Env + Offload | Env + Offload | Env + Offload | +| 28672 | vanilla | Env + Offload | Env + Offload | Env + Offload | Env + Offload | +| 32768 | Env | Env + Offload | Env + Offload | - | - | +| 36864 | Env + Offload | Env + Offload | - | - | - | +| 40960 | - | - | - | - | - | + +
+ +--- + +### 2 张 GPU + +
支持 14B 模型至 50K 上下文 + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:-------------|:-------------|:---------------------|:---------------------|:---------------------| +| 4096 | vanilla | vanilla | vanilla | vanilla | Env + Offload | +| 8192 | vanilla | vanilla | vanilla | vanilla | Env + Offload | +| 12288 | vanilla | vanilla | vanilla | SP=2 | Env + Offload | +| 16384 | vanilla | vanilla | vanilla | SP=2 | Env + Offload | +| 20480 | vanilla | vanilla | Env | Env + Offload | Env + Offload | +| 24576 | vanilla | vanilla | SP=2 | Env + Offload | Env + Offload | +| 28672 | vanilla | Env | SP=2 | Env + Offload | Env + Offload | +| 32768 | Env | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 36864 | Env | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 40960 | SP=2 | SP=2 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 51200 | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 61440 | Env + SP=2 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | - | +| 71680 | Env + SP=2 | - | - | - | - | +| 81920 | - | - | - | - | - | + +
+ +--- + +### 4 张 GPU + +
✅ 可支持 14B 模型训练至 100K 上下文 + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:-------------|:---------------------|:---------------------|:---------------------|:---------------------| +| 4096 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 8192 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 12288 | vanilla | vanilla | vanilla | vanilla | SP=2 | +| 16384 | vanilla | vanilla | vanilla | vanilla | SP=2 | +| 20480 | vanilla | vanilla | vanilla | Env | Env + SP=2 | +| 24576 | vanilla | vanilla | vanilla | SP=2 | SP=4 | +| 28672 | vanilla | vanilla | Env | SP=2 | SP=4 | +| 32768 | Env | Env | SP=2 | Env + SP=2 | SP=4 | +| 36864 | Env | SP=2 | SP=2 | Env + SP=2 | Env + SP=4 | +| 40960 | SP=2 | SP=2 | SP=2 | SP=4 | Env + Offload + SP=2 | +| 51200 | SP=2 | SP=2 | Env + SP=2 | SP=4 | Env + Offload + SP=2 | +| 61440 | SP=2 | Env + SP=2 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | +| 71680 | Env + SP=2 | SP=4 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | +| 81920 | SP=4 | SP=4 | SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 92160 | SP=4 | SP=4 | SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 102400 | SP=4 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 112640 | SP=4 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | +| 122880 | Env + SP=4 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | - | +| 133120 | Env + SP=4 | Env + Offload + SP=4 | Env + Offload + SP=4 | - | - | +| 143360 | Env + SP=4 | - | - | - | - | +| 153600 | - | - | - | - | - | + +
+ +--- + +### 6 张 GPU + +
对中小模型(≤4B)支持较好,但对 14B 模型在超长上下文下仍存在限制 + +| `max_model_len` | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +|------------------:|:-------------|:---------------------|:---------------------|:---------------------|:---------------------| +| 4096 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 8192 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 12288 | vanilla | vanilla | vanilla | vanilla | vanilla | +| 16384 | vanilla | vanilla | vanilla | vanilla | Env | +| 20480 | vanilla | vanilla | vanilla | vanilla | SP=2 | +| 24576 | vanilla | vanilla | vanilla | SP=2 | SP=2 | +| 28672 | vanilla | vanilla | Env | SP=2 | Env + SP=2 | +| 32768 | Env | Env | SP=2 | SP=2 | Env + SP=2 | +| 36864 | Env | SP=2 | SP=2 | SP=2 | Env + Offload + SP=2 | +| 40960 | SP=2 | SP=2 | SP=2 | Env + SP=2 | Env + Offload + SP=2 | +| 51200 | SP=2 | SP=2 | SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | +| 61440 | SP=2 | Env + SP=2 | Env + Offload + SP=2 | Env + Offload + SP=2 | - | +| 71680 | Env + SP=2 | Env + Offload + SP=2 | - | - | - | +| 81920 | - | - | - | - | - | + +
+ +--- + +## ✅ 最佳实践建议 + +1. **从最简配置开始**:优先尝试 `vanilla`,仅在遇到 OOM 时逐步启用高级功能。 +2. **长上下文必用 YaRN**:超过 40,960 tokens 时,务必配置 `rope_scaling` 并合理设置 `factor`。 +3. **OOM 处理顺序**: + - 第一步:设置 `export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` + - 第二步:增加 **Sequence Parallelism(SP)** + - 第三步:启用 **FSDP v2 + CPU Offload** +4. **SP 并行度选择**:建议设为 **GPU 数量与注意力头数的公因数**(如 2、4)。 +5. **多卡优于单卡**:即使显存足够,多卡也能通过并行提升训练效率与稳定性。 diff --git a/scripts/context_length_test/README.md b/scripts/context_length_test/README.md new file mode 100644 index 0000000000..10c7a535e2 --- /dev/null +++ b/scripts/context_length_test/README.md @@ -0,0 +1,241 @@ +# Automated Context Length Testing for Large Language Models + +This script automates the process of determining the **maximum context length** a large language model (LLM) can handle under various distributed training configurations, including different GPU counts and sequence parallelism settings. It iteratively increases the context length during training until an **Out-of-Memory (OOM)** error occurs, logging results and supporting advanced features like RoPE scaling, FSDP strategies, and offloading. + +--- + +## 🧰 Requirements + +Ensure Trinity-RFT is well installed ([Installation Guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html)). No extra dependence is required. + +--- + +## 🛠️ Configuration Files + +The script relies on two external files: + +1. **`context_length.yaml`** + - Located in the same directory as this script. + - Defines the base training configuration used by `trinity`. + +2. **`workflow/` plugin directory** + - Contains `CustomWorkflow` expected by the `trinity`, which providing a synthetic training data generator. + +Ensure both exist at runtime. You can modify these files to customize the training process. + +--- + +## 🚀 Usage + +### Run the Script + +```bash +python search_context_length_capacity.py \ + --model_path /path/to/your/model \ + --start_length 4096 \ + --log_dir ./logs \ + --test_gpu_num 1 2 4 \ + --test_sp_num 1 2 \ + --trainer_strategy fsdp \ + --save_hf_checkpoint last \ + --timeout 2400 +``` + +### Required Arguments + +| Argument | Description | +|--------|-----------| +| `--model_path` | Path to the pretrained Hugging Face model directory. | + +### Optional Arguments + +| Argument | Default | Description | +|--------|--------|-----------| +| `--start_length` | `4096` | Initial context length to begin testing. | +| `--log_dir` | `./logs` | Directory to save logs and results. | +| `--checkpoint_path` | `os.environ.get("TRINITY_CHECKPOINT_ROOT_DIR", "./checkpoints/length-test")` | Checkpoint path for testing. Note that this directory will be deleted during the test, please specify a path that is not used by other processes. | +| `--test_gpu_num` | `1 2 4 6` | List of GPU counts to test scalability. | +| `--test_sp_num` | `1` | Sequence parallel group sizes to evaluate. Must divide `test_gpu_num` and number of attention heads. | +| `--save_hf_checkpoint` | `last` | When to save HF format checkpoints (`always`, `never`, `last`). | +| `--entropy_saving` | `False` | Enable memory-saving techniques (if supported). | +| `--offload` | `False` | Offload parameters to CPU to reduce GPU memory usage. | +| `--trainer_strategy` | `fsdp` | Distributed training strategy (`fsdp` or `fsdp2`). | +| `--timeout` | `2400` (40 min) | Maximum time per job before forced termination. | +| `--dlc` | `False` | Specify when running in Aliyun PAI DLC. | + +--- + +## 📂 Output Structure + +Logs are saved in a structured hierarchy under `--log_dir`: + +``` +logs/ +└── / + └── gpu-/ + └── sp-/ + └── model_len-.log +``` + +Each log file corresponds to a specific `(GPU count, SP size, context length)` combination. + +Final results are printed to stdout: +``` +model_name = Qwen3-0.6B, trainer_gpu_num = 4, sp_num = 2, max_model_len = 40960 +``` + +--- + +## ⚠️ Notes & Best Practices + +- **Model Compatibility**: Ensure the model supports dynamic context extension (e.g., via RoPE scaling). +- **SP Validity**: Only valid SP values (divisors of both GPU count and attention heads) are tested. +- **Checkpoint Root**: Controlled by `TRINITY_CHECKPOINT_ROOT_DIR` env var (default: `./checkpoints/length-test`). Cleared before each trial. +- **Early Termination**: If any run fails due to OOM, the search stops and returns the last successful length. +- **Large Steps After Base Limit**: Basic step size is 4096. And once context exceeds `max_position_embeddings`, step size becomes quarter of original limit. + +--- + +## 🧪 Example: Test Qwen3-0.6B Context Length + +```bash +python search_context_length_capacity.py \ + --model_path Qwen/Qwen3-0.6B \ + --test_gpu_num 1 2 4 6 \ + --test_sp_num 1 2 4 \ + --start_length 8192 \ + --log_dir ./results/qwen3-length-scan \ + --trainer_strategy fsdp2 \ + --timeout 3600 +``` + +This command will test the maximum context length for Qwen3-0.6B model with 2, 4, and 8 GPUs, using FSDP2 strategy, and save logs to `./results/qwen3-length-scan`. + +--- + +## 📚 Test Results + +Below are empirical results from running this script on various Qwen3 models across different hardware and optimization configurations. These benchmarks help guide configuration choices for maximizing context length within memory constraints. + +### Legend +- `*` indicates RoPE scaling (YARN) was applied — context length exceeds the model’s native `max_position_embeddings`. +- `-` indicates OOM occurred even at 4096 context length. +- All tests use `start_length=4096` and increase dynamically. + +### A100 80GB + +#### Vallina Settings (Baseline) + +| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +| ---- | -- | ---------- | ---------- | -------- | -------- | --------- | +| 1 | 1 | 20480 | 16384 | - | - | - | +| 2 | 1 | 24576 | 20480 | 12288 | - | - | +| 2 | 2 | 40960 | 40960 | 24576 | - | - | +| 4 | 1 | 24576 | 20480 | 20480 | 8192 | - | +| 4 | 2 | 40960 | 40960 | 36864 | 20480 | - | +| 4 | 4 | 92160* | 81920* | 71680* | 40960 | - | +| 6 | 1 | 24576 | 20480 | 20480 | 12288 | 8192 | +| 6 | 2 | 40960 | 40960 | 40960 | 28672 | 16384 | + + +#### Enable `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` + +> ⚠️ Must be set **before** launching any processes (including Ray clusters). + + +| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +| ---- | -- | ---------- | ---------- | -------- | -------- | --------- | +| 1 | 1 | 24576 | 16384 | - | - | - | +| 2 | 1 | 28672 | 24576 | 16384 | 4096 | - | +| 2 | 2 | 51200* | 40960 | 32768 | - | - | +| 4 | 1 | 28672 | 24576 | 20480 | 12288 | 4096 | +| 4 | 2 | 51200* | 51200* | 40960 | 28672 | 8192 | +| 4 | 4 | 112640* | 102400* | 81920* | 51200* | 20480 | +| 6 | 1 | 28672 | 28672 | 24576 | 16384 | 8192 | +| 6 | 2 | 61440* | 51200* | 40960 | 32768 | 20480 | + + +#### Enable `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`, FSDP2 Offload and `save_hf_checkpoint=never` + +> Uses: `--offload --trainer_strategy fsdp2 --save_hf_checkpoint never` + + +| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +| ---- | -- | ---------- | ---------- | -------- | -------- | --------- | +| 1 | 1 | 28672 | 28672 | 28672 | 24576 | 24576 | +| 2 | 1 | 28672 | 28672 | 28672 | 24576 | 24576 | +| 2 | 2 | 61440* | 51200* | 51200* | 51200* | 40960 | +| 4 | 1 | 28672 | 28672 | 28672 | 24576 | 24576 | +| 4 | 2 | 61440* | 51200* | 51200* | 51200* | 40960 | +| 4 | 4 | 122880* | 112640* | 102400* | 102400* | 92160* | +| 6 | 1 | 28672 | 28672 | 28672 | 24576 | 24576 | +| 6 | 2 | 61440* | 51200* | 51200* | 51200* | 40960 | + + + + +### H20 96GB (Higher VRAM, Slower Bandwidth) + + +#### Vallina Settings + + +| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +| ---- | -- | ---------- | ---------- | -------- | -------- | --------- | +| 1 | 1 | 28672 | 20480 | 8192 | - | - | +| 2 | 1 | 28672 | 24576 | 16384 | 8192 | - | +| 2 | 2 | 51200* | 51200* | 36864 | 16384 | - | +| 4 | 1 | 28672 | 28672 | 24576 | 16384 | 8192 | +| 4 | 2 | 61440* | 51200* | 40960 | 28672 | 16384 | +| 4 | 4 | 112640* | 102400* | 92160* | 51200* | 32768 | +| 6 | 1 | 28672 | 28672 | 24576 | 20480 | 12288 | +| 6 | 2 | 61440* | 51200* | 51200* | 36864 | 24576 | + + +#### Enable `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` + + +| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +| ---- | -- | ---------- | ---------- | -------- | -------- | --------- | +| 1 | 1 | 32768 | 24576 | 8192 | - | - | +| 2 | 1 | 36864 | 28672 | 20480 | 8192 | - | +| 2 | 2 | 71680* | 61440* | 40960 | 16384 | - | +| 4 | 1 | 36864 | 32768 | 28672 | 20480 | 8192 | +| 4 | 2 | 71680* | 61440* | 51200* | 36864 | 20480 | +| 4 | 4 | 143360* | 122880* | 102400* | 71680* | 36864 | +| 6 | 1 | 36864 | 32768 | 28672 | 20480 | 16384 | +| 6 | 2 | 71680* | 61440* | 51200* | 40960 | 32768 | + + + +#### Enable `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` and FSDP2 Offload + +> Uses: `--offload --trainer_strategy fsdp2` + +| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +| ---- | -- | ---------- | ---------- | -------- | -------- | --------- | +| 1 | 1 | 36864 | 36864 | 32768 | 28672 | 28672 | +| 2 | 1 | 36864 | 36864 | 32768 | 28672 | 28672 | +| 2 | 2 | 71680* | 61440* | 61440* | 61440 | 51200* | +| 4 | 1 | 36864 | | 32768 | 28672 | 28672 | +| 4 | 2 | 71680* | 71680* | 61440* | 61440* | | +| 4 | 4 | 143360* | 133120* | 133120* | 122880* | 112640* | +| 6 | 1 | 36864 | | 32768 | 28672 | 28672 | +| 6 | 2 | 71680* | 71680* | 61440* | 61440* | 51200* | + + +#### Enable `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`, FSDP2 Offload and `save_hf_checkpoint=never` + +> Uses: `--offload --trainer_strategy fsdp2 --save_hf_checkpoint never` + + +| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B | +| ---- | -- | ---------- | ---------- | -------- | -------- | --------- | +| 1 | 1 | 36864 | 36864 | 32768 | 28672 | 28672 | +| 2 | 1 | 36864 | 36864 | 32768 | 28672 | 28672 | +| 2 | 2 | 71680* | 61440* | 61440* | 61440* | | +| 4 | 1 | 36864 | | 32768 | 28672 | 28672 | +| 4 | 2 | 71680* | 71680* | 61440* | 61440* | 51200* | +| 4 | 4 | 143360* | 133120* | 133120* | 122880* | 112640* | +| 6 | 1 | 36864 | | 32768 | 28672 | 28672 | +| 6 | 2 | 71680* | 71680* | 61440* | 61440* | 51200* | diff --git a/scripts/context_length_test/context_length.yaml b/scripts/context_length_test/context_length.yaml new file mode 100644 index 0000000000..e7133de8dd --- /dev/null +++ b/scripts/context_length_test/context_length.yaml @@ -0,0 +1,112 @@ +mode: both +project: Trinity-RFT-context-length-exp +group: length-test +name: length-test +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints/length-test} +continue_from_checkpoint: false +algorithm: + algorithm_type: grpo + repeat_times: ${oc.env:REPEAT_TIMES,8} + advantage_fn: grpo + sample_strategy: default + policy_loss_fn: ppo + kl_penalty_fn: none + kl_loss_fn: k2 + entropy_loss_fn: default + optimizer: + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + warmup_style: constant +data_processor: {} +model: + model_path: ${oc.env:MODEL_PATH,Qwen/Qwen3-0.6B} + max_prompt_tokens: ${oc.env:PROMPT_LEN,2048} + max_model_len: ${oc.env:MAX_MODEL_LEN,4096} + rope_scaling: ${oc.decode:${oc.env:ROPE_SCALING,null}} +cluster: + node_num: 1 + gpu_per_node: ${oc.env:GPU_NUM,8} +buffer: + batch_size: 1 + total_steps: 2 + explorer_input: + taskset: + name: taskset + storage_type: file + path: openai/gsm8k + split: train + subset_name: main + format: + prompt_key: question + response_key: answer + rollout_args: + temperature: 1.0 + logprobs: 0 + workflow_args: + prompt_len: ${model.max_prompt_tokens} + max_model_len: ${model.max_model_len} + eval_tasksets: [] + default_workflow_type: synthetic_exp_workflow + default_reward_fn_type: math_reward + trainer_input: + experience_buffer: + name: experience_buffer + storage_type: queue + replay_buffer: + enable: false + priority_fn: linear_decay + reuse_cooldown_time: null + priority_fn_args: + decay: 2.0 +explorer: + runner_per_model: 8 + rollout_model: + engine_num: ${oc.env:ENGINE_NUM,1} + tensor_parallel_size: 1 + enforce_eager: true + enable_prefix_caching: false + enable_chunked_prefill: false + gpu_memory_utilization: 0.9 + dtype: bfloat16 + seed: 42 + enable_thinking: false + enable_history: false + enable_openai_api: false + enable_auto_tool_choice: false + tool_call_parser: null + reasoning_parser: null + auxiliary_models: [] + eval_interval: 1000 +trainer: + trainer_type: verl + trainer_strategy: ${oc.env:TRAINER_STRATEGY,fsdp} + save_interval: 100 + enable_preview: true + grad_clip: 1.0 + ulysses_sequence_parallel_size: ${oc.env:SP_NUM,1} + save_hf_checkpoint: ${oc.env:SAVE_HF_CHECKPOINT,last} + trainer_config: + actor_rollout_ref: + actor: + entropy_from_logits_with_chunking: ${oc.env:ENTROPY_SAVING,false} + entropy_checkpointing: ${oc.env:ENTROPY_SAVING,false} + fsdp_config: + param_offload: ${oc.env:OFFLOAD,false} + optimizer_offload: ${oc.env:OFFLOAD,false} + offload_policy: ${oc.env:OFFLOAD,false} + ref: + entropy_from_logits_with_chunking: ${oc.env:ENTROPY_SAVING,false} + entropy_checkpointing: ${oc.env:ENTROPY_SAVING,false} + fsdp_config: + param_offload: ${oc.env:OFFLOAD,false} + optimizer_offload: ${oc.env:OFFLOAD,false} + offload_policy: ${oc.env:OFFLOAD,false} +monitor: + monitor_type: tensorboard +synchronizer: + sync_method: nccl + sync_style: fixed + sync_interval: 1 + sync_timeout: 1200 +log: + level: INFO diff --git a/scripts/context_length_test/search_context_length_capacity.py b/scripts/context_length_test/search_context_length_capacity.py new file mode 100644 index 0000000000..f1bc748158 --- /dev/null +++ b/scripts/context_length_test/search_context_length_capacity.py @@ -0,0 +1,435 @@ +"""Automated context length testing for large language models using distributed training configurations. + +This script runs scalability tests on a given model by launching training jobs with increasing +context lengths until OOM (Out-of-Memory) errors occur. It supports sequence parallelism and multiple +GPU configurations. +""" + +import argparse +import os +import shutil +import subprocess +import threading +from typing import List, Optional + +import transformers +import yaml + +from trinity.utils.dlc_utils import is_running, setup_ray_cluster, stop_ray_cluster + +# Default list of GPU counts to test +DEFAULT_GPU_NUMS: List[int] = [1, 2, 4, 6] +EXCEPTION_STRING = "Traceback (most recent call last)" +OOM_STRING = "torch.OutOfMemoryError: CUDA out of memory" +CUDA_ERROR_STRING = "RuntimeError: CUDA error:" + + +def monitor_output( + pipe, + exception_event: threading.Event, + oom_event: threading.Event, + log_file, +): + """Monitors the output stream from a subprocess and sets events if target strings are found. + + Reads lines from the provided pipe (e.g., stdout), writes them to the log file, and checks + whether the output contains the stop or OOM trigger strings. If found, it sets the corresponding + threading event to signal termination. + + Args: + pipe: Readable file-like object (e.g., subprocess.stdout). + exception_event: Threading event set when an exception is detected. + oom_event: Threading event set when 'torch.OutOfMemoryError: CUDA out of memory' is detected. + log_file: Open file handle where output is logged. + """ + try: + for line in iter(pipe.readline, ""): + if not line: + break + # Write to log and flush immediately + log_file.write(line) + log_file.flush() + + # Check for exception + if EXCEPTION_STRING in line: + exception_event.set() + + if exception_event.is_set(): + print(line, end="", flush=True) + + # Check for oom + if OOM_STRING in line or CUDA_ERROR_STRING in line: + exception_event.set() + oom_event.set() + break + except Exception as e: + print(f"Error in monitoring thread: {e}") + + +def run_command_with_monitor( + command: List[str], + envs: dict[str, str], + log_path: str, + checkpoint_path: str, + timeout: Optional[int] = None, + max_retry: int = 10, +) -> bool: + """Runs a shell command with real-time output monitoring and early termination support. + + Executes the specified command, merges stdout and stderr, logs output to a file, and monitors + for exception string. If the string appears or a timeout occurs, the process is terminated. + + Retries execution until no other exception event is raised (i.e., until success or OOM). + + Args: + command: Command to execute, as a list of strings. + envs: Environment variables to set for the command. + log_path: Path to the log file where output will be saved. + checkpoint_path: Path to the checkpoint directory. + timeout: Optional timeout in seconds before forcing termination. + max_retry: Maximum number of retries in case of OOM. + + Returns: + True if the command completed successfully without OOM error; False otherwise. + """ + retry_flag = True + success_flag = False + envs["TRINITY_CHECKPOINT_ROOT_DIR"] = checkpoint_path + process_env = os.environ.copy() + process_env.update(envs) + + for _ in range(max_retry): + # Clean up checkpoint directory before each run + shutil.rmtree(checkpoint_path, ignore_errors=True) + + exception_event = threading.Event() + oom_event = threading.Event() + is_timeout = False + + with open(log_path, "w", encoding="utf-8") as log_file: + # Start subprocess with merged stdout/stderr + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + env=process_env, + ) + + # Start monitoring thread + monitor_thread = threading.Thread( + target=monitor_output, + args=( + process.stdout, + exception_event, + oom_event, + log_file, + ), + daemon=True, + ) + monitor_thread.start() + + try: + # Wait for monitor thread or timeout + if timeout: + monitor_thread.join(timeout) + if monitor_thread.is_alive(): + is_timeout = True + timeout *= 1.3 + else: + monitor_thread.join() + + # Handle process termination based on events + if exception_event.is_set() or is_timeout: + process.terminate() + try: + process.wait(timeout=2) + except subprocess.TimeoutExpired: + process.kill() + + if oom_event.is_set(): # CUDA OOM + retry_flag = False + elif is_timeout: + print("Timeout reached, retrying...") + else: + print("Exception detected, retrying...") + + success_flag = False + else: # no exception, runs successfully + retry_flag = False + success_flag = True + + # Ensure process has fully terminated + if process.poll() is None: + process.wait() + + except KeyboardInterrupt: + process.terminate() + process.wait() + + if not retry_flag: + break + + return success_flag + + +def find_max_model_len( + model_path: str, + model_config, + checkpoint_path: str, + trainer_gpu_num: int, + sp_num: int, + base_log_dir: str, + start_length: int = 4096, + save_hf_checkpoint: str = "last", + entropy_saving: bool = False, + offload: bool = False, + trainer_strategy: str = "fsdp", + timeout: int = 2400, +) -> int: + """Finds the maximum context length the model can handle under current hardware configuration. + + Iteratively increases the `MAX_MODEL_LEN` value and runs training jobs until an OOM error occurs. + Uses different YAML config files depending on whether the length exceeds the original max. + + Args: + model_path: Path to the pretrained model. + model_config: Loaded Hugging Face model configuration. + checkpoint_path: Path to the checkpoint directory. + trainer_gpu_num: Number of GPUs allocated. + sp_num: Number of sequence parallel groups. + base_log_dir: Base directory for saving logs. + start_length: Initial context length to test. + save_hf_checkpoint: Checkpoint saving strategy. + entropy_saving: Whether to enable entropy-saving options. + offload: Whether to offload parameters to CPU. + trainer_strategy: Trainer strategy. Only support "fsdp" and "fsdp2" for now. + timeout: Timeout in seconds for each training job. + + Returns: + Maximum supported context length before OOM; 0 if search failed. + """ + checked_length = 0 + script_dir = os.path.dirname(os.path.abspath(__file__)) + yaml_file = os.path.join(script_dir, "context_length.yaml") + plugin_dir = os.path.join(script_dir, "workflow") + + length = start_length + origin_max_len = model_config.max_position_embeddings + small_step = 4096 + big_step = origin_max_len // 4 + model_name = os.path.basename(model_path) + + while True: + log_dir = os.path.join(base_log_dir, model_name, f"gpu-{trainer_gpu_num}", f"sp-{sp_num}") + os.makedirs(log_dir, exist_ok=True) + logfile = os.path.join(log_dir, f"model_len-{length}.log") + if trainer_gpu_num >= 8: + explorer_gpu_num = 8 + else: + explorer_gpu_num = 1 + total_gpu_num = trainer_gpu_num + explorer_gpu_num + + # Build command + cmd_env = { + "GPU_NUM": f"{total_gpu_num}", + "ENGINE_NUM": f"{explorer_gpu_num}", + "SP_NUM": f"{sp_num}", + "REPEAT_TIMES": f"{trainer_gpu_num // sp_num * 8}", + "MODEL_PATH": f"{model_path}", + "MAX_MODEL_LEN": f"{length}", + } + if length > origin_max_len: + rope_config = { + "rope_type": "yarn", + "factor": length / origin_max_len, + "original_max_position_embeddings": origin_max_len, + } + cmd_env["ROPE_SCALING"] = yaml.dump(rope_config, default_flow_style=True).strip() + if save_hf_checkpoint != "last": + cmd_env["SAVE_HF_CHECKPOINT"] = f"{save_hf_checkpoint}" + if entropy_saving: + cmd_env["ENTROPY_SAVING"] = "true" + if offload: + cmd_env["OFFLOAD"] = "true" + if trainer_strategy != "fsdp": + cmd_env["TRAINER_STRATEGY"] = f"{trainer_strategy}" + + cmd_base = [ + "trinity", + "run", + "--config", + yaml_file, + "--plugin-dir", + plugin_dir, + ] + + print(f"Running: {' '.join(f'{k}={v}' for k, v in cmd_env.items())} {' '.join(cmd_base)}") + + # Run with monitoring + success = run_command_with_monitor( + cmd_base, + cmd_env, + logfile, + checkpoint_path, + timeout=timeout, + ) + + if not success: + break + + checked_length = length + + # Increase step size after exceeding original limit + if length < origin_max_len: + length += small_step + else: + length += big_step + + if checked_length == 0: + print( + f"Search failed for model {model_name} with {trainer_gpu_num} GPUs. " + "Please check the log file for details." + ) + + return checked_length + + +def main(args): + """Main entry point: orchestrates multi-GPU, multi-SP context length testing.""" + if args.dlc: + cluster_namespace = "search_context_length_capacity" + setup_ray_cluster(namespace=cluster_namespace) + + if not is_running(): + raise RuntimeError("Ray is not running, please start it by `ray start --head`.") + + os.makedirs(args.log_dir, exist_ok=True) + + model_name = os.path.basename(args.model_path) + model_config = transformers.AutoConfig.from_pretrained(args.model_path) + + # Map SP group count to starting context length + sp_num_to_start_length = {sp_num: args.start_length for sp_num in args.test_sp_num} + + for trainer_gpu_num in args.test_gpu_num: + # Filter valid SP numbers: divides GPU count and attention heads + sp_list = [ + sp_num + for sp_num in args.test_sp_num + if (trainer_gpu_num % sp_num == 0 and model_config.num_attention_heads % sp_num == 0) + ] + + last_length = 0 + for sp_num in sp_list: + start_length = max(last_length, sp_num_to_start_length[sp_num]) + max_length = find_max_model_len( + model_path=args.model_path, + model_config=model_config, + checkpoint_path=args.checkpoint_path, + trainer_gpu_num=trainer_gpu_num, + sp_num=sp_num, + base_log_dir=args.log_dir, + start_length=start_length, + save_hf_checkpoint=args.save_hf_checkpoint, + entropy_saving=args.entropy_saving, + offload=args.offload, + trainer_strategy=args.trainer_strategy, + timeout=args.timeout, + ) + last_length = max(max_length, args.start_length) + sp_num_to_start_length[sp_num] = last_length + print( + f"model_name = {model_name}, " + f"trainer_gpu_num = {trainer_gpu_num}, " + f"sp_num = {sp_num}, " + f"max_model_len = {max_length}" + ) + + if args.dlc: + stop_ray_cluster(namespace=cluster_namespace) + + +if __name__ == "__main__": + default_log_dir = os.path.join(os.path.dirname(__file__), "logs") + parser = argparse.ArgumentParser( + description="Automated context length scalability testing for LLMs." + ) + parser.add_argument( + "--start_length", + type=int, + default=4096, + help="Starting context length for testing.", + ) + parser.add_argument( + "--model_path", + type=str, + required=True, + help="Directory containing the pretrained models.", + ) + parser.add_argument( + "--log_dir", + type=str, + default=default_log_dir, + help="Directory to store experiment logs.", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + default=os.environ.get("TRINITY_CHECKPOINT_ROOT_DIR", "./checkpoints/length-test"), + help="Checkpoint path for testing. " + "Note that this directory will be deleted during the test, " + "please specify a path that is not used by other processes.", + ) + parser.add_argument( + "--test_gpu_num", + type=int, + nargs="*", + default=DEFAULT_GPU_NUMS, + help="List of GPU counts to test.", + ) + parser.add_argument( + "--test_sp_num", + type=int, + nargs="*", + default=[1], + help="List of sequence parallel sizes to test.", + ) + parser.add_argument( + "--save_hf_checkpoint", + type=str, + choices=["always", "never", "last"], + default="last", + help="Whether to save HF checkpoint.", + ) + parser.add_argument( + "--entropy_saving", + action="store_true", + help="Whether to reduce entropy memory usage.", + ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload model to CPU.", + ) + parser.add_argument( + "--trainer_strategy", + type=str, + choices=["fsdp", "fsdp2"], + default="fsdp", + help="Trainer strategy to use.", + ) + parser.add_argument( + "--timeout", + type=int, + default=2400, + help="Base timeout duration per experiment in seconds. " + "Each retry increases the timeout by 30% (multiplied by 1.3).", + ) + parser.add_argument( + "--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC." + ) + + args = parser.parse_args() + main(args) diff --git a/scripts/context_length_test/workflow/synthetic_exp_workflow.py b/scripts/context_length_test/workflow/synthetic_exp_workflow.py new file mode 100644 index 0000000000..5147acec65 --- /dev/null +++ b/scripts/context_length_test/workflow/synthetic_exp_workflow.py @@ -0,0 +1,26 @@ +import torch + +from trinity.common.experience import Experience +from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task + + +@WORKFLOWS.register_module("synthetic_exp_workflow") +class SyntheticExpWorkflow(SimpleWorkflow): + def reset(self, task: Task): + self.workflow_args = task.workflow_args + self.task = task + self.max_model_len = self.workflow_args.get("max_model_len", 4096) + self.prompt_len = self.workflow_args.get("prompt_len", 2048) + self.response_len = self.max_model_len - self.prompt_len + self.dummy_token = self.workflow_args.get("dummy_token", 1024) + + def run(self): + return [ + Experience( + tokens=torch.full((self.max_model_len,), self.dummy_token, dtype=torch.int32), + logprobs=torch.ones((self.response_len,), dtype=torch.float32), + prompt_length=self.prompt_len, + reward=torch.tensor(0.0), + ) + for _ in range(self.repeat_times) + ]