From d47093fcdddfcf93fd27e485cd67e928b6856097 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 22:29:56 +0700 Subject: [PATCH] fix: simplify fn same as sft and pass model to plugin --- src/axolotl/core/builders/rl.py | 35 +++++++++++++++------------------ 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index c5f01dd418..f08513feac 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -6,6 +6,7 @@ from axolotl.core.builders.base import TrainerBuilderBase from axolotl.core.trainers import ( AxolotlCPOTrainer, + AxolotlDPOTrainer, AxolotlKTOTrainer, AxolotlORPOTrainer, ) @@ -36,33 +37,23 @@ def get_post_trainer_create_callbacks(self, trainer): callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) return callbacks - def _get_trainer_cls(self, trainer_kwargs: dict): - """ - Returns trainer_cls and trainer_cls_args - """ + def _get_trainer_cls(self): + """Returns trainer_cls""" if self.cfg.plugins: plugin_manager = PluginManager.get_instance() trainer_cls = plugin_manager.get_trainer_cls(self.cfg) - trainer_cls_args = [] # type: ignore if trainer_cls is not None: - return trainer_cls, trainer_cls_args + return trainer_cls trainer_cls = None - trainer_cls_args = [self.model] if self.cfg.rl is RLType.GRPO: trainer_cls = GRPOStrategy.get_trainer_class( sequence_parallel=self.cfg.sequence_parallel_degree > 1 ) - trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) - - trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) - elif self.cfg.rl in [RLType.DPO, RLType.IPO]: - trainer_cls = DPOStrategy.get_trainer_class() - trainer_cls_args.append(self.model_ref) - + trainer_cls = AxolotlDPOTrainer elif self.cfg.rl is RLType.ORPO: trainer_cls = AxolotlORPOTrainer elif self.cfg.rl is RLType.KTO: @@ -72,7 +63,7 @@ def _get_trainer_cls(self, trainer_kwargs: dict): else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") - return trainer_cls, trainer_cls_args + return trainer_cls def _build_training_arguments(self, total_num_steps): """ @@ -182,7 +173,15 @@ def build(self, total_num_steps): self.cfg.precompute_ref_log_probs ) - trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs) + trainer_cls = self._get_trainer_cls() + trainer_cls_args = [self.model] + + if self.cfg.rl is RLType.GRPO: + trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) + trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) + + if self.cfg.rl in [RLType.DPO, RLType.IPO]: + trainer_cls_args.append(self.model_ref) sig = inspect.signature(trainer_cls) if "tokenizer" in sig.parameters: @@ -190,9 +189,7 @@ def build(self, total_num_steps): else: trainer_kwargs["processing_class"] = self.tokenizer - if self.cfg.datasets is not None and ( - trainer_cls is DPOStrategy.get_trainer_class() - ): + if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer): trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() ]