Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions src/axolotl/core/builders/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import (
AxolotlCPOTrainer,
AxolotlDPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
)
Expand Down Expand Up @@ -36,33 +37,23 @@ def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks

def _get_trainer_cls(self, trainer_kwargs: dict):
"""
Returns trainer_cls and trainer_cls_args
"""
def _get_trainer_cls(self):
"""Returns trainer_cls"""
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
trainer_cls_args = [] # type: ignore

if trainer_cls is not None:
return trainer_cls, trainer_cls_args
return trainer_cls

trainer_cls = None
trainer_cls_args = [self.model]

if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))

trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))

elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args.append(self.model_ref)

trainer_cls = AxolotlDPOTrainer
elif self.cfg.rl is RLType.ORPO:
trainer_cls = AxolotlORPOTrainer
elif self.cfg.rl is RLType.KTO:
Expand All @@ -72,7 +63,7 @@ def _get_trainer_cls(self, trainer_kwargs: dict):
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")

return trainer_cls, trainer_cls_args
return trainer_cls

def _build_training_arguments(self, total_num_steps):
"""
Expand Down Expand Up @@ -182,17 +173,23 @@ def build(self, total_num_steps):
self.cfg.precompute_ref_log_probs
)

trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
trainer_cls = self._get_trainer_cls()
trainer_cls_args = [self.model]

if self.cfg.rl is RLType.GRPO:
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))

if self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls_args.append(self.model_ref)

sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
else:
trainer_kwargs["processing_class"] = self.tokenizer

if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
Expand Down
Loading