Skip to content

Commit 41a1ec0

Browse files
alexdremovwinglian
andauthored
Plugins create_lr_scheduler support (axolotl-ai-cloud#2584)
* lr_scheduler support * fix * Update scheduler.py * Update scheduler.py * cfg handling * black * remove debug * remove adding the axolotl cfg to the scheduler mixin --------- Co-authored-by: Wing Lian <[email protected]>
1 parent ecac731 commit 41a1ec0

File tree

4 files changed

+40
-18
lines changed

4 files changed

+40
-18
lines changed

src/axolotl/core/trainers/mixins/scheduler.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import logging
44

55
import torch
6-
from torch.optim.lr_scheduler import OneCycleLR
6+
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
77
from transformers.trainer import Trainer
88

9+
from axolotl.integrations.base import PluginManager
910
from axolotl.utils.schedulers import (
1011
RexLR,
1112
get_cosine_schedule_with_min_lr,
@@ -25,9 +26,9 @@ class SchedulerMixin(Trainer):
2526

2627
def create_scheduler(
2728
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
28-
):
29+
) -> LRScheduler:
2930
"""
30-
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
31+
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
3132
passed as an argument.
3233
3334
Args:
@@ -47,7 +48,16 @@ def create_scheduler(
4748
# fmt: off
4849
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
4950
# fmt: on
50-
if self.args.alternate_lr_scheduler_type == "one_cycle":
51+
plugin_manager = PluginManager.get_instance()
52+
lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler(
53+
trainer=self,
54+
optimizer=optimizer,
55+
num_training_steps=num_training_steps
56+
)
57+
if lr_scheduler is not None:
58+
LOG.info(f"Using plugin-created lr_scheduler: {lr_scheduler}")
59+
self.lr_scheduler = lr_scheduler
60+
elif self.args.alternate_lr_scheduler_type == "one_cycle":
5161
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
5262
pct_start = num_warmup_steps / num_training_steps
5363
extra_lr_kwargs = {}
@@ -110,4 +120,4 @@ def create_scheduler(
110120
if use_cosine_min_lr:
111121
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
112122

113-
return self.lr_scheduler
123+
return self.lr_scheduler # type: ignore

src/axolotl/core/trainers/relora.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module for ReLoRA trainer"""
22

33
import torch
4+
from torch.optim.lr_scheduler import LRScheduler
45

56
from axolotl.core.trainers.base import AxolotlTrainer
67
from axolotl.monkeypatch.relora import ReLoRAScheduler
@@ -19,9 +20,11 @@ def create_scheduler(
1920
self,
2021
num_training_steps: int,
2122
optimizer: torch.optim.Optimizer | None = None,
22-
):
23+
) -> LRScheduler:
2324
optimizer = self.optimizer if optimizer is None else optimizer
24-
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
25+
lr_scheduler: LRScheduler = super().create_scheduler(
26+
num_training_steps, optimizer
27+
)
2528

2629
if self.args.relora_steps:
2730
warmup_steps = (
@@ -30,14 +33,14 @@ def create_scheduler(
3033
anneal_steps = (
3134
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
3235
)
33-
self.lr_scheduler = ReLoRAScheduler(
36+
self.lr_scheduler = ReLoRAScheduler( # type: ignore
3437
optimizer,
3538
lr_scheduler,
3639
self.args.relora_steps,
3740
anneal_steps,
3841
warmup_steps,
3942
)
4043
else:
41-
self.lr_scheduler = lr_scheduler
44+
self.lr_scheduler = lr_scheduler # type: ignore
4245

43-
return self.lr_scheduler
46+
return self.lr_scheduler # type: ignore

src/axolotl/integrations/base.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import OrderedDict
2525

2626
import torch
27+
from torch.optim.lr_scheduler import LRScheduler
2728

2829

2930
class BasePlugin:
@@ -41,7 +42,7 @@ class BasePlugin:
4142
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
4243
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
4344
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
44-
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
45+
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
4546
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
4647
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
4748
"""
@@ -146,18 +147,19 @@ def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
146147
"""
147148

148149
def create_lr_scheduler(
149-
self, cfg, trainer, optimizer
150-
): # pylint: disable=unused-argument
150+
self, cfg, trainer, optimizer, num_training_steps
151+
) -> LRScheduler | None: # pylint: disable=unused-argument
151152
"""
152153
Creates and returns a learning rate scheduler.
153154
154155
Parameters:
155156
cfg (dict): The configuration for the plugin.
156157
trainer (object): The trainer object for training.
157158
optimizer (object): The optimizer for training.
159+
num_training_steps (int): Total number of training steps
158160
159161
Returns:
160-
object: The created learning rate scheduler.
162+
object (LRScheduler): The created learning rate scheduler.
161163
"""
162164

163165
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
@@ -436,7 +438,9 @@ def create_optimizer(self, trainer):
436438
return optimizer
437439
return None
438440

439-
def create_lr_scheduler(self, trainer, optimizer):
441+
def create_lr_scheduler(
442+
self, trainer, optimizer, num_training_steps
443+
) -> LRScheduler | None:
440444
"""
441445
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
442446
@@ -448,7 +452,12 @@ def create_lr_scheduler(self, trainer, optimizer):
448452
object: The created learning rate scheduler, or None if none was found.
449453
"""
450454
for plugin in self.plugins.values():
451-
scheduler = plugin.create_lr_scheduler(self.cfg, trainer, optimizer)
455+
scheduler: LRScheduler | None = plugin.create_lr_scheduler(
456+
self.cfg,
457+
trainer=trainer,
458+
optimizer=optimizer,
459+
num_training_steps=num_training_steps,
460+
)
452461
if scheduler is not None:
453462
return scheduler
454463
return None

tests/e2e/integrations/test_hooks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_trainer_cls(self, cfg): # pylint: disable=unused-argument
7272
f.write("get_trainer_cls\n")
7373

7474
def create_lr_scheduler(
75-
self, cfg, trainer, optimizer
75+
self, cfg, trainer, optimizer, num_training_steps
7676
): # pylint: disable=unused-argument
7777
with open(
7878
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
@@ -172,7 +172,7 @@ def test_plugin_hooks(self, temp_dir):
172172
assert "post_model_load" in file_contents
173173
# assert "create_optimizer" in file_contents # not implemented yet
174174
assert "get_trainer_cls" in file_contents
175-
# assert "create_lr_scheduler" in file_contents # not implemented yet
175+
assert "create_lr_scheduler" in file_contents
176176
assert "add_callbacks_pre_trainer" in file_contents
177177
assert "add_callbacks_post_trainer" in file_contents
178178
assert "post_train" in file_contents

0 commit comments

Comments
 (0)