Skip to content

Commit 170cdb5

Browse files
divyanshuaggarwalwingliandjsaunde
authored
Add Post_model_load, post_lora_load, post_train, post_train_unload function calls (axolotl-ai-cloud#2539)
* Update train.py add post_model_load and post_lora_load model calss. * Update train.py add post_train and post_train_unload function calls * Update train.py * Update base.py * Update train.py * chore: lint * clarify plugin hooks * Update src/axolotl/integrations/base.py Co-authored-by: Dan Saunders <[email protected]> * Update src/axolotl/utils/models.py Co-authored-by: Dan Saunders <[email protected]> * Update src/axolotl/utils/models.py Co-authored-by: Dan Saunders <[email protected]> * Update src/axolotl/integrations/base.py Co-authored-by: Dan Saunders <[email protected]> * Update models.py * Update models.py * remove extra call to post_model_load * chore: lint * add test for hooks and gc trainer * disable duplicated code check for test * fix the path and add better handling --------- Co-authored-by: Wing Lian <[email protected]> Co-authored-by: Dan Saunders <[email protected]>
1 parent 5d182a1 commit 170cdb5

File tree

6 files changed

+245
-10
lines changed

6 files changed

+245
-10
lines changed

src/axolotl/cli/train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""CLI to run training on a model."""
22

3+
import gc
34
import logging
45
import os
56
from pathlib import Path
@@ -48,8 +49,11 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
4849
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
4950

5051
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
52+
5153
del model, tokenizer, trainer
5254

55+
gc.collect()
56+
5357
plugin_manager = PluginManager.get_instance()
5458
plugin_manager.post_train_unload(cfg)
5559

src/axolotl/integrations/base.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ class BasePlugin:
3636
Methods:
3737
register(cfg): Registers the plugin with the given configuration.
3838
pre_model_load(cfg): Performs actions before the model is loaded.
39-
post_model_load(cfg, model): Performs actions after the model is loaded.
39+
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
4040
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
4141
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
42+
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
4243
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
4344
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
4445
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
@@ -77,6 +78,14 @@ def pre_model_load(self, cfg): # pylint: disable=unused-argument
7778
None
7879
"""
7980

81+
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
82+
"""
83+
Performs actions after the model is built/loaded, but before any adapters are applied.
84+
85+
Args:
86+
cfg (dict): The configuration for the plugin.
87+
"""
88+
8089
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
8190
"""
8291
Performs actions after the model is loaded.
@@ -329,9 +338,22 @@ def pre_model_load(self, cfg):
329338
for plugin in self.plugins.values():
330339
plugin.pre_model_load(cfg)
331340

341+
def post_model_build(self, cfg, model):
342+
"""
343+
Calls the post_model_build method of all registered plugins after the model has been built/loaded,
344+
but before any adapters have been applied.
345+
346+
Args:
347+
cfg (dict): The configuration for the plugins.
348+
model (object): The loaded model.
349+
"""
350+
for plugin in self.plugins.values():
351+
plugin.post_model_build(cfg, model)
352+
332353
def post_model_load(self, cfg, model):
333354
"""
334-
Calls the post_model_load method of all registered plugins.
355+
Calls the post_model_load method of all registered plugins after the model has been loaded
356+
inclusive of any adapters
335357
336358
Parameters:
337359
cfg (dict): The configuration for the plugins.
@@ -458,6 +480,20 @@ def add_callbacks_post_trainer(self, cfg, trainer):
458480
callbacks.extend(plugin_callbacks)
459481
return callbacks
460482

483+
def post_train(self, cfg, model):
484+
"""
485+
Calls the post_train method of all registered plugins.
486+
487+
Parameters:
488+
cfg (dict): The configuration for the plugins.
489+
model (object): The loaded model.
490+
491+
Returns:
492+
None
493+
"""
494+
for plugin in self.plugins.values():
495+
plugin.post_train(cfg, model)
496+
461497
def post_train_unload(self, cfg):
462498
"""
463499
Calls the post_train_unload method of all registered plugins.

src/axolotl/train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from axolotl.core.trainers.mixins.sequence_parallel import (
3030
SequenceParallelContextManager,
3131
)
32+
from axolotl.integrations.base import PluginManager
3233
from axolotl.logging_config import configure_logging
3334
from axolotl.utils.dict import DictDefault
3435
from axolotl.utils.distributed import cleanup_distributed
@@ -533,4 +534,7 @@ def train(
533534
if not cfg.use_ray:
534535
cleanup_distributed()
535536

537+
plugin_manager = PluginManager.get_instance()
538+
plugin_manager.post_train(cfg, model)
539+
536540
return model, tokenizer, trainer

src/axolotl/utils/models.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
)
5454

5555
from axolotl.common.architectures import MOE_ARCH_BLOCK
56+
from axolotl.integrations.base import PluginManager
5657
from axolotl.models.mamba import fix_mamba_attn_for_loss
5758
from axolotl.monkeypatch.multipack import (
5859
SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -74,6 +75,7 @@
7475
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
7576

7677
LOG = logging.getLogger(__name__)
78+
PLUGIN_MANAGER = PluginManager.get_instance()
7779

7880
MULTIMODAL_AUTO_MODEL_MAPPING = {
7981
"mllama": MllamaForConditionalGeneration,
@@ -571,10 +573,8 @@ def apply_patches(self) -> None:
571573
patch_gemma3conditionalgeneration_forward()
572574

573575
# load any patches from plugins
574-
from axolotl.integrations.base import PluginManager
575576

576-
plugin_manager = PluginManager.get_instance()
577-
plugin_manager.pre_model_load(self.cfg)
577+
PLUGIN_MANAGER.pre_model_load(self.cfg)
578578

579579
# monkey patch to allow additional Accelerator init kwargs
580580
if self.cfg.fp8:
@@ -1252,6 +1252,7 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
12521252

12531253
try:
12541254
skip_move_to_device = self.build_model(qlora_fsdp)
1255+
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
12551256
except Exception as err: # pylint: disable=broad-exception-caught
12561257
LOG.exception(err)
12571258
raise err
@@ -1331,6 +1332,8 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
13311332
before_kbit_train_or_finetune=False,
13321333
)
13331334

1335+
PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model)
1336+
13341337
# ---------------------------------------------------------
13351338
# load lora or adapter
13361339
# ---------------------------------------------------------
@@ -1392,7 +1395,7 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
13921395
gc.collect()
13931396
torch.cuda.empty_cache()
13941397

1395-
# TODO resume_from_checkpoint handling
1398+
PLUGIN_MANAGER.post_model_load(self.cfg, self.model)
13961399
return self.model, lora_config
13971400

13981401

@@ -1427,9 +1430,13 @@ def load_adapter(model, cfg, adapter, inference=False):
14271430
if hasattr(model, "enable_input_require_grads"):
14281431
model.enable_input_require_grads()
14291432
if adapter in ["lora", "qlora"]:
1430-
return load_lora(model, cfg, inference=inference)
1433+
model, lora_config = load_lora(model, cfg, inference=inference)
1434+
PLUGIN_MANAGER.post_lora_load(cfg, model)
1435+
return model, lora_config
14311436
if adapter == "llama-adapter":
1432-
return load_llama_adapter(model, cfg)
1437+
model, lora_config = load_llama_adapter(model, cfg)
1438+
PLUGIN_MANAGER.post_lora_load(cfg, model)
1439+
return model, lora_config
14331440

14341441
raise NotImplementedError(f"{adapter} peft adapter not available")
14351442

tests/e2e/integrations/test_hooks.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""
2+
e2e tests to make sure all the hooks are fired on the plugin
3+
"""
4+
5+
import os
6+
from pathlib import Path
7+
8+
from axolotl.cli.args import TrainerCliArgs
9+
from axolotl.common.datasets import load_datasets
10+
from axolotl.integrations.base import BasePlugin
11+
from axolotl.train import train
12+
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
13+
from axolotl.utils.dict import DictDefault
14+
15+
from ..utils import check_model_output_exists
16+
17+
18+
class LogHooksPlugin(BasePlugin):
19+
"""
20+
fixture to capture in a log file each hook that was fired
21+
"""
22+
23+
base_dir = Path("/tmp/axolotl-log-hooks")
24+
25+
def __init__(self):
26+
self.base_dir.mkdir(parents=True, exist_ok=True)
27+
try:
28+
os.remove(self.base_dir.joinpath("plugin_hooks.log"))
29+
except FileNotFoundError:
30+
pass
31+
32+
def pre_model_load(self, cfg): # pylint: disable=unused-argument
33+
with open(
34+
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
35+
) as f:
36+
f.write("pre_model_load\n")
37+
38+
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
39+
with open(
40+
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
41+
) as f:
42+
f.write("post_model_build\n")
43+
44+
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
45+
with open(
46+
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
47+
) as f:
48+
f.write("pre_lora_load\n")
49+
50+
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
51+
with open(
52+
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
53+
) as f:
54+
f.write("post_lora_load\n")
55+
56+
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
57+
with open(
58+
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
59+
) as f:
60+
f.write("post_model_load\n")
61+
62+
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
63+
with open(
64+
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
65+
) as f:
66+
f.write("create_optimizer\n")
67+
68+
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument
69+
with open(
70+
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
71+
) as f:
72+
f.write("get_trainer_cls\n")
73+
74+
def create_lr_scheduler(
75+
self, cfg, trainer, optimizer
76+
): # pylint: disable=unused-argument
77+
with open(
78+
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
79+
) as f:
80+
f.write("create_lr_scheduler\n")
81+
82+
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
83+
with open(
84+
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
85+
) as f:
86+
f.write("add_callbacks_pre_trainer\n")
87+
return []
88+
89+
def add_callbacks_post_trainer(
90+
self, cfg, trainer
91+
): # pylint: disable=unused-argument
92+
with open(
93+
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
94+
) as f:
95+
f.write("add_callbacks_post_trainer\n")
96+
return []
97+
98+
def post_train(self, cfg, model): # pylint: disable=unused-argument
99+
with open(
100+
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
101+
) as f:
102+
f.write("post_train\n")
103+
104+
def post_train_unload(self, cfg): # pylint: disable=unused-argument
105+
with open(
106+
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
107+
) as f:
108+
f.write("post_train_unload\n")
109+
110+
111+
class TestPluginHooks:
112+
"""
113+
e2e tests to make sure all the hooks are fired during the training
114+
"""
115+
116+
def test_plugin_hooks(self, temp_dir):
117+
# pylint: disable=duplicate-code
118+
cfg = DictDefault(
119+
{
120+
"base_model": "HuggingFaceTB/SmolLM2-135M",
121+
"plugins": [
122+
"tests.e2e.integrations.test_hooks.LogHooksPlugin",
123+
],
124+
"tokenizer_type": "AutoTokenizer",
125+
"sequence_len": 1024,
126+
"adapter": "lora",
127+
"lora_r": 8,
128+
"lora_alpha": 16,
129+
"lora_dropout": 0.05,
130+
"lora_target_linear": True,
131+
"val_set_size": 0.02,
132+
"special_tokens": {
133+
"pad_token": "<|endoftext|>",
134+
},
135+
"datasets": [
136+
{
137+
"path": "mhenrichsen/alpaca_2k_test",
138+
"type": "alpaca",
139+
},
140+
],
141+
"num_epochs": 1,
142+
"micro_batch_size": 2,
143+
"gradient_accumulation_steps": 1,
144+
"output_dir": temp_dir,
145+
"learning_rate": 0.00001,
146+
"optimizer": "adamw_torch_fused",
147+
"lr_scheduler": "cosine",
148+
"max_steps": 5,
149+
"flash_attention": True,
150+
"bf16": "auto",
151+
}
152+
)
153+
154+
cfg = validate_config(cfg)
155+
prepare_plugins(cfg)
156+
normalize_config(cfg)
157+
cli_args = TrainerCliArgs()
158+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
159+
160+
train(cfg=cfg, dataset_meta=dataset_meta)
161+
check_model_output_exists(temp_dir, cfg)
162+
163+
with open(
164+
"/tmp/axolotl-log-hooks" + "/plugin_hooks.log", "r", encoding="utf-8"
165+
) as f:
166+
file_contents = f.readlines()
167+
file_contents = "\n".join(file_contents)
168+
assert "pre_model_load" in file_contents
169+
assert "post_model_build" in file_contents
170+
assert "pre_lora_load" in file_contents
171+
assert "post_lora_load" in file_contents
172+
assert "post_model_load" in file_contents
173+
# assert "create_optimizer" in file_contents # not implemented yet
174+
assert "get_trainer_cls" in file_contents
175+
# assert "create_lr_scheduler" in file_contents # not implemented yet
176+
assert "add_callbacks_pre_trainer" in file_contents
177+
assert "add_callbacks_post_trainer" in file_contents
178+
assert "post_train" in file_contents
179+
# assert "post_train_unload" in file_contents # not called from test train call
180+
181+
try:
182+
os.remove("/tmp/axolotl-log-hooks" + "/plugin_hooks.log")
183+
except FileNotFoundError:
184+
pass

tests/e2e/test_lora_llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ def test_lora(self, temp_dir):
4848
},
4949
],
5050
"num_epochs": 1,
51-
"micro_batch_size": 8,
51+
"micro_batch_size": 2,
5252
"gradient_accumulation_steps": 1,
5353
"output_dir": temp_dir,
5454
"learning_rate": 0.00001,
5555
"optimizer": "adamw_torch_fused",
5656
"lr_scheduler": "cosine",
57-
"max_steps": 20,
57+
"max_steps": 5,
5858
}
5959
)
6060

0 commit comments

Comments
 (0)