diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 54ee195361..a7552db9dd 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -49,6 +49,7 @@ RewardTrainer, ) from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length +from llmcompressor.transformers.finetune.session_mixin import SessionManagerMixIn from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES @@ -424,7 +425,7 @@ def create_scheduler( return self.lr_scheduler -class AxolotlTrainer(SchedulerMixin, Trainer): +class AxolotlTrainer(SessionManagerMixIn, SchedulerMixin, Trainer): """ Extend the base Trainer for axolotl helpers """ @@ -1309,11 +1310,12 @@ class TrainerBuilderBase(abc.ABC): _model_ref = None _peft_config = None - def __init__(self, cfg, model, tokenizer, processor=None): + def __init__(self, cfg, model, tokenizer, processor=None, teacher=None): self.cfg = cfg self.model = model self.tokenizer = tokenizer self.processor = processor + self.teacher = teacher # in case the model supports tagging, add the axolotl tag. # This makes sure the tag is correctly pushed even if a user calls @@ -1950,6 +1952,11 @@ def build(self, total_num_steps): trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() ] + + if self.cfg.compressor: + trainer_kwargs["recipe"] = self.cfg.compressor.recipe + trainer_kwargs["recipe_args"] = self.cfg.compressor.recipe_args or {} + trainer = trainer_cls( model=self.model, train_dataset=self.train_dataset, @@ -1957,6 +1964,7 @@ def build(self, total_num_steps): args=training_args, data_collator=self.build_collator(training_args, **data_collator_kwargs), callbacks=self.get_callbacks(), + teacher=self.teacher, **trainer_kwargs, ) trainer = self.hook_post_create_trainer(trainer) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index dc7289b093..dc893391a2 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -25,7 +25,7 @@ from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except -from axolotl.utils.models import load_model, load_processor, load_tokenizer +from axolotl.utils.models import load_model, load_processor, load_tokenizer, load_teacher from axolotl.utils.trainer import setup_trainer try: @@ -98,6 +98,8 @@ def train( if model.generation_config is not None: model.generation_config.do_sample = True + teacher = load_teacher(cfg, tokenizer) + model_ref = None if cfg.rl and cfg.rl != "orpo": if cfg.adapter and not cfg.rl_adapter_ref_model: @@ -123,6 +125,7 @@ def train( tokenizer, processor, total_num_steps, + teacher, ) if cfg.fix_untrained_tokens: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 523fd76feb..152af310b7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1203,6 +1203,27 @@ def load_model( return loader.load_model() +def load_teacher( + cfg: DictDefault, + tokenizer: PreTrainedTokenizerBase, +) -> Optional[PreTrainedModel]: + """ + Load a teacher model for a given configuration and tokenizer. + """ + if not cfg.compressor or not cfg.compressor.teacher: + return None + + loader = ModelLoader( + cfg.compressor.teacher, + tokenizer, + inference=True, + reference_model=True, + ) + teacher, _ = loader.load_model() + + return teacher + + def load_adapter(model, cfg, adapter, inference=False): # type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 32e54c9a86..256d78419c 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -513,14 +513,14 @@ def prepare_opinionated_env(cfg): def setup_trainer( - cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps + cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps, teacher ): if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"): trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] else: - trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor) + trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor, teacher) trainer_builder.train_dataset = train_dataset trainer_builder.eval_dataset = eval_dataset