26
26
import torch
27
27
from torch .optim .lr_scheduler import LRScheduler
28
28
29
+ from axolotl .utils .dict import DictDefault
30
+
29
31
30
32
class BasePlugin :
31
33
"""
@@ -36,11 +38,13 @@ class BasePlugin:
36
38
37
39
Methods:
38
40
register(cfg): Registers the plugin with the given configuration.
41
+ load_datasets(cfg): Loads and preprocesses the dataset for training.
39
42
pre_model_load(cfg): Performs actions before the model is loaded.
40
43
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
41
44
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
42
45
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
43
46
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
47
+ post_trainer_create(cfg, trainer): Performs actions after the trainer is created.
44
48
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
45
49
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
46
50
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
@@ -63,20 +67,32 @@ def register(self, cfg): # pylint: disable=unused-argument
63
67
None
64
68
"""
65
69
66
- def get_input_args (self ):
70
+ def get_input_args (self ) -> str | None :
67
71
"""
68
72
Returns a pydantic model for the plugin's input arguments.
69
73
"""
70
74
75
+ def load_datasets (self , cfg : DictDefault , preprocess : bool = False ):
76
+ """
77
+ Loads and preprocesses the dataset for training.
78
+
79
+ Args:
80
+ cfg: The configuration for the plugin.
81
+ preprocess: Whether this is the preprocess step of the datasets.
82
+
83
+ Returns:
84
+ dataset_meta: The metadata for the training dataset.
85
+ """
86
+
71
87
def pre_model_load (self , cfg ): # pylint: disable=unused-argument
72
88
"""
73
89
Performs actions before the model is loaded.
74
90
75
- Parameters :
76
- cfg (dict): The configuration for the plugin.
91
+ Args :
92
+ cfg (dict): The configuration for the plugin.
77
93
78
94
Returns:
79
- None
95
+ None
80
96
"""
81
97
82
98
def post_model_build (self , cfg , model ): # pylint: disable=unused-argument
@@ -91,59 +107,71 @@ def post_model_load(self, cfg, model): # pylint: disable=unused-argument
91
107
"""
92
108
Performs actions after the model is loaded.
93
109
94
- Parameters :
95
- cfg (dict): The configuration for the plugin.
96
- model (object): The loaded model.
110
+ Args :
111
+ cfg (dict): The configuration for the plugin.
112
+ model (object): The loaded model.
97
113
98
114
Returns:
99
- None
115
+ None
100
116
"""
101
117
102
118
def pre_lora_load (self , cfg , model ): # pylint: disable=unused-argument
103
119
"""
104
120
Performs actions before LoRA weights are loaded.
105
121
106
- Parameters :
107
- cfg (dict): The configuration for the plugin.
108
- model (object): The loaded model.
122
+ Args :
123
+ cfg (dict): The configuration for the plugin.
124
+ model (object): The loaded model.
109
125
110
126
Returns:
111
- None
127
+ None
112
128
"""
113
129
114
130
def post_lora_load (self , cfg , model ): # pylint: disable=unused-argument
115
131
"""
116
132
Performs actions after LoRA weights are loaded.
117
133
118
- Parameters :
119
- cfg (dict): The configuration for the plugin.
120
- model (object): The loaded model.
134
+ Args :
135
+ cfg (dict): The configuration for the plugin.
136
+ model (object): The loaded model.
121
137
122
138
Returns:
123
- None
139
+ None
124
140
"""
125
141
126
142
def get_trainer_cls (self , cfg ): # pylint: disable=unused-argument):
127
143
"""
128
144
Returns a custom class for the trainer.
129
145
130
- Parameters:
131
- cfg (dict): The global axolotl configuration.
146
+ Args:
147
+ cfg (dict): The global axolotl configuration.
148
+
149
+ Returns:
150
+ class: The class for the trainer.
151
+ """
152
+
153
+ def post_trainer_create (self , cfg , trainer ): # pylint: disable=unused-argument
154
+ """
155
+ Performs actions after the trainer is created.
156
+
157
+ Args:
158
+ cfg (dict): The configuration for the plugin.
159
+ trainer (object): The trainer object for training.
132
160
133
161
Returns:
134
- class: The class for the trainer.
162
+ None
135
163
"""
136
164
137
165
def create_optimizer (self , cfg , trainer ): # pylint: disable=unused-argument
138
166
"""
139
167
Creates and returns an optimizer for training.
140
168
141
- Parameters :
142
- cfg (dict): The configuration for the plugin.
143
- trainer (object): The trainer object for training.
169
+ Args :
170
+ cfg (dict): The configuration for the plugin.
171
+ trainer (object): The trainer object for training.
144
172
145
173
Returns:
146
- object: The created optimizer.
174
+ object: The created optimizer.
147
175
"""
148
176
149
177
def create_lr_scheduler (
@@ -152,26 +180,26 @@ def create_lr_scheduler(
152
180
"""
153
181
Creates and returns a learning rate scheduler.
154
182
155
- Parameters :
156
- cfg (dict): The configuration for the plugin.
157
- trainer (object): The trainer object for training.
158
- optimizer (object): The optimizer for training.
159
- num_training_steps (int): Total number of training steps
183
+ Args :
184
+ cfg (dict): The configuration for the plugin.
185
+ trainer (object): The trainer object for training.
186
+ optimizer (object): The optimizer for training.
187
+ num_training_steps (int): Total number of training steps
160
188
161
189
Returns:
162
- object (LRScheduler): The created learning rate scheduler.
190
+ object (LRScheduler): The created learning rate scheduler.
163
191
"""
164
192
165
193
def add_callbacks_pre_trainer (self , cfg , model ): # pylint: disable=unused-argument
166
194
"""
167
195
setup callbacks before creating the trainer.
168
196
169
- Parameters :
170
- cfg (dict): The configuration for the plugin.
171
- model (object): The loaded model.
197
+ Args :
198
+ cfg (dict): The configuration for the plugin.
199
+ model (object): The loaded model.
172
200
173
201
Returns:
174
- List[callable]: A list of callback functions to be added to the TrainingArgs
202
+ List[callable]: A list of callback functions to be added to the TrainingArgs
175
203
"""
176
204
return []
177
205
@@ -182,36 +210,36 @@ def add_callbacks_post_trainer(
182
210
Adds callbacks to the trainer after creating the trainer.
183
211
This is useful for callbacks that require access to the model or trainer.
184
212
185
- Parameters :
186
- cfg (dict): The configuration for the plugin.
187
- trainer (object): The trainer object for training.
213
+ Args :
214
+ cfg (dict): The configuration for the plugin.
215
+ trainer (object): The trainer object for training.
188
216
189
217
Returns:
190
- List[callable]: A list of callback functions to be added
218
+ List[callable]: A list of callback functions to be added
191
219
"""
192
220
return []
193
221
194
222
def post_train (self , cfg , model ): # pylint: disable=unused-argument
195
223
"""
196
224
Performs actions after training is complete.
197
225
198
- Parameters :
199
- cfg (dict): The axolotl configuration
200
- model (object): The loaded model.
226
+ Args :
227
+ cfg (dict): The axolotl configuration
228
+ model (object): The loaded model.
201
229
202
230
Returns:
203
- None
231
+ None
204
232
"""
205
233
206
234
def post_train_unload (self , cfg ): # pylint: disable=unused-argument
207
235
"""
208
236
Performs actions after training is complete and the model is unloaded.
209
237
210
- Parameters :
211
- cfg (dict): The configuration for the plugin.
238
+ Args :
239
+ cfg (dict): The configuration for the plugin.
212
240
213
241
Returns:
214
- None
242
+ None
215
243
"""
216
244
217
245
@@ -338,6 +366,27 @@ def get_input_args(self):
338
366
input_args .append (input_args_from_plugin )
339
367
return input_args
340
368
369
+ def load_datasets (self , cfg , preprocess : bool = False ):
370
+ """
371
+ Calls the load_datasets method of each registered plugin.
372
+
373
+ Args:
374
+ cfg: The configuration for the plugins.
375
+ preprocess : Whether this is preprocess step of the datasets.
376
+
377
+ Returns:
378
+ dataset_meta: The dataset metadata loaded from all registered plugins.
379
+ """
380
+ return_ds_meta = None
381
+ for plugin in self .plugins .values ():
382
+ dataset_meta = plugin .load_datasets (cfg , preprocess )
383
+ if dataset_meta is not None :
384
+ if return_ds_meta is None :
385
+ return_ds_meta = dataset_meta
386
+ else :
387
+ raise RuntimeError ("Multiple plugins loaded datasets" )
388
+ return return_ds_meta
389
+
341
390
def pre_model_load (self , cfg ):
342
391
"""
343
392
Calls the pre_model_load method of all registered plugins.
@@ -422,6 +471,20 @@ def get_trainer_cls(self, cfg):
422
471
return trainer_cls
423
472
return None
424
473
474
+ def post_trainer_create (self , cfg , trainer ):
475
+ """
476
+ Calls the post_trainer_create method of all registered plugins.
477
+
478
+ Parameters:
479
+ cfg (dict): The configuration for the plugins.
480
+ trainer (object): The trainer object for training.
481
+
482
+ Returns:
483
+ None
484
+ """
485
+ for plugin in self .plugins .values ():
486
+ plugin .post_trainer_create (cfg , trainer )
487
+
425
488
def create_optimizer (self , trainer ):
426
489
"""
427
490
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
0 commit comments