Mixed Precision Training #336
Replies: 1 comment 2 replies
-
In the example at https://github.com/google/flax/tree/main/examples/wmt, the model dtype is specified as bfloat16 (if running on TPU): https://github.com/google/flax/blob/2ac765a5c056dc57bcaa70ba5e0bd2f4933d2ed0/examples/wmt/train.py#L455 However, the optimiser states are entirely in float32, matching the dtype of the params (float32) and grads (float32): https://github.com/google/flax/blob/2ac765a5c056dc57bcaa70ba5e0bd2f4933d2ed0/examples/wmt/train.py#L497-L508 cc @marcvanzee |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
The Adam optimiser, or one of its commonly used variants such as AdamW, hold two accumulators (mu and nu) per model weight. For a model of size 1GB, this results in an optimiser of size 2GB. Thus, before even loading a batch of training data, there is already 3GB of memory allocation required on an accelerator device, of which 2/3 comes from the optimiser. This is a potentially limiting factor in determining the maximum permissible batch size per-device, and naturally raises the question of how one might reduce the memory requirements of the optimiser. Of course, one might switch to alternative optimiser that only holds one accumulator per weight and factors the second order estimate, such as Adafactor, but the factored estimate of the second order moments might negatively impact training results compared to the unfactored approach taken in Adam. An alternative approach is mixed precision training. Here, the model weights are kept in full precision (float32), but the optimiser states are saved in half precision (bfloat16). Since each of the two accumulators require half the number of bytes as one model parameter to which they are bound, the optimiser memory is halved. As a result, for a full precision model of size 1GB, the optimiser now only has a size of 1GB. It is generally not advised to train using full half precision training, in which the model parameters are also kept in half precision, as this leads to poor numerical stability. Consequently, for mixed precision training, it is evident that one must keep the parameters in full precision and the optimiser states in half precision. However, it is less clear what data types (dtypes) are appropriate for other variables in a train step.
Option 1:
Option 2:
Here is a Colab that implements both options for a simple network consisting of a single linear layer: https://drive.google.com/file/d/1xqq24YP_MPrf2j3u2i0MOypCDZOOpQJK/view?usp=sharing
Beta Was this translation helpful? Give feedback.
All reactions