You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/en/trainer.md
+91Lines changed: 91 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -443,6 +443,97 @@ trainer.train()
443
443
444
444
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
445
445
446
+
### APOLLO
447
+
448
+
Approximated Gradient Scaling for Memory Efficient LLM Optimization (APOLLO) is a memory-efficient training strategy that allows full-parameter learning for both pre-training and fine-tuning, while maintaining AdamW-level performance with SGD-like memory efficiency.
449
+
450
+
***Ultra-low rank efficiency** → Requires much lower rank than GaLore—even rank 1 (APOLLO-Mini) suffices.
451
+
***No expensive SVD computations** → Unlike GaLore, APOLLO leverages random projection, avoiding training stalls.
452
+
453
+
You can read more about the method in the [original repository](https://github.com/zhuhanqing/APOLLO) or the [APOLLO: SGD-like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
454
+
455
+
First, make sure to install APOLLO from its official repository:
456
+
457
+
```bash
458
+
pip install apollo-torch
459
+
```
460
+
461
+
Then, APOLLO optimizers can be used simply by setting `optim="apollo_adamw"` and specifying `optim_target_modules`.
462
+
`optim_target_modules` can be a list of strings, regex or full path corresponding to the target module names you want to adapt.
463
+
Currently, only Linear layers are considered to use the APOLLO optimizers, i.e., included in `optim_target_modules,` while the remaining models are still using AdamW.
464
+
465
+
466
+
You can also enable layer-wise APOLLO by appending "layerwise" to the optimizer name (optim="apollo_adamw_layerwise"), the same as layer-wise GaLore. This saves additional memory for gradient by performing weight updates layer by layer.
467
+
468
+
Below is an end-to-end example script (make sure to `pip install trl datasets`):
469
+
470
+
```python
471
+
import torch
472
+
import datasets
473
+
import trl
474
+
475
+
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
491
+
492
+
trainer = trl.SFTTrainer(
493
+
model=model,
494
+
args=args,
495
+
train_dataset=train_dataset,
496
+
dataset_text_field='text',
497
+
max_seq_length=512,
498
+
)
499
+
500
+
trainer.train()
501
+
```
502
+
503
+
504
+
You can further customize APOLLO’s behavior by passing hyperparameters using `optim_args`.
505
+
506
+
| Parameter | Description |
507
+
|------------------|-------------|
508
+
|`rank`| Rank of the auxiliary sub-space used for gradient scaling. <br> **APOLLO (default=256)** → Works well for 1B and 7B models. <br> **APOLLO-Mini (default=1)**|
509
+
|`scale_type`| How scaling factors are applied. <br> **`channel`** → Per-channel scaling (used in APOLLO). <br> **`tensor`** → Per-tensor scaling (used in APOLLO-Mini). |
510
+
|`scale`| Adjusts gradient updates to stabilize training. <br> **APOLLO (default=1.0)** <br> **APOLLO-Mini (default=128)**|
511
+
|`update_proj_gap`| Steps before updating projection matrices. Default: **200**. |
512
+
|`proj`| Type of projection. Default: **`random`**. |
513
+
514
+
515
+
<Tip>
516
+
517
+
The `scale` parameter can be set to `n/r`, where `n` is the original space dimension and `r` is the low-rank space dimension.
518
+
Alternatively, you can achieve a similar effect by adjusting the learning rate, while keeping scale at its default value.
519
+
520
+
</Tip>
521
+
522
+
For example, you can enable APOLLO-Mini (rank=1 for extreme memory efficiency) by passing `optim_args`:
The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
1502
-
)
1503
-
1504
-
ifmodelisNone:
1505
-
raiseValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")
1506
-
1507
-
logger.warning(
1508
-
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
0 commit comments