-
Notifications
You must be signed in to change notification settings - Fork 50
Add FAQ in docs #109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add FAQ in docs #109
Changes from 1 commit
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
0b5df71
add faq
hiyuchang 0f7c58a
update faq
hiyuchang a7414d6
add faq to readme
hiyuchang e4f0e90
fix typo
hiyuchang 4568943
rm a verl param
hiyuchang 2dcd722
fix readme
hiyuchang 3ca7f98
fix comments
hiyuchang 4fc18fc
fix pre-commit issue
hiyuchang bab822b
fix comments
hiyuchang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| # FAQ | ||
|
|
||
| ## Part 1: Configurations | ||
| **Q:** Why do most examples have two configuration YAML files, e.g., `gsm8k.yaml` and `train_gsm8k.yaml` in the `examples/grpo_gsm8k` directory? | ||
|
|
||
| **A:** Trinity-RFT uses [veRL](https://github.com/volcengine/verl) as the training backend, and the auxiliary YAML file starting with `train_` is used for configuring veRL, referred to [veRL documentation](https://github.com/volcengine/verl/blob/v0.4.0/docs/examples/config.rst). | ||
| If you specify the path to `train_gsm8k.yaml` in `trainer.trainer_config_path`, Trinity-RFT will automatically pass the parameters to veRL. | ||
|
|
||
| We provide an alternative way to configure the veRL trainer. You may also directly specify the parameters in the `trainer.trainer_config` dictionary. This approach is mutually exclusive with using `trainer.trainer_config_path`. | ||
|
|
||
| Note that some parameters are not listed in the auxiliary configuration file (e.g., `train_gsm8k.yaml`), as they will be overridden by the parameters in the trinity configuration file (e.g., `gsm8k.yaml`). Please refer to `./trinity_configs.md` for more details. | ||
| Future versions will gradually reduce parameters in `trainer.trainer_config` and `trainer.trainer_config_path` until it's fully deprecated. | ||
|
|
||
| --- | ||
|
|
||
| **Q:** What's the relationship between `buffer.batch_size`, `actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu` and other batch sizes? | ||
|
|
||
| **A:** The following parameters are closely related: | ||
|
|
||
| - `buffer.batch_size`: The number of tasks in a batch, effective for both the explorer and the trainer. | ||
| - `actor_rollout_ref.actor.ppo_mini_batch_size`: In the configuration, this value represents the number of tasks in a mini-batch, overridden by `buffer.batch_size`; but in the `update_policy` function, its value becomes the number of experiences in a mini-batch per GPU, i.e., `buffer.batch_size * algorithm.repeat_times / ngpus_trainer`. | ||
| - `actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu`: The number of experiences in a micro-batch per GPU. | ||
|
|
||
| A minimal example showing their usage is as follows: | ||
|
|
||
| ```python | ||
| def update_policy(batch): | ||
| dataloader = batch.split(ppo_mini_batch_size) | ||
| for _ in range(ppo_epochs): | ||
| for batch_idx, data in enumerate(dataloader): | ||
| # Split data | ||
| mini_batch = data | ||
| if actor_rollout_ref.actor.use_dynamic_bsz: | ||
| micro_batches, _ = rearrange_micro_batches( | ||
| batch=mini_batch, max_token_len=max_token_len | ||
| ) | ||
| else: | ||
| micro_batches = mini_batch.split(actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu) | ||
|
|
||
| # Computing gradient | ||
| for data in micro_batches: | ||
| entropy, log_prob = self._forward_micro_batch( | ||
| micro_batch=data, ... | ||
| ) | ||
| pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( | ||
| log_prob=log_prob, **data | ||
| ) | ||
| policy_loss = pg_loss + ... | ||
| loss = policy_loss / self.gradient_accumulation | ||
| loss.backward() | ||
|
|
||
| # Optimizer step | ||
| grad_norm = self._optimizer_step() | ||
| self.actor_optimizer.zero_grad() | ||
| ``` | ||
| Please refer to `trinity/trainer/verl/dp_actor.py` for detailed implementation. veRL also provides an explanation in [FAQ](https://verl.readthedocs.io/en/latest/faq/faq.html#what-is-the-meaning-of-train-batch-size-mini-batch-size-and-micro-batch-size). | ||
|
|
||
|
|
||
| ## Part 2: Common Errors | ||
|
|
||
| **Error:** | ||
| ```bash | ||
| File ".../flash_attn/flash_attn_interface.py", line 15, in ‹module> | ||
| import flash_attn_2_cuda as flash_attn_gpu | ||
| ImportError: ... | ||
| ``` | ||
|
|
||
| **A:** The `flash-attn` module is not properly installed. Try to fix it by running `MAX_JOBS=128 pip install flash-attn`. | ||
|
|
||
| --- | ||
|
|
||
| **Error:** | ||
| ```bash | ||
| UsageError: api_key not configured (no-tty). call wandb.login(key=[your_api_key]) ... | ||
| ``` | ||
|
|
||
| **A:** Try to log in to WandB before running the experiment. One way to do this is run the command `export WANDB_API_KEY=[your_api_key]`. | ||
|
|
||
| --- | ||
|
|
||
| **Error:** | ||
| ```bash | ||
| ValueError: Failed to look up actor with name 'explorer' ... | ||
| ``` | ||
|
|
||
| **A:** Try to restart Ray before running the experiment: | ||
hiyuchang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ```bash | ||
| ray stop | ||
| ray start --head | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| **Error:** Out-of-Memory (OOM) error | ||
|
|
||
| **A:** The following parameters may be helpful: | ||
|
|
||
| - For trainer, adjust `actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu` when `actor_rollout_ref.actor.use_dynamic_bsz=false`; adjust `actor_rollout_ref.actor.ppo_max_token_len_per_gpu` and `actor_rollout_ref.actor.ulysses_sequence_parallel_size` when `actor_rollout_ref.actor.use_dynamic_bsz=true`. | ||
| - For exploere, adjust `explorer.rollout_model.tensor_parallel_size`, | ||
hiyuchang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| ## Part 3: Debugging Methods [Coming Soon] | ||
| To see the full logs of all processes and save it to `debug.log`: | ||
| ```bash | ||
| export RAY_DEDUP_LOGS=0 | ||
| trinity run --config grpo_gsm8k/gsm8k.yaml 2>&1 | tee debug.log | ||
| ``` | ||
|
|
||
|
|
||
| ## Part 4: Other Questions | ||
| **Q:** What's the purpose of `buffer.trainer_input.experience_buffer.path`? | ||
|
|
||
| **A:** This path specifies the path to the SQLite database storaging the generated experiences. You may comment out this line if you don't want to use the SQLite database. | ||
|
|
||
| To see the experiences in the database, you can use the following Python script: | ||
|
|
||
| ```python | ||
| from sqlalchemy import create_engine | ||
| from sqlalchemy.exc import OperationalError | ||
| from sqlalchemy.orm import sessionmaker | ||
| from sqlalchemy.pool import NullPool | ||
| from trinity.common.schema import ExperienceModel | ||
|
|
||
| engine = create_engine(buffer.trainer_input.experience_buffer.path) | ||
| session = sessionmaker(bind=engine) | ||
| sess = session() | ||
|
|
||
| MAX_EXPERIENCES = 4 | ||
| experiences = ( | ||
| sess.query(ExperienceModel) | ||
| .with_for_update() | ||
| .limit(MAX_EXPERIENCES) | ||
| .all() | ||
| ) | ||
|
|
||
| exp_list = [] | ||
| for exp in experiences: | ||
| exp_list.append(ExperienceModel.to_experience(exp)) | ||
|
|
||
| # Print the experiences | ||
| for exp in exp_list: | ||
| print(f"{exp.prompt_text=}", f"{exp.response_text=}") | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| **Q:** How to load the checkpoints outside of the Trinity-RFT framework? | ||
|
|
||
| **A:** You need to specify `model.model_path` and `checkpoint_root_dir`. The following code snippet gives an example with transformers. | ||
|
|
||
| ```python | ||
| from transformers import AutoTokenizer, AutoModelForCausalLM | ||
| from trinity.common.models.utils import load_state_dict_from_verl_checkpoint | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained(model.model_path) | ||
| # Assume we need the checkpoint at step 780 | ||
| ckp_path = checkpoint_root_dir + "global_step_780/actor/" | ||
hiyuchang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| model.load_state_dict(load_state_dict_from_verl_checkpoint(ckp_path)) | ||
| ``` | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.