Skip to content

Commit 6565ae8

Browse files
authored
set config on the PluginManager for callback access (axolotl-ai-cloud#2587)
1 parent 80b4edb commit 6565ae8

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

src/axolotl/cli/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ def prepare_plugins(cfg: DictDefault):
152152
plugin_manager.register(plugin_name)
153153

154154

155+
def plugin_set_cfg(cfg: DictDefault):
156+
if cfg.get("plugins"):
157+
plugin_manager = PluginManager.get_instance()
158+
plugin_manager.cfg = cfg
159+
160+
155161
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
156162
"""
157163
Loads the `axolotl` configuration stored at `config`, validates it, and performs
@@ -213,5 +219,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
213219
setup_wandb_env_vars(cfg)
214220
setup_mlflow_env_vars(cfg)
215221
setup_comet_env_vars(cfg)
222+
plugin_set_cfg(cfg)
216223

217224
return cfg

src/axolotl/integrations/base.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,14 +270,17 @@ class PluginManager:
270270
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
271271

272272
_instance = None
273+
_cfg = None
273274

274275
def __new__(cls):
275276
"""
276277
Creates a new instance of PluginManager if it doesn't exist yet.
277278
"""
278279
if cls._instance is None:
279280
cls._instance = super(PluginManager, cls).__new__(cls)
280-
cls._instance.plugins = collections.OrderedDict()
281+
cls._instance.plugins: OrderedDict[str, BasePlugin] = (
282+
collections.OrderedDict()
283+
)
281284
return cls._instance
282285

283286
@staticmethod
@@ -290,6 +293,14 @@ def get_instance() -> "PluginManager":
290293
PluginManager()
291294
return PluginManager._instance # type: ignore
292295

296+
@property
297+
def cfg(self):
298+
return self._cfg
299+
300+
@cfg.setter
301+
def cfg(self, cfg):
302+
self._cfg = cfg
303+
293304
def register(self, plugin_name: str):
294305
"""
295306
Registers a new plugin by its name.
@@ -409,37 +420,35 @@ def get_trainer_cls(self, cfg):
409420
return trainer_cls
410421
return None
411422

412-
def create_optimizer(self, cfg, trainer):
423+
def create_optimizer(self, trainer):
413424
"""
414425
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
415426
416427
Parameters:
417-
cfg (dict): The configuration for the plugins.
418428
trainer (object): The trainer object for training.
419429
420430
Returns:
421431
object: The created optimizer, or None if none was found.
422432
"""
423433
for plugin in self.plugins.values():
424-
optimizer = plugin.create_optimizer(cfg, trainer)
434+
optimizer = plugin.create_optimizer(self.cfg, trainer)
425435
if optimizer is not None:
426436
return optimizer
427437
return None
428438

429-
def create_lr_scheduler(self, cfg, trainer, optimizer):
439+
def create_lr_scheduler(self, trainer, optimizer):
430440
"""
431441
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
432442
433443
Parameters:
434-
cfg (dict): The configuration for the plugins.
435444
trainer (object): The trainer object for training.
436445
optimizer (object): The optimizer for training.
437446
438447
Returns:
439448
object: The created learning rate scheduler, or None if none was found.
440449
"""
441450
for plugin in self.plugins.values():
442-
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
451+
scheduler = plugin.create_lr_scheduler(self.cfg, trainer, optimizer)
443452
if scheduler is not None:
444453
return scheduler
445454
return None

0 commit comments

Comments
 (0)