Skip to content

Commit 3d3ab36

Browse files
authored
[New Model]: Snowflake Arctic Embed (Family) (vllm-project#16649)
1 parent 686623c commit 3d3ab36

File tree

7 files changed

+312
-26
lines changed

7 files changed

+312
-26
lines changed

tests/entrypoints/openai/test_embedding_dimensions.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,17 @@
33
Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`.
44
"""
55

6-
from typing import NamedTuple
7-
86
import openai
97
import pytest
108

119
from vllm.entrypoints.openai.protocol import EmbeddingResponse
1210

11+
from ...models.embedding.utils import EmbedModelInfo
1312
from ...utils import RemoteOpenAIServer
1413

15-
16-
class ModelInfo(NamedTuple):
17-
name: str
18-
is_matryoshka: bool
19-
20-
2114
MODELS = [
22-
ModelInfo(name="BAAI/bge-m3", is_matryoshka=False),
23-
ModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True),
15+
EmbedModelInfo(name="BAAI/bge-m3", is_matryoshka=False),
16+
EmbedModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True),
2417
]
2518

2619
input_texts = [
@@ -30,7 +23,7 @@ class ModelInfo(NamedTuple):
3023

3124
@pytest.mark.asyncio
3225
@pytest.mark.parametrize("model", MODELS)
33-
async def test_validating_dimensions(model: ModelInfo):
26+
async def test_validating_dimensions(model: EmbedModelInfo):
3427
args = [
3528
"--task",
3629
"embed",
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Compare the embedding outputs of HF and vLLM models.
3+
4+
Run `pytest tests/models/embedding/language/test_snowflake_arctic_embed.py`.
5+
"""
6+
import pytest
7+
8+
from tests.models.embedding.utils import EmbedModelInfo
9+
10+
from ..utils import check_embeddings_close
11+
12+
EMBEDDING_PROMPTS = [
13+
'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!',
14+
'Mexico City of Course!'
15+
]
16+
17+
MODELS = [
18+
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
19+
is_matryoshka=False,
20+
architecture="BertModel",
21+
enable_test=True),
22+
EmbedModelInfo("Snowflake/snowflake-arctic-embed-s",
23+
is_matryoshka=False,
24+
architecture="BertModel",
25+
enable_test=False),
26+
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m",
27+
is_matryoshka=False,
28+
architecture="BertModel",
29+
enable_test=False),
30+
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long",
31+
is_matryoshka=False,
32+
architecture="NomicBertModel",
33+
enable_test=True),
34+
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l",
35+
is_matryoshka=False,
36+
architecture="BertModel",
37+
enable_test=False),
38+
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
39+
is_matryoshka=True,
40+
architecture="BertModel",
41+
enable_test=True),
42+
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0",
43+
is_matryoshka=True,
44+
architecture="XLMRobertaModel",
45+
enable_test=True),
46+
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
47+
is_matryoshka=True,
48+
architecture="GteModel",
49+
enable_test=True),
50+
]
51+
52+
53+
@pytest.mark.parametrize("model_info", MODELS)
54+
@pytest.mark.parametrize("dtype", ["half"])
55+
def test_models(
56+
hf_runner,
57+
vllm_runner,
58+
example_prompts,
59+
model_info: EmbedModelInfo,
60+
dtype: str,
61+
monkeypatch,
62+
) -> None:
63+
if not model_info.enable_test:
64+
# A model family has many models with the same architecture,
65+
# and we don't need to test each one.
66+
pytest.skip("Skipping test.")
67+
68+
example_prompts = example_prompts + EMBEDDING_PROMPTS
69+
70+
vllm_extra_kwargs = {
71+
"hf_overrides": {
72+
"is_matryoshka": model_info.is_matryoshka
73+
}
74+
}
75+
76+
with hf_runner(model_info.name, dtype=dtype,
77+
is_sentence_transformer=True) as hf_model:
78+
hf_outputs = hf_model.encode(example_prompts)
79+
80+
with vllm_runner(model_info.name,
81+
task="embed",
82+
dtype=dtype,
83+
max_model_len=None,
84+
**vllm_extra_kwargs) as vllm_model:
85+
86+
assert (vllm_model.model.llm_engine.model_config.is_matryoshka ==
87+
model_info.is_matryoshka)
88+
89+
if model_info.architecture:
90+
assert (model_info.architecture
91+
in vllm_model.model.llm_engine.model_config.architectures)
92+
93+
vllm_outputs = vllm_model.encode(example_prompts)
94+
95+
check_embeddings_close(
96+
embeddings_0_lst=hf_outputs,
97+
embeddings_1_lst=vllm_outputs,
98+
name_0="hf",
99+
name_1="vllm",
100+
tol=1e-2,
101+
)

tests/models/embedding/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from collections.abc import Sequence
4+
from typing import NamedTuple
45

56
import torch
67
import torch.nn.functional as F
@@ -37,3 +38,10 @@ def matryoshka_fy(tensor, dimensions):
3738
tensor = tensor[..., :dimensions]
3839
tensor = F.normalize(tensor, p=2, dim=1)
3940
return tensor
41+
42+
43+
class EmbedModelInfo(NamedTuple):
44+
name: str
45+
is_matryoshka: bool
46+
architecture: str = ""
47+
enable_test: bool = True

tests/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,11 +247,15 @@ def check_available_online(
247247
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
248248
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
249249
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
250+
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
251+
trust_remote_code=True),
250252
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
251253
trust_remote_code=True),
252254
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
253255
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
254256
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
257+
"NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501
258+
trust_remote_code=True),
255259
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
256260
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
257261
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),

vllm/model_executor/layers/activation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
354354
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
355355
"gelu": lambda: GeluAndMul(),
356356
"silu": lambda: SiluAndMul(),
357+
"gelu_and_mul": lambda: GeluAndMul(),
357358
})
358359

359360

0 commit comments

Comments
 (0)