diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md
index 67bb2ae4f594..96897948b1e8 100644
--- a/docs/source/en/trainer.md
+++ b/docs/source/en/trainer.md
@@ -443,6 +443,97 @@ trainer.train()
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
+### APOLLO
+
+Approximated Gradient Scaling for Memory Efficient LLM Optimization (APOLLO) is a memory-efficient training strategy that allows full-parameter learning for both pre-training and fine-tuning, while maintaining AdamW-level performance with SGD-like memory efficiency.
+
+* **Ultra-low rank efficiency** → Requires much lower rank than GaLore—even rank 1 (APOLLO-Mini) suffices.
+* **No expensive SVD computations** → Unlike GaLore, APOLLO leverages random projection, avoiding training stalls.
+
+You can read more about the method in the [original repository](https://github.com/zhuhanqing/APOLLO) or the [APOLLO: SGD-like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
+
+First, make sure to install APOLLO from its official repository:
+
+```bash
+pip install apollo-torch
+```
+
+Then, APOLLO optimizers can be used simply by setting `optim="apollo_adamw"` and specifying `optim_target_modules`.
+`optim_target_modules` can be a list of strings, regex or full path corresponding to the target module names you want to adapt.
+Currently, only Linear layers are considered to use the APOLLO optimizers, i.e., included in `optim_target_modules,` while the remaining models are still using AdamW.
+
+
+You can also enable layer-wise APOLLO by appending "layerwise" to the optimizer name (optim="apollo_adamw_layerwise"), the same as layer-wise GaLore. This saves additional memory for gradient by performing weight updates layer by layer.
+
+Below is an end-to-end example script (make sure to `pip install trl datasets`):
+
+```python
+import torch
+import datasets
+import trl
+
+from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
+
+train_dataset = datasets.load_dataset('imdb', split='train')
+
+args = TrainingArguments(
+ output_dir="./test-apollo",
+ max_steps=100,
+ per_device_train_batch_size=2,
+ optim="apollo_adamw",
+ optim_target_modules=[r".*.attn.*", r".*.mlp.*"]
+)
+
+model_id = "google/gemma-2b"
+
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
+
+trainer = trl.SFTTrainer(
+ model=model,
+ args=args,
+ train_dataset=train_dataset,
+ dataset_text_field='text',
+ max_seq_length=512,
+)
+
+trainer.train()
+```
+
+
+You can further customize APOLLO’s behavior by passing hyperparameters using `optim_args`.
+
+| Parameter | Description |
+|------------------|-------------|
+| `rank` | Rank of the auxiliary sub-space used for gradient scaling.
**APOLLO (default=256)** → Works well for 1B and 7B models.
**APOLLO-Mini (default=1)** |
+| `scale_type` | How scaling factors are applied.
**`channel`** → Per-channel scaling (used in APOLLO).
**`tensor`** → Per-tensor scaling (used in APOLLO-Mini). |
+| `scale` | Adjusts gradient updates to stabilize training.
**APOLLO (default=1.0)**
**APOLLO-Mini (default=128)** |
+| `update_proj_gap` | Steps before updating projection matrices. Default: **200**. |
+| `proj` | Type of projection. Default: **`random`**. |
+
+
+
+
+The `scale` parameter can be set to `n/r`, where `n` is the original space dimension and `r` is the low-rank space dimension.
+Alternatively, you can achieve a similar effect by adjusting the learning rate, while keeping scale at its default value.
+
+
+
+For example, you can enable APOLLO-Mini (rank=1 for extreme memory efficiency) by passing `optim_args`:
+
+```python
+
+args = TrainingArguments(
+ output_dir="./test-galore",
+ max_steps=100,
+ per_device_train_batch_size=2,
+ optim="apollo_adamw",
+ optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
+ optim_args="proj=random,rank=1,scale=128.0,scale_type=tensor,update_proj_gap=200",
+
+)
+```
+
### LOMO optimizer
The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index 6d1965e29d79..a653162cffef 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -62,6 +62,7 @@
GGUF_MIN_VERSION,
is_accelerate_available,
is_apex_available,
+ is_apollo_torch_available,
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
@@ -403,6 +404,14 @@ def require_galore_torch(test_case):
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
+def require_apollo_torch(test_case):
+ """
+ Decorator marking a test that requires GaLore. These tests are skipped when APOLLO isn't installed.
+ https://github.com/zhuhanqing/APOLLO
+ """
+ return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case)
+
+
def require_lomo(test_case):
"""
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 5da300a90966..627d3d79cd34 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -152,6 +152,7 @@
find_labels,
is_accelerate_available,
is_apex_available,
+ is_apollo_torch_available,
is_bitsandbytes_available,
is_datasets_available,
is_galore_torch_available,
@@ -1310,6 +1311,103 @@ def get_optimizer_cls_and_kwargs(
"betas": (args.adam_beta1, args.adam_beta2),
"eps": args.adam_epsilon,
}
+
+ def setup_low_rank_optimizer(
+ optimizer_name: str,
+ optimizer_mapping: Dict[str, Any],
+ optim_kwargs: Dict[str, Any],
+ is_layerwise_supported: bool = True,
+ ) -> Tuple[Any, Any]:
+ """
+ Helper function to set up low-rank optimizers like GaLore and Apollo.
+
+ Args:
+ optimizer_name (str): Name of the optimizer.
+ optimizer_mapping (dict): Mapping of optimizer names to their classes.
+ optim_kwargs (dict): Keyword arguments for the optimizer.
+ is_layerwise_supported (bool): Whether layerwise optimization is supported.
+
+ Returns:
+ Tuple[Any, Any]: Optimizer class and updated optimizer kwargs.
+ """
+ is_layerwise = optimizer_name.lower().endswith("layerwise")
+ if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED and is_layerwise_supported:
+ raise NotImplementedError(f"Layer-wise {optimizer_name} does not support DDP at this time")
+
+ optimizer_cls = optimizer_mapping[optimizer_name]
+
+ if args.optim_target_modules is None:
+ raise ValueError(f"You need to define `optim_target_modules` to use {optimizer_name} optimizers")
+
+ if not isinstance(args.optim_target_modules, (list, str)):
+ raise ValueError(
+ f"`optim_target_modules` must be a list of strings, a regex string, or 'all-linear'. Got: {args.optim_target_modules}"
+ )
+
+ if model is None:
+ raise ValueError(f"You need to pass a model to initialize {optimizer_name} optimizer.")
+
+ all_linear = (
+ isinstance(args.optim_target_modules, str)
+ and args.optim_target_modules.replace("_", "-") == "all-linear"
+ )
+
+ target_params = []
+ target_params_names = []
+ for module_name, module in model.named_modules():
+ target_module_exists, is_regex = check_target_module_exists(
+ args.optim_target_modules, module_name, return_is_regex=True
+ )
+
+ if not isinstance(module, nn.Linear):
+ if target_module_exists and not is_regex:
+ logger.warning(
+ f"{module_name} matched but ignored. {optimizer_name} only supports linear layers."
+ )
+ continue
+
+ if not target_module_exists and not all_linear:
+ continue
+
+ target_params.append(module.weight)
+ target_params_names.append(module_name + ".weight")
+
+ if len(target_params) == 0:
+ raise ValueError(f"No target modules found for {optimizer_name} ({args.optim_target_modules}).")
+
+ non_target_params = [p for n, p in model.named_parameters() if n not in target_params_names]
+ optim_kwargs.update(optim_args)
+
+ param_groups = [
+ {"params": non_target_params},
+ {"params": target_params, **optim_kwargs},
+ ]
+
+ if is_layerwise:
+ if args.gradient_accumulation_steps != 1:
+ raise ValueError(f"Layerwise {optimizer_name} does not support gradient accumulation!")
+
+ optimizer_dict = {}
+ for param in non_target_params:
+ optimizer_dict[param] = optimizer_cls([{"params": [param]}], **optimizer_kwargs)
+ for param in target_params:
+ optimizer_dict[param] = optimizer_cls([{"params": [param], **optim_kwargs}], **optimizer_kwargs)
+
+ def optimizer_hook(param):
+ if param.grad is not None:
+ optimizer_dict[param].step()
+ optimizer_dict[param].zero_grad()
+
+ for param in model.parameters():
+ if param.requires_grad:
+ param.register_post_accumulate_grad_hook(optimizer_hook)
+
+ optimizer_cls = LayerWiseDummyOptimizer
+ optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
+
+ optimizer_kwargs.update({"params": param_groups})
+ return optimizer_cls, optimizer_kwargs
+
if args.optim == OptimizerNames.ADAFACTOR:
optimizer_cls = Adafactor
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
@@ -1471,10 +1569,6 @@ def get_optimizer_cls_and_kwargs(
)
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
- is_layerwise = args.optim.lower().endswith("layerwise")
- if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
- raise NotImplementedError("Layer-wise GaLore does not support DDP at this time")
-
optimizer_mapping = {
OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
@@ -1484,59 +1578,6 @@ def get_optimizer_cls_and_kwargs(
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
}
- optimizer_cls = optimizer_mapping[args.optim]
-
- if args.optim_target_modules is None:
- raise ValueError(
- "You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
- )
-
- if not isinstance(args.optim_target_modules, (list, str)):
- raise ValueError(
- f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
- )
-
- if model is None:
- raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")
-
- logger.warning(
- "Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
- )
-
- all_linear = (
- isinstance(args.optim_target_modules, str)
- and args.optim_target_modules.replace("_", "-") == "all-linear"
- )
-
- galore_params = []
- galore_params_names = []
- for module_name, module in model.named_modules():
- target_module_exists, is_regex = check_target_module_exists(
- args.optim_target_modules, module_name, return_is_regex=True
- )
-
- if not isinstance(module, nn.Linear):
- # Warn in case we match but it's not a linear layer
- if target_module_exists and not is_regex:
- logger.warning(
- f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
- )
-
- continue
-
- if not target_module_exists and not all_linear:
- continue
-
- galore_params.append(module.weight)
- galore_params_names.append(module_name + ".weight")
-
- if len(galore_params) == 0:
- raise ValueError(
- f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
- )
-
- non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]
-
galore_optim_kwargs = {
"rank": int(optim_args.pop("rank", 128)),
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
@@ -1544,45 +1585,39 @@ def get_optimizer_cls_and_kwargs(
"proj_type": optim_args.pop("proj_type", "std"),
}
- # The default args are from the official repository: https://github.com/jiaweizzhao/GaLore
- param_groups = [
- {"params": non_galore_params},
- {"params": galore_params, **galore_optim_kwargs},
- ]
-
- if is_layerwise:
- # For layer-wise optimizers, the optimization step is done through post accumulation
- # gradient hooks. The trick is to first attach these hooks to the model parameters then
- # create a dummy optimizer that will perform no-ops in the Trainer.
- # See the original implementation or the nice implementation from @hiyouga
- # here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
- if args.gradient_accumulation_steps != 1:
- raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !")
-
- optimizer_dict = {}
- for param in non_galore_params:
- param_groups = [{"params": [param]}]
- optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
- for param in galore_params:
- param_groups = [{"params": [param], **galore_optim_kwargs}]
- optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
-
- def optimizer_hook(param):
- if param.grad is not None:
- optimizer_dict[param].step()
- optimizer_dict[param].zero_grad()
-
- for param in model.parameters():
- if param.requires_grad:
- param.register_post_accumulate_grad_hook(optimizer_hook)
+ optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
+ args.optim, optimizer_mapping, galore_optim_kwargs
+ )
+ if args.optim == OptimizerNames.GALORE_ADAFACTOR:
+ optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
+ elif args.optim in [
+ OptimizerNames.APOLLO_ADAMW,
+ OptimizerNames.APOLLO_ADAMW_LAYERWISE,
+ ]:
+ if not is_apollo_torch_available():
+ raise ImportError(
+ "You need to install `apollo_torch` in order to use APOLLO optimizers"
+ " install it with `pip install git+https://github.com/zhuhanqing/APOLLO`"
+ )
+ from apollo_torch import APOLLOAdamW
- optimizer_cls = LayerWiseDummyOptimizer
- optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
+ optimizer_mapping = {
+ OptimizerNames.APOLLO_ADAMW: APOLLOAdamW,
+ OptimizerNames.APOLLO_ADAMW_LAYERWISE: APOLLOAdamW,
+ }
- optimizer_kwargs.update({"params": param_groups})
+ apollo_optim_kwargs = {
+ "rank": int(optim_args.pop("rank", 128)),
+ "proj": optim_args.pop("proj", "random"),
+ "scale_type": optim_args.pop("scale_type", "channel"),
+ "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
+ "scale": float(optim_args.pop("scale", 1.0)),
+ "proj_type": optim_args.pop("proj_type", "std"),
+ }
- if args.optim == OptimizerNames.GALORE_ADAFACTOR:
- optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
+ optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
+ args.optim, optimizer_mapping, apollo_optim_kwargs
+ )
elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
if not is_lomo_available():
raise ImportError(
diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py
index 5bc31b616003..36c2224b210e 100644
--- a/src/transformers/training_args.py
+++ b/src/transformers/training_args.py
@@ -184,6 +184,8 @@ class OptimizerNames(ExplicitEnum):
GROKADAMW = "grokadamw"
SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
SCHEDULE_FREE_SGD = "schedule_free_sgd"
+ APOLLO_ADAMW = "apollo_adamw"
+ APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
# Sometimes users will pass in a `str` repr of a dict in the CLI
@@ -789,11 +791,10 @@ class TrainingArguments:
[original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
`PeftModel` from peft. The original paper used values in the range [5.0, 15.0].
optim_target_modules (`Union[str, List[str]]`, *optional*):
- The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
- https://arxiv.org/abs/2403.03507
- See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe
- optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules
- only.
+ The target modules to optimize, i.e. the module names that you would like to train.
+ Currently used for the GaLore algorithm (https://arxiv.org/abs/2403.03507) and APOLLO algorithm (https://arxiv.org/abs/2412.05270).
+ See GaLore implementation (https://github.com/jiaweizzhao/GaLore) and APOLLO implementation (https://github.com/zhuhanqing/APOLLO) for more details.
+ You need to make sure to pass a valid GaLore or APOLLO optimizer, e.g., one of: "apollo_adamw", "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules only.
batch_eval_metrics (`Optional[bool]`, defaults to `False`):
If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index e5aedf5916fa..4b226ef4dc24 100755
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -118,6 +118,7 @@
get_torch_version,
is_accelerate_available,
is_apex_available,
+ is_apollo_torch_available,
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index ac07281b3d33..9c7f710482ab 100755
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -98,6 +98,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
+_apollo_torch_available = _is_package_available("apollo_torch")
_aqlm_available = _is_package_available("aqlm")
_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True)
_av_available = importlib.util.find_spec("av") is not None
@@ -402,6 +403,10 @@ def is_galore_torch_available():
return _galore_torch_available
+def is_apollo_torch_available():
+ return _apollo_torch_available
+
+
def is_lomo_available():
return _lomo_available
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index 3de94511fb8e..b2872fc83af8 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -66,6 +66,7 @@
get_tests_dir,
is_staging_test,
require_accelerate,
+ require_apollo_torch,
require_bitsandbytes,
require_deepspeed,
require_galore_torch,
@@ -2235,6 +2236,168 @@ def test_galore_lr_display_with_scheduler(self):
# warm up steps << total steps
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
+ @require_apollo_torch
+ @require_torch_gpu
+ def test_apollo(self):
+ config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
+ tiny_llama = LlamaForCausalLM(config)
+ x = torch.randint(0, 100, (128,))
+ train_dataset = RepeatDataset(x)
+
+ # Trainer without inf/nan filter
+ args = TrainingArguments(
+ self.get_auto_remove_tmp_dir(),
+ learning_rate=1e-9,
+ logging_steps=5,
+ optim="apollo_adamw",
+ optim_target_modules=[r".*attn.*", r".*mlp.*"],
+ )
+ trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
+
+ # Check this works
+ _ = trainer.train()
+
+ @require_apollo_torch
+ @require_torch_gpu
+ def test_apollo_extra_args(self):
+ config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
+ tiny_llama = LlamaForCausalLM(config)
+ x = torch.randint(0, 100, (128,))
+ train_dataset = RepeatDataset(x)
+
+ # Trainer without inf/nan filter
+ args = TrainingArguments(
+ self.get_auto_remove_tmp_dir(),
+ learning_rate=1e-9,
+ logging_steps=5,
+ optim="apollo_adamw",
+ optim_args="proj=random,scale_type=tensor,rank=1,update_proj_gap=100,scale=128.0",
+ optim_target_modules=[r".*attn.*", r".*mlp.*"],
+ )
+ trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
+
+ # Check this works
+ _ = trainer.train()
+
+ @require_apollo_torch
+ @require_torch_gpu
+ def test_apollo_layerwise(self):
+ config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
+ tiny_llama = LlamaForCausalLM(config)
+ x = torch.randint(0, 100, (128,))
+ train_dataset = RepeatDataset(x)
+
+ # Trainer without inf/nan filter
+ args = TrainingArguments(
+ self.get_auto_remove_tmp_dir(),
+ learning_rate=1e-9,
+ logging_steps=5,
+ optim="apollo_adamw_layerwise",
+ optim_target_modules=[r".*attn.*", r".*mlp.*"],
+ )
+ trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
+
+ # Check this works
+ _ = trainer.train()
+
+ @require_apollo_torch
+ @require_torch_gpu
+ def test_apollo_layerwise_with_scheduler(self):
+ config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
+ tiny_llama = LlamaForCausalLM(config)
+ x = torch.randint(0, 100, (128,))
+ train_dataset = RepeatDataset(x)
+
+ # Trainer without inf/nan filter
+ args = TrainingArguments(
+ self.get_auto_remove_tmp_dir(),
+ learning_rate=1e-9,
+ logging_steps=5,
+ optim="apollo_adamw_layerwise",
+ lr_scheduler_type="cosine",
+ optim_target_modules=[r".*attn.*", r".*mlp.*"],
+ )
+ trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
+
+ # Check this works
+ _ = trainer.train()
+
+ @require_apollo_torch
+ @require_torch_gpu
+ def test_apollo_lr_display_without_scheduler(self):
+ config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
+ tiny_llama = LlamaForCausalLM(config)
+ x = torch.randint(0, 100, (128,))
+ train_dataset = RepeatDataset(x)
+
+ learning_rate = 1e-9
+ num_steps = 10
+
+ # Trainer without inf/nan filter
+ args = TrainingArguments(
+ self.get_auto_remove_tmp_dir(),
+ learning_rate=learning_rate,
+ logging_steps=5,
+ optim="apollo_adamw",
+ optim_target_modules=[r".*attn.*", r".*mlp.*"],
+ )
+ trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
+ trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
+
+ # reflects displayed lr in trainer
+ self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
+
+ @require_apollo_torch
+ @require_torch_gpu
+ def test_apollo_lr_display_with_scheduler(self):
+ config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
+ tiny_llama = LlamaForCausalLM(config)
+ x = torch.randint(0, 100, (128,))
+ train_dataset = RepeatDataset(x)
+
+ learning_rate = 2e-4
+ num_train_epochs = 10
+ num_warmup_steps = 5
+
+ # Trainer without inf/nan filter
+ args = TrainingArguments(
+ self.get_auto_remove_tmp_dir(),
+ num_train_epochs=num_train_epochs,
+ learning_rate=learning_rate,
+ warmup_steps=num_warmup_steps,
+ lr_scheduler_type="cosine",
+ logging_steps=1,
+ optim="apollo_adamw",
+ optim_target_modules=[r".*attn.*", r".*mlp.*"],
+ )
+ trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
+
+ # creating log history of trainer, results don't matter
+ trainer.train()
+ logs = trainer.state.log_history[1:][:-1]
+
+ # reach given learning rate peak and end with 0 lr
+ self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate)
+ self.assertTrue(logs[-1]["learning_rate"] == 0)
+
+ # increasing and decreasing pattern of lrs
+ increasing_lrs = [
+ logs[i]["learning_rate"] < logs[i + 1]["learning_rate"]
+ for i in range(len(logs))
+ if i < num_warmup_steps - 2
+ ]
+ decreasing_lrs = [
+ logs[i]["learning_rate"] > logs[i + 1]["learning_rate"]
+ for i in range(len(logs) - 1)
+ if i >= num_warmup_steps - 2
+ ]
+
+ self.assertTrue(all(increasing_lrs))
+ self.assertTrue(all(decreasing_lrs))
+
+ # warm up steps << total steps
+ self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
+
@require_torch_multi_accelerator
def test_data_is_not_parallelized_when_model_is_parallel(self):
model = RegressionModel()