Skip to content

Commit cd84325

Browse files
authored
allow plugins to return their own dataset (axolotl-ai-cloud#2617) [skip ci]
* allow plugins to return their own dataset * add post_trainer_create and wire up * add hook check * address PR feedback: * remove annotation causing circular import
1 parent 0b140fe commit cd84325

File tree

5 files changed

+129
-50
lines changed

5 files changed

+129
-50
lines changed

src/axolotl/cli/preprocess.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from axolotl.cli.config import load_cfg
1919
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
2020
from axolotl.common.datasets import load_datasets, load_preference_datasets
21+
from axolotl.integrations.base import PluginManager
2122
from axolotl.utils.dict import DictDefault
2223
from axolotl.utils.trainer import disable_datasets_caching
2324

@@ -47,7 +48,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
4748
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
4849

4950
with disable_datasets_caching():
50-
if cfg.rl:
51+
plugin_manager = PluginManager.get_instance()
52+
if plugin_manager.load_datasets(cfg, preprocess=True):
53+
pass
54+
elif cfg.rl:
5155
load_preference_datasets(cfg=cfg, cli_args=cli_args)
5256
else:
5357
load_datasets(cfg=cfg, cli_args=cli_args)

src/axolotl/cli/train.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
4343
if int(os.getenv("LOCAL_RANK", "0")) == 0:
4444
check_user_token()
4545

46-
if cfg.rl:
47-
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
48-
else:
49-
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
46+
plugin_manager = PluginManager.get_instance()
47+
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
48+
if not dataset_meta:
49+
if cfg.rl:
50+
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
51+
else:
52+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
5053

5154
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
5255

src/axolotl/integrations/base.py

Lines changed: 107 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import torch
2727
from torch.optim.lr_scheduler import LRScheduler
2828

29+
from axolotl.utils.dict import DictDefault
30+
2931

3032
class BasePlugin:
3133
"""
@@ -36,11 +38,13 @@ class BasePlugin:
3638
3739
Methods:
3840
register(cfg): Registers the plugin with the given configuration.
41+
load_datasets(cfg): Loads and preprocesses the dataset for training.
3942
pre_model_load(cfg): Performs actions before the model is loaded.
4043
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
4144
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
4245
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
4346
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
47+
post_trainer_create(cfg, trainer): Performs actions after the trainer is created.
4448
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
4549
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
4650
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
@@ -63,20 +67,32 @@ def register(self, cfg): # pylint: disable=unused-argument
6367
None
6468
"""
6569

66-
def get_input_args(self):
70+
def get_input_args(self) -> str | None:
6771
"""
6872
Returns a pydantic model for the plugin's input arguments.
6973
"""
7074

75+
def load_datasets(self, cfg: DictDefault, preprocess: bool = False):
76+
"""
77+
Loads and preprocesses the dataset for training.
78+
79+
Args:
80+
cfg: The configuration for the plugin.
81+
preprocess: Whether this is the preprocess step of the datasets.
82+
83+
Returns:
84+
dataset_meta: The metadata for the training dataset.
85+
"""
86+
7187
def pre_model_load(self, cfg): # pylint: disable=unused-argument
7288
"""
7389
Performs actions before the model is loaded.
7490
75-
Parameters:
76-
cfg (dict): The configuration for the plugin.
91+
Args:
92+
cfg (dict): The configuration for the plugin.
7793
7894
Returns:
79-
None
95+
None
8096
"""
8197

8298
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
@@ -91,59 +107,71 @@ def post_model_load(self, cfg, model): # pylint: disable=unused-argument
91107
"""
92108
Performs actions after the model is loaded.
93109
94-
Parameters:
95-
cfg (dict): The configuration for the plugin.
96-
model (object): The loaded model.
110+
Args:
111+
cfg (dict): The configuration for the plugin.
112+
model (object): The loaded model.
97113
98114
Returns:
99-
None
115+
None
100116
"""
101117

102118
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
103119
"""
104120
Performs actions before LoRA weights are loaded.
105121
106-
Parameters:
107-
cfg (dict): The configuration for the plugin.
108-
model (object): The loaded model.
122+
Args:
123+
cfg (dict): The configuration for the plugin.
124+
model (object): The loaded model.
109125
110126
Returns:
111-
None
127+
None
112128
"""
113129

114130
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
115131
"""
116132
Performs actions after LoRA weights are loaded.
117133
118-
Parameters:
119-
cfg (dict): The configuration for the plugin.
120-
model (object): The loaded model.
134+
Args:
135+
cfg (dict): The configuration for the plugin.
136+
model (object): The loaded model.
121137
122138
Returns:
123-
None
139+
None
124140
"""
125141

126142
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
127143
"""
128144
Returns a custom class for the trainer.
129145
130-
Parameters:
131-
cfg (dict): The global axolotl configuration.
146+
Args:
147+
cfg (dict): The global axolotl configuration.
148+
149+
Returns:
150+
class: The class for the trainer.
151+
"""
152+
153+
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
154+
"""
155+
Performs actions after the trainer is created.
156+
157+
Args:
158+
cfg (dict): The configuration for the plugin.
159+
trainer (object): The trainer object for training.
132160
133161
Returns:
134-
class: The class for the trainer.
162+
None
135163
"""
136164

137165
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
138166
"""
139167
Creates and returns an optimizer for training.
140168
141-
Parameters:
142-
cfg (dict): The configuration for the plugin.
143-
trainer (object): The trainer object for training.
169+
Args:
170+
cfg (dict): The configuration for the plugin.
171+
trainer (object): The trainer object for training.
144172
145173
Returns:
146-
object: The created optimizer.
174+
object: The created optimizer.
147175
"""
148176

149177
def create_lr_scheduler(
@@ -152,26 +180,26 @@ def create_lr_scheduler(
152180
"""
153181
Creates and returns a learning rate scheduler.
154182
155-
Parameters:
156-
cfg (dict): The configuration for the plugin.
157-
trainer (object): The trainer object for training.
158-
optimizer (object): The optimizer for training.
159-
num_training_steps (int): Total number of training steps
183+
Args:
184+
cfg (dict): The configuration for the plugin.
185+
trainer (object): The trainer object for training.
186+
optimizer (object): The optimizer for training.
187+
num_training_steps (int): Total number of training steps
160188
161189
Returns:
162-
object (LRScheduler): The created learning rate scheduler.
190+
object (LRScheduler): The created learning rate scheduler.
163191
"""
164192

165193
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
166194
"""
167195
setup callbacks before creating the trainer.
168196
169-
Parameters:
170-
cfg (dict): The configuration for the plugin.
171-
model (object): The loaded model.
197+
Args:
198+
cfg (dict): The configuration for the plugin.
199+
model (object): The loaded model.
172200
173201
Returns:
174-
List[callable]: A list of callback functions to be added to the TrainingArgs
202+
List[callable]: A list of callback functions to be added to the TrainingArgs
175203
"""
176204
return []
177205

@@ -182,36 +210,36 @@ def add_callbacks_post_trainer(
182210
Adds callbacks to the trainer after creating the trainer.
183211
This is useful for callbacks that require access to the model or trainer.
184212
185-
Parameters:
186-
cfg (dict): The configuration for the plugin.
187-
trainer (object): The trainer object for training.
213+
Args:
214+
cfg (dict): The configuration for the plugin.
215+
trainer (object): The trainer object for training.
188216
189217
Returns:
190-
List[callable]: A list of callback functions to be added
218+
List[callable]: A list of callback functions to be added
191219
"""
192220
return []
193221

194222
def post_train(self, cfg, model): # pylint: disable=unused-argument
195223
"""
196224
Performs actions after training is complete.
197225
198-
Parameters:
199-
cfg (dict): The axolotl configuration
200-
model (object): The loaded model.
226+
Args:
227+
cfg (dict): The axolotl configuration
228+
model (object): The loaded model.
201229
202230
Returns:
203-
None
231+
None
204232
"""
205233

206234
def post_train_unload(self, cfg): # pylint: disable=unused-argument
207235
"""
208236
Performs actions after training is complete and the model is unloaded.
209237
210-
Parameters:
211-
cfg (dict): The configuration for the plugin.
238+
Args:
239+
cfg (dict): The configuration for the plugin.
212240
213241
Returns:
214-
None
242+
None
215243
"""
216244

217245

@@ -338,6 +366,27 @@ def get_input_args(self):
338366
input_args.append(input_args_from_plugin)
339367
return input_args
340368

369+
def load_datasets(self, cfg, preprocess: bool = False):
370+
"""
371+
Calls the load_datasets method of each registered plugin.
372+
373+
Args:
374+
cfg: The configuration for the plugins.
375+
preprocess : Whether this is preprocess step of the datasets.
376+
377+
Returns:
378+
dataset_meta: The dataset metadata loaded from all registered plugins.
379+
"""
380+
return_ds_meta = None
381+
for plugin in self.plugins.values():
382+
dataset_meta = plugin.load_datasets(cfg, preprocess)
383+
if dataset_meta is not None:
384+
if return_ds_meta is None:
385+
return_ds_meta = dataset_meta
386+
else:
387+
raise RuntimeError("Multiple plugins loaded datasets")
388+
return return_ds_meta
389+
341390
def pre_model_load(self, cfg):
342391
"""
343392
Calls the pre_model_load method of all registered plugins.
@@ -422,6 +471,20 @@ def get_trainer_cls(self, cfg):
422471
return trainer_cls
423472
return None
424473

474+
def post_trainer_create(self, cfg, trainer):
475+
"""
476+
Calls the post_trainer_create method of all registered plugins.
477+
478+
Parameters:
479+
cfg (dict): The configuration for the plugins.
480+
trainer (object): The trainer object for training.
481+
482+
Returns:
483+
None
484+
"""
485+
for plugin in self.plugins.values():
486+
plugin.post_trainer_create(cfg, trainer)
487+
425488
def create_optimizer(self, trainer):
426489
"""
427490
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.

src/axolotl/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,9 @@ def train(
528528
processor,
529529
) = setup_model_and_trainer(cfg, dataset_meta)
530530

531+
plugin_manager = PluginManager.get_instance()
532+
plugin_manager.post_trainer_create(cfg, trainer)
533+
531534
# Handle untrained tokens if configured
532535
safe_serialization = cfg.save_safetensors is True
533536
train_dataset = dataset_meta.train_dataset
@@ -550,7 +553,6 @@ def train(
550553
if not cfg.use_ray:
551554
cleanup_distributed()
552555

553-
plugin_manager = PluginManager.get_instance()
554556
plugin_manager.post_train(cfg, model)
555557

556558
return model, tokenizer, trainer

tests/e2e/integrations/test_hooks.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ def __init__(self):
2929
except FileNotFoundError:
3030
pass
3131

32+
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
33+
with open(
34+
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
35+
) as f:
36+
f.write("post_trainer_create\n")
37+
3238
def pre_model_load(self, cfg): # pylint: disable=unused-argument
3339
with open(
3440
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
@@ -165,6 +171,7 @@ def test_plugin_hooks(self, temp_dir):
165171
) as f:
166172
file_contents = f.readlines()
167173
file_contents = "\n".join(file_contents)
174+
assert "post_trainer_create" in file_contents
168175
assert "pre_model_load" in file_contents
169176
assert "post_model_build" in file_contents
170177
assert "pre_lora_load" in file_contents

0 commit comments

Comments
 (0)