Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,9 @@ def compute_metrices(self, model: SentenceTransformer) -> dict[str, dict[str, fl

output_scores[similarity_fn_name] = {
"accuracy": acc,
"accuracy_threshold": acc_threshold,
"accuracy_threshold": np.float64(acc_threshold),
"f1": f1,
"f1_threshold": f1_threshold,
"f1_threshold": np.float64(f1_threshold),
"precision": precision,
"recall": recall,
"ap": ap,
Expand Down
42 changes: 34 additions & 8 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ class SentenceTransformerTrainer(Trainer):
callbacks (List of [:class:`transformers.TrainerCallback`], *optional*):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](callback).
extra_feature_keys (List[str], *optional*):
If you have a custom processor that adds extra features to the dataset, similar to "pixel_values" in CLIP,
you can specify strings which are used as a prefix of these features here. These keys are used in the `collect_features` method.

If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
optimizers (`Tuple[:class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler.LambdaLR`]`, *optional*, defaults to `(None, None)`):
Expand Down Expand Up @@ -136,6 +139,7 @@ def __init__(
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
extra_feature_keys: list[str] | None = None,
) -> None:
if not is_training_available():
raise RuntimeError(
Expand Down Expand Up @@ -292,6 +296,10 @@ def __init__(
eval_dataset, args.prompts, dataset_name="eval"
)
self.add_model_card_callback(default_args_dict)
if isinstance(extra_feature_keys, str):
extra_feature_keys = [extra_feature_keys]
else:
self.extra_feature_keys = extra_feature_keys or []

def add_model_card_callback(self, default_args_dict: dict[str, Any]) -> None:
"""
Expand Down Expand Up @@ -434,16 +442,34 @@ def collect_features(
# All inputs ending with `_input_ids` (Transformers), `_sentence_embedding` (BoW), `_pixel_values` (CLIPModel)
# are considered to correspond to a feature
features = []
extra_suffixes = {key: "_" + key for key in self.extra_feature_keys}
for column in inputs:
if column.endswith("_input_ids"):
prefix = column[: -len("input_ids")]
elif column.endswith("_sentence_embedding"):
prefix = column[: -len("sentence_embedding")]
elif column.endswith("_pixel_values"):
prefix = column[: -len("pixel_values")]
else:
prefix = None # Reset prefix for each column

# First, check extra feature keys
# add "_" infront of every key
for key, suffix in extra_suffixes.items():
if column.endswith(suffix):
prefix = column[: -len(key)]
break # Stop checking once we find a match

# If no match found, check predefined suffixes
if prefix is None:
if column.endswith("_input_ids"):
prefix = column[: -len("input_ids")]
elif column.endswith("_sentence_embedding"):
prefix = column[: -len("sentence_embedding")]
elif column.endswith("_pixel_values"):
prefix = column[: -len("pixel_values")]

# Skip columns with no matching suffix
if prefix is None:
continue
features.append({key[len(prefix) :]: value for key, value in inputs.items() if key.startswith(prefix)})

# Collect all features that start with the detected prefix
feature_dict = {key[len(prefix) :]: value for key, value in inputs.items() if key.startswith(prefix)}
features.append(feature_dict)

labels = inputs.get("label", None)
return features, labels

Expand Down
Loading