-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Description
huggingface/transformers#42805 drastically changed the default behaviour of loading models with AutoModelForCausalLM.from_pretrained.
Following up on fixes and suggestions in #4770:
Longer-term: TRL should provide training-oriented defaults
[...]
However, it looks like the load dtype often follows the model dtype, which can implicitly put users/tests into fp16/bf16 without intent:
[...]
- Make the default load dtype fp32: when the user passes a model ID
[...]
key idea is: we should not end up training in the model dtype unless it’s intentional, especially in tests that are not meant to validate this specific (and likely unstable) case.
Originally posted by @qgallouedec in #4770 (comment)
It would be nice if the documentation was updated accordingly for the cases when the initialization of the model is not handled by TRL but by the user. This means basically everywhere AutoModelForCausalLM.from_pretrained is used without explicit dtype argument. For example in the Training Customization docs page:
trl/docs/source/customization.md
Line 19 in 4fea6d1
| model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") |
The documentation should also warn user that, due to this change in transformers, they may unintentionally end up training fully in fp16/bf16, which can negatively affect training stability and convergence.