Skip to content

Commit 0641915

Browse files
committed
fix template file
1 parent d6e1245 commit 0641915

File tree

2 files changed

+4
-66
lines changed

2 files changed

+4
-66
lines changed

trl/trainer/base_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class BaseTrainer(Trainer):
2828
_tag_names = []
2929
_name = "Base"
3030
_paper = {}
31+
_template_file = None
3132

3233
def create_model_card(
3334
self,
@@ -78,6 +79,7 @@ def create_model_card(
7879
comet_url=get_comet_experiment_url(),
7980
trainer_name=self._name,
8081
trainer_citation=self._paper.get("citation"),
82+
template_file=self._template_file,
8183
paper_title=self._paper.get("title"),
8284
paper_id=self._paper.get("id"),
8385
)

trl/trainer/reward_trainer.py

Lines changed: 2 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,16 @@
3030
DataCollator,
3131
PreTrainedModel,
3232
PreTrainedTokenizerBase,
33-
ProcessorMixin,
3433
)
3534
from transformers.data.data_collator import DataCollatorMixin
3635
from transformers.trainer_callback import TrainerCallback
3736
from transformers.trainer_utils import EvalPrediction
3837
from transformers.utils import is_peft_available
3938

40-
from ..data_utils import maybe_apply_chat_template
4139
from ..models import prepare_peft_model
4240
from .base_trainer import BaseTrainer
4341
from .reward_config import RewardConfig
44-
from .utils import (
45-
RewardDataCollatorWithPadding,
46-
compute_accuracy,
47-
decode_and_strip_padding,
48-
disable_dropout_in_model,
49-
log_table_to_comet_experiment,
50-
print_rich_table,
51-
)
52-
42+
from .utils import disable_dropout_in_model
5343

5444
if is_peft_available():
5545
from peft import PeftConfig, PeftModel
@@ -260,6 +250,7 @@ class RewardTrainer(BaseTrainer):
260250

261251
_tag_names = ["trl", "reward-trainer"]
262252
_name = "Reward"
253+
_template_file = "rm_model_card.md"
263254

264255
def __init__(
265256
self,
@@ -600,58 +591,3 @@ def _save_checkpoint(self, model, trial):
600591
model_name = self.args.hub_model_id.split("/")[-1]
601592
self.create_model_card(model_name=model_name)
602593
super()._save_checkpoint(model, trial)
603-
604-
def create_model_card(
605-
self,
606-
model_name: Optional[str] = None,
607-
dataset_name: Optional[str] = None,
608-
tags: Union[str, list[str], None] = None,
609-
):
610-
"""
611-
Creates a draft of a model card using the information available to the `Trainer`.
612-
613-
Args:
614-
model_name (`str`, *optional*):
615-
Name of the model.
616-
dataset_name (`str`, *optional*):
617-
Name of the dataset used for training.
618-
tags (`str`, `list[str]`, *optional*):
619-
Tags to be associated with the model card.
620-
"""
621-
if not self.is_world_process_zero():
622-
return
623-
624-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
625-
base_model = self.model.config._name_or_path
626-
else:
627-
base_model = None
628-
629-
# normalize `tags` to a mutable set
630-
if tags is None:
631-
tags = set()
632-
elif isinstance(tags, str):
633-
tags = {tags}
634-
else:
635-
tags = set(tags)
636-
637-
if hasattr(self.model.config, "unsloth_version"):
638-
tags.add("unsloth")
639-
640-
if "JOB_ID" in os.environ:
641-
tags.add("hf_jobs")
642-
643-
tags.update(self._tag_names)
644-
645-
model_card = generate_model_card(
646-
base_model=base_model,
647-
model_name=model_name,
648-
hub_model_id=self.hub_model_id,
649-
dataset_name=dataset_name,
650-
tags=list(tags),
651-
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
652-
comet_url=get_comet_experiment_url(),
653-
trainer_name="Reward",
654-
template_file="rm_model_card.md",
655-
)
656-
657-
model_card.save(os.path.join(self.args.output_dir, "README.md"))

0 commit comments

Comments
 (0)