From 541452886d9751079e835b23a355b7aa0cacadb2 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Fri, 13 Mar 2026 11:44:25 +0530 Subject: [PATCH 1/3] trl_num_lables=1 --- src/axolotl/utils/schemas/validation.py | 17 +++++++++++++++ tests/core/test_builders.py | 2 +- tests/patched/test_validation.py | 28 +++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 2ff57558f7..dfdcbebd87 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -254,6 +254,23 @@ def hint_reward_model_pad(cls, data): data["pad_to_sequence_len"] = True return data + @model_validator(mode="before") + @classmethod + def set_reward_model_defaults(cls, data): + if data.get("reward_model"): + if data.get("num_labels") is None: + data["num_labels"] = 1 + if not (data.get("type_of_model") or data.get("model_type")): + data["model_type"] = "AutoModelForSequenceClassification" + + if data.get("process_reward_model"): + if data.get("num_labels") is None: + data["num_labels"] = 2 + if not (data.get("type_of_model") or data.get("model_type")): + data["model_type"] = "AutoModelForTokenClassification" + + return data + @model_validator(mode="before") @classmethod def check_gas_bsz(cls, data): diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index ea3c4e6c4e..a241e85492 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -536,7 +536,7 @@ def test_training_arguments(self, sft_cfg, model, tokenizer): "cfg_string", [ "sft_cfg", - # "rm_cfg", # TODO fix for num_labels = 2 vs 1 + "rm_cfg", "prm_cfg", ], ) diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 21299ed980..d22927940d 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -277,6 +277,34 @@ def test_model_type_remap(self, minimal_cfg): new_cfg = validate_config(cfg) assert new_cfg.type_of_model == "AutoModelForCausalLM" + def test_reward_model_defaults(self, minimal_cfg): + cfg = ( + DictDefault( + { + "reward_model": True, + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg.num_labels == 1 + assert new_cfg.type_of_model == "AutoModelForSequenceClassification" + + def test_process_reward_model_defaults(self, minimal_cfg): + cfg = ( + DictDefault( + { + "process_reward_model": True, + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg.num_labels == 2 + assert new_cfg.type_of_model == "AutoModelForTokenClassification" + def test_model_revision_remap(self, minimal_cfg): cfg = ( DictDefault( From a67c0606cd4a8c704bf8e1f42b4f408fa03d71e5 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Fri, 13 Mar 2026 16:32:12 +0530 Subject: [PATCH 2/3] casual num_lables=1,rwd model --- src/axolotl/core/builders/causal.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index c238cbbc3f..b062724966 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -421,6 +421,10 @@ def build(self, total_num_steps): trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() ] + # TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the + # config reflects this regardless of how the model was instantiated. + if self.cfg.reward_model and getattr(self.model.config, "num_labels", None) != 1: + self.model.config.num_labels = 1 trainer = trainer_cls( model=self.model, train_dataset=self.train_dataset, From 1bb0fe05cce59bbe5c526659b316b8aa2994a727 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Fri, 13 Mar 2026 16:37:55 +0530 Subject: [PATCH 3/3] lint --- src/axolotl/core/builders/causal.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index b062724966..f26ef8969e 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -423,7 +423,10 @@ def build(self, total_num_steps): ] # TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the # config reflects this regardless of how the model was instantiated. - if self.cfg.reward_model and getattr(self.model.config, "num_labels", None) != 1: + if ( + self.cfg.reward_model + and getattr(self.model.config, "num_labels", None) != 1 + ): self.model.config.num_labels = 1 trainer = trainer_cls( model=self.model,