Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,26 +594,19 @@ def evaluation_loop(
def _load_best_model(self) -> None:
# Attempt to load the model from self.state.best_model_checkpoint
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
try:
dummy_model = self.model.__class__(
self.state.best_model_checkpoint,
trust_remote_code=self.model.trust_remote_code,
)
except Exception as exc:
logger.error(f"Could not load the best model from {self.state.best_model_checkpoint}. Error: {str(exc)}")
return

# Store the best model checkpoint in the model card
try:
if checkpoint := self.state.best_model_checkpoint:
step = checkpoint.rsplit("-", 1)[-1]
self.model.model_card_data.set_best_model_step(int(step))
except Exception:
pass

# Ideally, the only changes between self.model and the dummy model are the weights
# so we should be able to just copy the state dict
self.model.load_state_dict(dummy_model.state_dict())
try:
self._load_from_checkpoint(self.state.best_model_checkpoint)
except Exception as exc:
logger.error(f"Could not load the best model from {self.state.best_model_checkpoint}. Error: {str(exc)}")
return

def validate_column_names(self, dataset: Dataset, dataset_name: str | None = None) -> None:
if isinstance(dataset, dict):
Expand Down
Loading