Skip to content

Commit a21b9cc

Browse files
authored
patch to convert LR from tensor to float when using DS (axolotl-ai-cloud#2595) [skip ci]
1 parent 41a1ec0 commit a21b9cc

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

src/axolotl/core/trainer_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from axolotl.integrations.base import PluginManager
6161
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
6262
from axolotl.monkeypatch.relora import ReLoRACallback
63+
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
6364
from axolotl.processing_strategies import get_processing_strategy
6465
from axolotl.utils import is_comet_available, is_mlflow_available
6566
from axolotl.utils.callbacks import (
@@ -114,6 +115,8 @@ def __init__(self, cfg, model, tokenizer, processor=None):
114115
if hasattr(model, "add_model_tags"):
115116
model.add_model_tags(["axolotl"])
116117

118+
patch_trainer_get_lr()
119+
117120
@property
118121
def model_ref(self):
119122
return self._model_ref

src/axolotl/monkeypatch/trainer/lr.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""
2+
monkeypatch for Trainer _get_learning_rate method
3+
"""
4+
5+
import logging
6+
7+
import torch
8+
9+
LOG = logging.getLogger(__name__)
10+
11+
12+
# TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release
13+
def _get_learning_rate(self):
14+
if self.is_deepspeed_enabled:
15+
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
16+
# not run for the first few dozen steps while loss scale is too large, and thus during
17+
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
18+
try:
19+
last_lr = self.lr_scheduler.get_last_lr()[0]
20+
except AssertionError as e:
21+
if "need to call step" in str(e):
22+
LOG.warning(
23+
"tried to get lr value before scheduler/optimizer started stepping, returning lr=0"
24+
)
25+
last_lr = 0
26+
else:
27+
raise
28+
else:
29+
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
30+
last_lr = self.optimizer.param_groups[0]["lr"]
31+
else:
32+
last_lr = self.lr_scheduler.get_last_lr()[0]
33+
34+
if torch.is_tensor(last_lr):
35+
last_lr = last_lr.item()
36+
return last_lr
37+
38+
39+
def patch_trainer_get_lr():
40+
from transformers.trainer import Trainer
41+
42+
Trainer._get_learning_rate = _get_learning_rate # pylint: disable=protected-access

0 commit comments

Comments
 (0)