Skip to content

Commit 9cbabb5

Browse files
committed
Merge branch 'main' into add_active_parameter
2 parents 7896319 + b968433 commit 9cbabb5

File tree

164 files changed

+2179
-786
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

164 files changed

+2179
-786
lines changed

docs/overview/create_available_benchmarks.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Updates the available benchmarks markdown file."""
22

33
from pathlib import Path
4-
from typing import cast
4+
from typing import TYPE_CHECKING, cast
55

66
from prettify_list import pretty_long_list
77
from slugify import slugify_anchor
88

99
import mteb
10-
from mteb.get_tasks import MTEBTasks
10+
11+
if TYPE_CHECKING:
12+
from mteb.get_tasks import MTEBTasks
1113

1214
benchmark_entry = """
1315
#### {benchmark_name}
@@ -38,7 +40,7 @@
3840
def create_table(benchmark: mteb.Benchmark) -> str:
3941
"""Create a markdown table of tasks in the benchmark."""
4042
tasks = benchmark.tasks
41-
tasks = cast(MTEBTasks, tasks)
43+
tasks = cast("MTEBTasks", tasks)
4244
df = tasks.to_dataframe(["name", "type", "modalities", "languages"])
4345

4446
# add links to task names:

docs/overview/create_available_models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
"""Updates the available models markdown files."""
22

3+
from __future__ import annotations
4+
35
from pathlib import Path
6+
from typing import TYPE_CHECKING
47

58
from prettify_list import pretty_long_list
69

710
import mteb
8-
from mteb.models import ModelMeta
11+
12+
if TYPE_CHECKING:
13+
from mteb.models import ModelMeta
914

