Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,139 @@ trainer.train()

Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.


### APOLLO

Approximated Gradient Scaling for Memory Efficient LLM Optimization (APOLLO) is a memory-efficient low-rank training strategy that allows full-parameter learning for both pre-training and fine-tuning, while maintaining AdamW-level performance with SGD-like memory efficiency.

* **Ultra-low rank efficiency** → Requires much lower rank than GaLore—even rank 1 (APOLLO-Mini) suffices.
* **No expensive SVD computations** → Unlike GaLore, APOLLO leverages random projection, avoiding training stalls.

First make sure to install APOLLO from its official repository:

```bash
pip install apollo-torch
```

Then simply add one of `["apollo_adamw"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regex or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`):

```python
import torch
import datasets
import trl

from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-apollo",
max_steps=100,
per_device_train_batch_size=2,
optim="apollo_adamw",
optim_target_modules=[r".*.attn.*", r".*.mlp.*"]
)

model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)

trainer.train()
```

To pass extra arguments supported by APOLLO, you should pass correctly `optim_args`, for example:

```python
import torch
import datasets
import trl

from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
optim_args="proj=random,scale_type=tensor,rank=128,update_proj_gap=100,scale=1.0",

)

model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)

trainer.train()
```

Currently only Linear layers are considered to use the APOLLO optimizers, while the remaining modueles are still using AdamW.

You can read more about the method in the [original repository](https://github.com/zhuhanqing/APOLLO) or the [paper](https://arxiv.org/abs/2412.05270).


You can also perform layer-wise APOLLO by simply post-pending the optimizer name with `layerwise` like below:

```python
import torch
import datasets
import trl

from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-apollo",
max_steps=100,
per_device_train_batch_size=2,
optim="apollo_adamw_layerwise",
optim_target_modules=[r".*.attn.*", r".*.mlp.*"]
)

model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)

trainer.train()
```


### LOMO optimizer

The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
GGUF_MIN_VERSION,
is_accelerate_available,
is_apex_available,
is_apollo_torch_available,
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
Expand Down Expand Up @@ -403,6 +404,14 @@ def require_galore_torch(test_case):
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)


def require_apollo_torch(test_case):
"""
Decorator marking a test that requires GaLore. These tests are skipped when APOLLO isn't installed.
https://github.com/zhuhanqing/APOLLO
"""
return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case)


def require_lomo(test_case):
"""
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
Expand Down
114 changes: 114 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
find_labels,
is_accelerate_available,
is_apex_available,
is_apollo_torch_available,
is_bitsandbytes_available,
is_datasets_available,
is_galore_torch_available,
Expand Down Expand Up @@ -1582,6 +1583,119 @@ def optimizer_hook(param):

if args.optim == OptimizerNames.GALORE_ADAFACTOR:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
elif args.optim in [
OptimizerNames.APOLLO_ADAMW,
OptimizerNames.APOLLO_ADAMW_LAYERWISE,
]:
if not is_apollo_torch_available():
raise ImportError(
"You need to install `apollo_torch` in order to use APOLLO optimizers"
" install it with `pip install git+https://github.com/zhuhanqing/APOLLO`"
)
from apollo_torch import APOLLOAdamW

is_layerwise = args.optim.lower().endswith("layerwise")
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
raise NotImplementedError("Layer-wise APOLLO does not support DDP at this time")

optimizer_mapping = {
OptimizerNames.APOLLO_ADAMW: APOLLOAdamW,
OptimizerNames.APOLLO_ADAMW_LAYERWISE: APOLLOAdamW,
}

optimizer_cls = optimizer_mapping[args.optim]

if args.optim_target_modules is None:
raise ValueError(
"You need to define a `optim_target_modules` in order to properly use APOLLO optimizers"
)

if not isinstance(args.optim_target_modules, (list, str)):
raise ValueError(
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
)

if model is None:
raise ValueError("You need to pass a model in order to correctly initialize a APOLLO optimizer.")

all_linear = (
isinstance(args.optim_target_modules, str)
and args.optim_target_modules.replace("_", "-") == "all-linear"
)

apollo_params = []
apollo_params_names = []
for module_name, module in model.named_modules():
target_module_exists, is_regex = check_target_module_exists(
args.optim_target_modules, module_name, return_is_regex=True
)

if not isinstance(module, nn.Linear):
# Warn in case we match but it's not a linear layer
if target_module_exists and not is_regex:
logger.warning(
f"{module_name} has been matched but ignored as APOLLO only supports linear layers. Please double check your `optim_target_modules`!"
)

continue

if not target_module_exists and not all_linear:
continue

apollo_params.append(module.weight)
apollo_params_names.append(module_name + ".weight")

if len(apollo_params) == 0:
raise ValueError(
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
)

non_apollo_params = [p for n, p in model.named_parameters() if n not in apollo_params_names]
apollo_optim_kwargs = {
"rank": int(optim_args.pop("rank", 128)),
"proj": optim_args.pop("proj", "random"),
"scale_type": optim_args.pop("scale_type", "channel"),
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
"scale": float(optim_args.pop("scale", 1.0)),
"proj_type": optim_args.pop("proj_type", "std"),
}

# The default args are from the official repository: https://github.com/zhuhanqing/APOLLO
param_groups = [
{"params": non_apollo_params},
{"params": apollo_params, **apollo_optim_kwargs},
]

if is_layerwise:
# For layer-wise optimizers, the optimization step is done through post accumulation
# gradient hooks. The trick is to first attach these hooks to the model parameters then
# create a dummy optimizer that will perform no-ops in the Trainer.
# See the original implementation or the nice implementation from @hiyouga
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
if args.gradient_accumulation_steps != 1:
raise ValueError("Layerwise APOLLO optimizer do not support gradient accumulation !")

optimizer_dict = {}
for param in non_apollo_params:
param_groups = [{"params": [param]}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
for param in apollo_params:
param_groups = [{"params": [param], **apollo_optim_kwargs}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)

def optimizer_hook(param):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()

for param in model.parameters():
if param.requires_grad:
param.register_post_accumulate_grad_hook(optimizer_hook)

optimizer_cls = LayerWiseDummyOptimizer
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})

optimizer_kwargs.update({"params": param_groups})
elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
if not is_lomo_available():
raise ImportError(
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ class OptimizerNames(ExplicitEnum):
GROKADAMW = "grokadamw"
SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
SCHEDULE_FREE_SGD = "schedule_free_sgd"
APOLLO_ADAMW = "apollo_adamw"
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"


# Sometimes users will pass in a `str` repr of a dict in the CLI
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
get_torch_version,
is_accelerate_available,
is_apex_available,
is_apollo_torch_available,
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[

_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
_apollo_torch_available = _is_package_available("apollo_torch")
_aqlm_available = _is_package_available("aqlm")
_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True)
_av_available = importlib.util.find_spec("av") is not None
Expand Down Expand Up @@ -402,6 +403,10 @@ def is_galore_torch_available():
return _galore_torch_available


def is_apollo_torch_available():
return _apollo_torch_available


def is_lomo_available():
return _lomo_available

Expand Down
Loading