Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mteb/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_task_result_path(
Returns:
The path to the results of the task.
"""
results_folder = "results" if not remote else "remote"
results_folder = self.cache_path / "results" if not remote else self.cache_path / "remote" / "results"

if isinstance(model_name, ModelMeta):
if model_revision is not None:
Expand All @@ -74,7 +74,7 @@ def get_task_result_path(
elif isinstance(model_name, str):
model_name = model_name.replace("/", "__").replace(" ", "_")

model_path = self.cache_path / results_folder / model_name
model_path = results_folder / model_name

if model_revision is None:
logger.warning(
Expand Down
26 changes: 22 additions & 4 deletions mteb/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,20 @@ def _check_model_modalities(
logger.warning(msg)


def _requires_merge(task: AbsTask, existing_results: TaskResult) -> bool:
"""Check if the existing results require merging with new results."""
# If the task has multiple eval splits and existing results cover only a subset, we need to merge
required_evals = dict.fromkeys(task.eval_splits, task.hf_subsets)
for split, subsets in required_evals.items():
res = existing_results.scores.get(split, None)
if res is None:
return True
hf_subsets = [r["hf_subset"] for r in res]
if not set(subsets).issubset(set(hf_subsets)):
return True
return False


def evaluate(
model: ModelMeta | MTEBModels | SentenceTransformer | CrossEncoder,
tasks: AbsTask | Iterable[AbsTask],
Expand Down Expand Up @@ -388,9 +402,12 @@ def evaluate(

if (
existing_results
and overwrite_strategy == "only-missing"
and overwrite_strategy == OverwriteStrategy.ONLY_MISSING
and existing_results.is_mergeable(task)
and overwrite_strategy
not in (OverwriteStrategy.ALWAYS, OverwriteStrategy.NEVER)
and (
not _requires_merge(task, existing_results)
or existing_results.is_mergeable(task)
)
):
missing_eval = existing_results.get_missing_evaluations(task)
else:
Expand All @@ -415,7 +432,8 @@ def evaluate(
OverwriteStrategy.ONLY_CACHE,
]:
raise ValueError(
f"overwrite_strategy is set to '{overwrite_strategy.value}' and the results file exists. However there are the following missing splits (and subsets): {missing_eval}. To rerun these set overwrite_strategy to 'only-missing'."
f"overwrite_strategy is set to '{overwrite_strategy.value}' and the results file exists for task {task.metadata.name}. "
+ f"However there are the following missing splits (and subsets): {missing_eval}. To rerun these set overwrite_strategy to 'only-missing'."
)

if existing_results:
Expand Down
22 changes: 13 additions & 9 deletions mteb/results/task_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,27 +698,31 @@ def is_mergeable(
name = result.metadata.name
revision = result.metadata.revision
else:
msg = "result must be a TaskResult or AbsTask object"
if raise_error:
raise ValueError(msg)
logger.debug(msg)
return False

if self.task_name != name:
msg = f"Cannot merge TaskResult objects as they are derived from different tasks ({self.task_name} and {name})"
if raise_error:
raise ValueError(
f"Cannot merge TaskResult objects as they are derived from different tasks ({self.task_name} and {name})"
)
raise ValueError(msg)
logger.debug(msg)
return False

if Criteria.MTEB_VERSION in criteria and self.mteb_version != mteb_version:
msg = f"Cannot merge TaskResult objects as they are derived from different MTEB versions ({self.mteb_version} (loaded) and {mteb_version} (current))"
if raise_error:
raise ValueError(
f"Cannot merge TaskResult objects as they are derived from different MTEB versions ({self.mteb_version} and {mteb_version})"
)
raise ValueError(msg)
logger.debug(msg)
return False

if Criteria.DATASET_REVISION in criteria and self.dataset_revision != revision:
msg = f"Cannot merge TaskResult objects as they are derived from different dataset revisions ({self.dataset_revision} and {revision})"
if raise_error:
raise ValueError(
f"Cannot merge TaskResult objects as they are derived from different dataset revisions ({self.dataset_revision} and {revision})"
)
raise ValueError(msg)
logger.debug(msg)
return False

return True
Expand Down
Loading