1015
model_entry = """
1116
#### {model_name_w_link}

docs/overview/create_available_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def task_category_to_string(category: str) -> str:
8383

8484

8585
def create_aggregate_table(task: AbsTaskAggregate) -> str:
86-
tasks = cast(MTEBTasks, MTEBTasks(task.metadata.tasks))
86+
tasks = cast("MTEBTasks", MTEBTasks(task.metadata.tasks))
8787
df = tasks.to_dataframe(["name", "type", "modalities", "languages"])
8888
df["name"] = df.apply(
8989
lambda row: (

docs/usage/loading_results.md

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@ For instance, if you are selecting the best model for semantic text similarity (
77

88
```python
99
import mteb
10-
from mteb.cache import ResultCache
1110

1211
tasks = mteb.get_tasks(tasks=["STS12"])
1312
model_names = ["intfloat/multilingual-e5-large"]
1413

15-
cache = ResultCache("~/.cache/mteb")
14+
cache = mteb.ResultCache("~/.cache/mteb")
1615
results = cache.load_results(models=model_names, tasks=tasks)
1716
```
1817

@@ -36,16 +35,16 @@ All previously submitted results are available results [repository](https://gith
3635
You can download this using:
3736

3837
```python
39-
from mteb.cache import ResultCache
38+
import mteb
4039

41-
cache = ResultCache()
40+
cache = mteb.ResultCache()
4241
cache.download_from_remote() # download results from the remote repository
4342
```
4443

4544
From here, you can work with the cache as usual. For instance, if you are selecting the best model for your French and English retrieval task on legal documents, you could fetch the relevant tasks and create a dataframe of the results using the following code:
4645

4746
```python
48-
from mteb.cache import ResultCache
47+
import mteb
4948

5049
# select your tasks
5150
tasks = mteb.get_tasks(task_types=["Retrieval"], languages=["eng", "fra"], domains=["Legal"])
@@ -56,7 +55,7 @@ model_names = [
5655
]
5756

5857

59-
cache = ResultCache()
58+
cache = mteb.ResultCache()
6059
cache.download_from_remote() # download results from the remote repository. Might take a while the first time.
6160

6261
results = cache.load_results(
@@ -88,11 +87,10 @@ If you loaded results for a specific benchmark, you can get the aggregated bench
8887

8988
```python
9089
import mteb
91-
from mteb.cache import ResultCache
9290

9391
# Load results for a specific benchmark
9492
benchmark = mteb.get_benchmark("MTEB(eng, v2)")
95-
cache = ResultCache()
93+
cache = mteb.ResultCache()
9694
cache.download_from_remote() # download results from the remote repository
9795
results = cache.load_results(
9896
models=["intfloat/e5-small", "intfloat/multilingual-e5-small"],

docs/whats_new.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ results = mteb.evaluate(model, tasks)
3737
### Better local and online caching
3838
The new [`mteb.ResultCache`][mteb.cache.ResultCache] makes managing the cache notably easier:
3939
```py
40-
from mteb.cache import ResultCache
40+
import mteb
4141

4242
model = ...
4343
tasks = ...
4444

45-
cache = ResultCache(cache_path="~/.cache/mteb") # default
45+
cache = mteb.ResultCache(cache_path="~/.cache/mteb") # default
4646

4747
# simple evaluate with cache
4848
results = mteb.evaluate(model, tasks, cache=cache) # only runs if results not in cache
@@ -169,9 +169,9 @@ We've added a lot of new documentation to make it easier to get started with MTE
169169
The new `ResultCache` also makes it easier to load, inspect and compare both local and online results:
170170

171171
```py
172-
from mteb.cache import ResultCache
172+
import mteb
173173

174-
cache = ResultCache(cache_path="~/.cache/mteb") # default
174+
cache = mteb.ResultCache(cache_path="~/.cache/mteb") # default
175175
cache.download_from_remote() # download the latest results from the remote repository
176176

177177
# load both local and online results

mteb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from mteb import types
44
from mteb.abstasks import AbsTask
55
from mteb.abstasks.task_metadata import TaskMetadata
6+
from mteb.cache import ResultCache
67
from mteb.deprecated_evaluator import MTEB
78
from mteb.evaluate import evaluate
89
from mteb.filter_tasks import filter_tasks
@@ -33,6 +34,7 @@
3334
"CrossEncoderProtocol",
3435
"EncoderProtocol",
3536
"IndexEncoderSearchProtocol",
37+
"ResultCache",
3638
"SearchProtocol",
3739
"SentenceTransformerEncoderWrapper",
3840
"TaskMetadata",

mteb/_create_dataloaders.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,28 @@
1+
from __future__ import annotations
2+
13
import logging
24
import warnings
3-
from collections.abc import Callable
4-
from typing import Any, cast
5+
from typing import TYPE_CHECKING, Any, cast
56

67
import torch
78
from datasets import Dataset, Image
89
from torch.utils.data import DataLoader, default_collate
910

10-
from mteb.abstasks.task_metadata import TaskMetadata
1111
from mteb.types import (
12-
BatchedInput,
13-
Conversation,
1412
ConversationTurn,
1513
PromptType,
16-
QueryDatasetType,
1714
)
18-
from mteb.types._encoder_io import CorpusInput, ImageInput, QueryInput, TextInput
15+
16+
if TYPE_CHECKING:
17+
from collections.abc import Callable
18+
19+
from mteb.abstasks.task_metadata import TaskMetadata
20+
from mteb.types import (
21+
BatchedInput,
22+
Conversation,
23+
QueryDatasetType,
24+
)
25+
from mteb.types._encoder_io import CorpusInput, ImageInput, QueryInput, TextInput
1926

2027
logger = logging.getLogger(__name__)
2128

@@ -128,7 +135,7 @@ def _convert_conv_history_to_query(
128135
conversation = row["text"]
129136
# if it's a list of strings, just join them
130137
if isinstance(conversation, list) and isinstance(conversation[0], str):
131-
conversation_ = cast(list[str], conversation)
138+
conversation_ = cast("list[str]", conversation)
132139
conv_str = "; ".join(conversation_)
133140
current_conversation = [
134141
ConversationTurn(role="user", content=message) for message in conversation_
@@ -173,7 +180,7 @@ def _convert_conv_history_to_query(
173180

174181
row["text"] = conv_str
175182
row["conversation"] = current_conversation
176-
return cast(dict[str, str | list[ConversationTurn]], row)
183+
return cast("dict[str, str | list[ConversationTurn]]", row)
177184

178185

179186
def _create_dataloader_for_queries_conversation(

mteb/_evaluators/any_sts_evaluator.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
1+
from __future__ import annotations
2+
13
import logging
2-
from typing import TypedDict
4+
from typing import TYPE_CHECKING, TypedDict
35

4-
from datasets import Dataset
56
from sklearn.metrics.pairwise import (
67
paired_cosine_distances,
78
paired_euclidean_distances,
89
paired_manhattan_distances,
910
)
1011

1112
from mteb._create_dataloaders import create_dataloader
12-
from mteb.abstasks.task_metadata import TaskMetadata
13-
from mteb.models import EncoderProtocol
1413
from mteb.similarity_functions import compute_pairwise_similarity
15-
from mteb.types import EncodeKwargs, PromptType
1614

1715
from .evaluator import Evaluator
1816

17+
if TYPE_CHECKING:
18+
from datasets import Dataset
19+
20+
from mteb.abstasks.task_metadata import TaskMetadata
21+
from mteb.models import EncoderProtocol
22+
from mteb.types import EncodeKwargs, PromptType
23+
1924
logger = logging.getLogger(__name__)
2025

2126

mteb/_evaluators/clustering_evaluator.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
1+
from __future__ import annotations
2+
13
import logging
4+
from typing import TYPE_CHECKING
25

3-
from datasets import Dataset
46
from sklearn import cluster
57

68
from mteb._create_dataloaders import create_dataloader
7-
from mteb.abstasks.task_metadata import TaskMetadata
8-
from mteb.models import EncoderProtocol
9-
from mteb.types import EncodeKwargs
109

1110
from .evaluator import Evaluator
1211

12+
if TYPE_CHECKING:
13+
from datasets import Dataset
14+
15+
from mteb.abstasks.task_metadata import TaskMetadata
16+
from mteb.models import EncoderProtocol
17+
from mteb.types import EncodeKwargs
18+
1319
logger = logging.getLogger(__name__)
1420

1521

mteb/_evaluators/evaluator.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
from __future__ import annotations
2+
13
from abc import ABC, abstractmethod
2-
from collections.abc import Iterable, Mapping
3-
from typing import Any
4+
from typing import TYPE_CHECKING, Any
45

56
from mteb.abstasks.abstask import _set_seed
6-
from mteb.models import EncoderProtocol
7-
from mteb.types import EncodeKwargs
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Iterable, Mapping
10+
11+
from mteb.models import EncoderProtocol
12+
from mteb.types import EncodeKwargs
813

914

1015
class Evaluator(ABC):

0 commit comments

Comments
 (0)