Skip to content

Commit 08c4959

Browse files
authored
Optim: APOLLO optimizer integration (#36062)
* Added APOLLO optimizer integration * fix comment * Remove redundancy: Modularize low-rank optimizer construction * Remove redundancy: Remove useless comment * Fix comment: Add typing * Fix comment: Rewrite apollo desc
1 parent 2440512 commit 08c4959

File tree

7 files changed

+403
-98
lines changed

7 files changed

+403
-98
lines changed

docs/source/en/trainer.md

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,97 @@ trainer.train()
443443

444444
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
445445

446+
### APOLLO
447+
448+
Approximated Gradient Scaling for Memory Efficient LLM Optimization (APOLLO) is a memory-efficient training strategy that allows full-parameter learning for both pre-training and fine-tuning, while maintaining AdamW-level performance with SGD-like memory efficiency.
449+
450+
* **Ultra-low rank efficiency** → Requires much lower rank than GaLore—even rank 1 (APOLLO-Mini) suffices.
451+
* **No expensive SVD computations** → Unlike GaLore, APOLLO leverages random projection, avoiding training stalls.
452+
453+
You can read more about the method in the [original repository](https://github.com/zhuhanqing/APOLLO) or the [APOLLO: SGD-like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
454+
455+
First, make sure to install APOLLO from its official repository:
456+
457+
```bash
458+
pip install apollo-torch
459+
```
460+
461+
Then, APOLLO optimizers can be used simply by setting `optim="apollo_adamw"` and specifying `optim_target_modules`.
462+
`optim_target_modules` can be a list of strings, regex or full path corresponding to the target module names you want to adapt.
463+
Currently, only Linear layers are considered to use the APOLLO optimizers, i.e., included in `optim_target_modules,` while the remaining models are still using AdamW.
464+
465+
466+
You can also enable layer-wise APOLLO by appending "layerwise" to the optimizer name (optim="apollo_adamw_layerwise"), the same as layer-wise GaLore. This saves additional memory for gradient by performing weight updates layer by layer.
467+
468+
Below is an end-to-end example script (make sure to `pip install trl datasets`):
469+
470+
```python
471+
import torch
472+
import datasets
473+
import trl
474+
475+
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
476+
477+
train_dataset = datasets.load_dataset('imdb', split='train')
478+
479+
args = TrainingArguments(
480+
output_dir="./test-apollo",
481+
max_steps=100,
482+
per_device_train_batch_size=2,
483+
optim="apollo_adamw",
484+
optim_target_modules=[r".*.attn.*", r".*.mlp.*"]
485+
)
486+
487+
model_id = "google/gemma-2b"
488+
489+
tokenizer = AutoTokenizer.from_pretrained(model_id)
490+
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
491+
492+
trainer = trl.SFTTrainer(
493+
model=model,
494+
args=args,
495+
train_dataset=train_dataset,
496+
dataset_text_field='text',
497+
max_seq_length=512,
498+
)
499+
500+
trainer.train()
501+
```
502+
503+
504+
You can further customize APOLLO’s behavior by passing hyperparameters using `optim_args`.
505+
506+
| Parameter | Description |
507+
|------------------|-------------|
508+
| `rank` | Rank of the auxiliary sub-space used for gradient scaling. <br> **APOLLO (default=256)** → Works well for 1B and 7B models. <br> **APOLLO-Mini (default=1)** |
509+
| `scale_type` | How scaling factors are applied. <br> **`channel`** → Per-channel scaling (used in APOLLO). <br> **`tensor`** → Per-tensor scaling (used in APOLLO-Mini). |
510+
| `scale` | Adjusts gradient updates to stabilize training. <br> **APOLLO (default=1.0)** <br> **APOLLO-Mini (default=128)** |
511+
| `update_proj_gap` | Steps before updating projection matrices. Default: **200**. |
512+
| `proj` | Type of projection. Default: **`random`**. |
513+
514+
515+
<Tip>
516+
517+
The `scale` parameter can be set to `n/r`, where `n` is the original space dimension and `r` is the low-rank space dimension.
518+
Alternatively, you can achieve a similar effect by adjusting the learning rate, while keeping scale at its default value.
519+
520+
</Tip>
521+
522+
For example, you can enable APOLLO-Mini (rank=1 for extreme memory efficiency) by passing `optim_args`:
523+
524+
```python
525+
526+
args = TrainingArguments(
527+
output_dir="./test-galore",
528+
max_steps=100,
529+
per_device_train_batch_size=2,
530+
optim="apollo_adamw",
531+
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
532+
optim_args="proj=random,rank=1,scale=128.0,scale_type=tensor,update_proj_gap=200",
533+
534+
)
535+
```
536+
446537
### LOMO optimizer
447538

448539
The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).

src/transformers/testing_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
GGUF_MIN_VERSION,
6363
is_accelerate_available,
6464
is_apex_available,
65+
is_apollo_torch_available,
6566
is_aqlm_available,
6667
is_auto_awq_available,
6768
is_auto_gptq_available,
@@ -404,6 +405,14 @@ def require_galore_torch(test_case):
404405
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
405406

406407

408+
def require_apollo_torch(test_case):
409+
"""
410+
Decorator marking a test that requires GaLore. These tests are skipped when APOLLO isn't installed.
411+
https://github.com/zhuhanqing/APOLLO
412+
"""
413+
return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case)
414+
415+
407416
def require_lomo(test_case):
408417
"""
409418
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.

src/transformers/trainer.py

Lines changed: 128 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@
151151
find_labels,
152152
is_accelerate_available,
153153
is_apex_available,
154+
is_apollo_torch_available,
154155
is_bitsandbytes_available,
155156
is_datasets_available,
156157
is_galore_torch_available,
@@ -1315,6 +1316,103 @@ def get_optimizer_cls_and_kwargs(
13151316
"betas": (args.adam_beta1, args.adam_beta2),
13161317
"eps": args.adam_epsilon,
13171318
}
1319+
1320+
def setup_low_rank_optimizer(
1321+
optimizer_name: str,
1322+
optimizer_mapping: Dict[str, Any],
1323+
optim_kwargs: Dict[str, Any],
1324+
is_layerwise_supported: bool = True,
1325+
) -> Tuple[Any, Any]:
1326+
"""
1327+
Helper function to set up low-rank optimizers like GaLore and Apollo.
1328+
1329+
Args:
1330+
optimizer_name (str): Name of the optimizer.
1331+
optimizer_mapping (dict): Mapping of optimizer names to their classes.
1332+
optim_kwargs (dict): Keyword arguments for the optimizer.
1333+
is_layerwise_supported (bool): Whether layerwise optimization is supported.
1334+
1335+
Returns:
1336+
Tuple[Any, Any]: Optimizer class and updated optimizer kwargs.
1337+
"""
1338+
is_layerwise = optimizer_name.lower().endswith("layerwise")
1339+
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED and is_layerwise_supported:
1340+
raise NotImplementedError(f"Layer-wise {optimizer_name} does not support DDP at this time")
1341+
1342+
optimizer_cls = optimizer_mapping[optimizer_name]
1343+
1344+
if args.optim_target_modules is None:
1345+
raise ValueError(f"You need to define `optim_target_modules` to use {optimizer_name} optimizers")
1346+
1347+
if not isinstance(args.optim_target_modules, (list, str)):
1348+
raise ValueError(
1349+
f"`optim_target_modules` must be a list of strings, a regex string, or 'all-linear'. Got: {args.optim_target_modules}"
1350+
)
1351+
1352+
if model is None:
1353+
raise ValueError(f"You need to pass a model to initialize {optimizer_name} optimizer.")
1354+
1355+
all_linear = (
1356+
isinstance(args.optim_target_modules, str)
1357+
and args.optim_target_modules.replace("_", "-") == "all-linear"
1358+
)
1359+
1360+
target_params = []
1361+
target_params_names = []
1362+
for module_name, module in model.named_modules():
1363+
target_module_exists, is_regex = check_target_module_exists(
1364+
args.optim_target_modules, module_name, return_is_regex=True
1365+
)
1366+
1367+
if not isinstance(module, nn.Linear):
1368+
if target_module_exists and not is_regex:
1369+
logger.warning(
1370+
f"{module_name} matched but ignored. {optimizer_name} only supports linear layers."
1371+
)
1372+
continue
1373+
1374+
if not target_module_exists and not all_linear:
1375+
continue
1376+
1377+
target_params.append(module.weight)
1378+
target_params_names.append(module_name + ".weight")
1379+
1380+
if len(target_params) == 0:
1381+
raise ValueError(f"No target modules found for {optimizer_name} ({args.optim_target_modules}).")
1382+
1383+
non_target_params = [p for n, p in model.named_parameters() if n not in target_params_names]
1384+
optim_kwargs.update(optim_args)
1385+
1386+
param_groups = [
1387+
{"params": non_target_params},
1388+
{"params": target_params, **optim_kwargs},
1389+
]
1390+
1391+
if is_layerwise:
1392+
if args.gradient_accumulation_steps != 1:
1393+
raise ValueError(f"Layerwise {optimizer_name} does not support gradient accumulation!")
1394+
1395+
optimizer_dict = {}
1396+
for param in non_target_params:
1397+
optimizer_dict[param] = optimizer_cls([{"params": [param]}], **optimizer_kwargs)
1398+
for param in target_params:
1399+
optimizer_dict[param] = optimizer_cls([{"params": [param], **optim_kwargs}], **optimizer_kwargs)
1400+
1401+
def optimizer_hook(param):
1402+
if param.grad is not None:
1403+
optimizer_dict[param].step()
1404+
optimizer_dict[param].zero_grad()
1405+
1406+
for param in model.parameters():
1407+
if param.requires_grad:
1408+
param.register_post_accumulate_grad_hook(optimizer_hook)
1409+
1410+
optimizer_cls = LayerWiseDummyOptimizer
1411+
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
1412+
1413+
optimizer_kwargs.update({"params": param_groups})
1414+
return optimizer_cls, optimizer_kwargs
1415+
13181416
if args.optim == OptimizerNames.ADAFACTOR:
13191417
optimizer_cls = Adafactor
13201418
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
@@ -1476,10 +1574,6 @@ def get_optimizer_cls_and_kwargs(
14761574
)
14771575
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
14781576

1479-
is_layerwise = args.optim.lower().endswith("layerwise")
1480-
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
1481-
raise NotImplementedError("Layer-wise GaLore does not support DDP at this time")
1482-
14831577
optimizer_mapping = {
14841578
OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
14851579
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
@@ -1489,105 +1583,46 @@ def get_optimizer_cls_and_kwargs(
14891583
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
14901584
}
14911585

1492-
optimizer_cls = optimizer_mapping[args.optim]
1493-
1494-
if args.optim_target_modules is None:
1495-
raise ValueError(
1496-
"You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
1497-
)
1498-
1499-
if not isinstance(args.optim_target_modules, (list, str)):
1500-
raise ValueError(
1501-
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
1502-
)
1503-
1504-
if model is None:
1505-
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")
1506-
1507-
logger.warning(
1508-
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
1509-
)
1510-
1511-
all_linear = (
1512-
isinstance(args.optim_target_modules, str)
1513-
and args.optim_target_modules.replace("_", "-") == "all-linear"
1514-
)
1515-
1516-
galore_params = []
1517-
galore_params_names = []
1518-
for module_name, module in model.named_modules():
1519-
target_module_exists, is_regex = check_target_module_exists(
1520-
args.optim_target_modules, module_name, return_is_regex=True
1521-
)
1522-
1523-
if not isinstance(module, nn.Linear):
1524-
# Warn in case we match but it's not a linear layer
1525-
if target_module_exists and not is_regex:
1526-
logger.warning(
1527-
f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
1528-
)
1529-
1530-
continue
1531-
1532-
if not target_module_exists and not all_linear:
1533-
continue
1534-
1535-
galore_params.append(module.weight)
1536-
galore_params_names.append(module_name + ".weight")
1537-
1538-
if len(galore_params) == 0:
1539-
raise ValueError(
1540-
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
1541-
)
1542-
1543-
non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]
1544-
15451586
galore_optim_kwargs = {
15461587
"rank": int(optim_args.pop("rank", 128)),
15471588
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
15481589
"scale": float(optim_args.pop("scale", 0.25)),
15491590
"proj_type": optim_args.pop("proj_type", "std"),
15501591
}
15511592

1552-
# The default args are from the official repository: https://github.com/jiaweizzhao/GaLore
1553-
param_groups = [
1554-
{"params": non_galore_params},
1555-
{"params": galore_params, **galore_optim_kwargs},
1556-
]
1557-
1558-
if is_layerwise:
1559-
# For layer-wise optimizers, the optimization step is done through post accumulation
1560-
# gradient hooks. The trick is to first attach these hooks to the model parameters then
1561-
# create a dummy optimizer that will perform no-ops in the Trainer.
1562-
# See the original implementation or the nice implementation from @hiyouga
1563-
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
1564-
if args.gradient_accumulation_steps != 1:
1565-
raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !")
1566-
1567-
optimizer_dict = {}
1568-
for param in non_galore_params:
1569-
param_groups = [{"params": [param]}]
1570-
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
1571-
for param in galore_params:
1572-
param_groups = [{"params": [param], **galore_optim_kwargs}]
1573-
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
1574-
1575-
def optimizer_hook(param):
1576-
if param.grad is not None:
1577-
optimizer_dict[param].step()
1578-
optimizer_dict[param].zero_grad()
1579-
1580-
for param in model.parameters():
1581-
if param.requires_grad:
1582-
param.register_post_accumulate_grad_hook(optimizer_hook)
1593+
optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
1594+
args.optim, optimizer_mapping, galore_optim_kwargs
1595+
)
1596+
if args.optim == OptimizerNames.GALORE_ADAFACTOR:
1597+
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
1598+
elif args.optim in [
1599+
OptimizerNames.APOLLO_ADAMW,
1600+
OptimizerNames.APOLLO_ADAMW_LAYERWISE,
1601+
]:
1602+
if not is_apollo_torch_available():
1603+
raise ImportError(
1604+
"You need to install `apollo_torch` in order to use APOLLO optimizers"
1605+
" install it with `pip install git+https://github.com/zhuhanqing/APOLLO`"
1606+
)
1607+
from apollo_torch import APOLLOAdamW
15831608

1584-
optimizer_cls = LayerWiseDummyOptimizer
1585-
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
1609+
optimizer_mapping = {
1610+
OptimizerNames.APOLLO_ADAMW: APOLLOAdamW,
1611+
OptimizerNames.APOLLO_ADAMW_LAYERWISE: APOLLOAdamW,
1612+
}
15861613

1587-
optimizer_kwargs.update({"params": param_groups})
1614+
apollo_optim_kwargs = {
1615+
"rank": int(optim_args.pop("rank", 128)),
1616+
"proj": optim_args.pop("proj", "random"),
1617+
"scale_type": optim_args.pop("scale_type", "channel"),
1618+
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
1619+
"scale": float(optim_args.pop("scale", 1.0)),
1620+
"proj_type": optim_args.pop("proj_type", "std"),
1621+
}
15881622

1589-
if args.optim == OptimizerNames.GALORE_ADAFACTOR:
1590-
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
1623+
optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
1624+
args.optim, optimizer_mapping, apollo_optim_kwargs
1625+
)
15911626
elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
15921627
if not is_lomo_available():
15931628
raise ImportError(

0 commit comments

Comments
 (0)