Skip to content

Commit 25e6c5f

Browse files
authored
Add CAME Optimizer (axolotl-ai-cloud#2385)
1 parent 32f51bc commit 25e6c5f

File tree

6 files changed

+66
-0
lines changed

6 files changed

+66
-0
lines changed

docs/config.qmd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ lr_div_factor: # Learning rate div factor
612612
# - optimi_adamw
613613
# - ao_adamw_8bit
614614
# - ao_adamw_fp8
615+
# - came_pytorch
615616
optimizer:
616617
# Dictionary of arguments to pass to the optimizer
617618
optim_args:

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def get_package_version():
142142
"apollo-torch",
143143
"lomo-optim==0.1.1",
144144
"torch-optimi==0.2.1",
145+
"came_pytorch==0.1.3",
145146
],
146147
"ray": [
147148
"ray[train]",

src/axolotl/core/trainer_builder.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,20 @@ def build(self, total_num_steps):
708708
optimizer_cls = ADOPT
709709
adam_kwargs["decouple"] = True
710710
optimizer_kwargs.update(adam_kwargs)
711+
elif self.cfg.optimizer == "came_pytorch":
712+
from came_pytorch import CAME
713+
714+
optimizer_cls = CAME
715+
716+
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
717+
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
718+
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
719+
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
720+
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
721+
adam_kwargs["betas"] = (beta1, beta2, beta3)
722+
adam_kwargs["eps"] = (eps1, eps2)
723+
724+
optimizer_kwargs.update(adam_kwargs)
711725

712726
# Parse any additional optimizer args from config
713727
if self.cfg.optim_args:

src/axolotl/utils/schemas/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,5 @@ class CustomSupportedOptimizers(str, Enum):
5353
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
5454
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
5555
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
56+
came_pytorch = "came_pytorch" # pylint: disable=invalid-name
5657
muon = "muon" # pylint: disable=invalid-name

src/axolotl/utils/schemas/training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ class HyperparametersConfig(BaseModel):
7575
lr_groups: list[LrGroup] | None = None
7676

7777
adam_epsilon: float | None = None
78+
adam_epsilon2: float | None = None
7879
adam_beta1: float | None = None
7980
adam_beta2: float | None = None
81+
adam_beta3: float | None = None
8082
max_grad_norm: float | None = None
8183
num_epochs: float = Field(default=1.0)
8284

tests/e2e/test_optimizers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,50 @@ def test_fft_schedule_free_adamw(self, temp_dir):
199199

200200
train(cfg=cfg, dataset_meta=dataset_meta)
201201
check_model_output_exists(temp_dir, cfg)
202+
203+
@with_temp_dir
204+
def test_came_pytorch(self, temp_dir):
205+
# pylint: disable=duplicate-code
206+
cfg = DictDefault(
207+
{
208+
"base_model": "JackFram/llama-68m",
209+
"tokenizer_type": "LlamaTokenizer",
210+
"sequence_len": 1024,
211+
"load_in_8bit": True,
212+
"adapter": "lora",
213+
"lora_r": 8,
214+
"lora_alpha": 16,
215+
"lora_dropout": 0.05,
216+
"lora_target_linear": True,
217+
"val_set_size": 0.1,
218+
"special_tokens": {
219+
"unk_token": "<unk>",
220+
"bos_token": "<s>",
221+
"eos_token": "</s>",
222+
},
223+
"datasets": [
224+
{
225+
"path": "mhenrichsen/alpaca_2k_test",
226+
"type": "alpaca",
227+
},
228+
],
229+
"num_epochs": 1,
230+
"micro_batch_size": 8,
231+
"gradient_accumulation_steps": 1,
232+
"output_dir": temp_dir,
233+
"learning_rate": 0.00001,
234+
"optimizer": "came_pytorch",
235+
"adam_beta3": 0.9999,
236+
"adam_epsilon2": 1e-16,
237+
"max_steps": 5,
238+
"lr_scheduler": "cosine",
239+
}
240+
)
241+
242+
cfg = validate_config(cfg)
243+
normalize_config(cfg)
244+
cli_args = TrainerCliArgs()
245+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
246+
247+
train(cfg=cfg, dataset_meta=dataset_meta)
248+
check_model_output_exists(temp_dir, cfg)

0 commit comments

Comments
 (0)