Skip to content

Commit 35cf32d

Browse files
authored
Improve the output precision of embedding models (vllm-project#19092)
1 parent 8711bc5 commit 35cf32d

File tree

8 files changed

+69
-28
lines changed

8 files changed

+69
-28
lines changed

tests/models/language/pooling/embed_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,10 @@ def correctness_test_embed_models(hf_runner,
5656
max_model_len=None,
5757
**vllm_extra_kwargs) as vllm_model:
5858
vllm_outputs = vllm_model.encode(example_prompts)
59-
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
60-
model_dtype = getattr(
61-
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
62-
vllm_dtype)
6359

6460
with hf_runner(
6561
model_info.name,
66-
dtype=model_dtype,
62+
dtype="float32",
6763
is_sentence_transformer=True,
6864
) as hf_model:
6965

tests/models/language/pooling/mteb_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pytest
88

99
from tests.models.utils import EmbedModelInfo
10-
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
1110

1211
# Most models on the STS12 task (See #17175):
1312
# - Model implementation and minor changes in tensor dtype
@@ -104,17 +103,18 @@ def mteb_test_embed_models(hf_runner,
104103
MTEB_EMBED_TASKS)
105104
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
106105

107-
with set_default_torch_dtype(vllm_dtype) and hf_runner(
108-
model_info.name, is_sentence_transformer=True,
109-
dtype=vllm_dtype) as hf_model:
106+
with hf_runner(model_info.name,
107+
is_sentence_transformer=True,
108+
dtype="float32") as hf_model:
110109

111110
if hf_model_callback is not None:
112111
hf_model_callback(hf_model)
113112

114113
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
114+
st_dtype = next(hf_model.model.parameters()).dtype
115115

116-
print("VLLM:", vllm_main_score)
117-
print("SentenceTransformers:", st_main_score)
116+
print("VLLM:", vllm_dtype, vllm_main_score)
117+
print("SentenceTransformers:", st_dtype, st_main_score)
118118
print("Difference:", st_main_score - vllm_main_score)
119119

120120
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)

tests/models/language/pooling/test_gte.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,21 @@
1111
########## BertModel
1212
EmbedModelInfo("thenlper/gte-large",
1313
architecture="BertModel",
14-
dtype="float32",
1514
enable_test=True),
1615
EmbedModelInfo("thenlper/gte-base",
1716
architecture="BertModel",
18-
dtype="float32",
1917
enable_test=False),
2018
EmbedModelInfo("thenlper/gte-small",
2119
architecture="BertModel",
22-
dtype="float32",
2320
enable_test=False),
2421
EmbedModelInfo("thenlper/gte-large-zh",
2522
architecture="BertModel",
26-
dtype="float32",
2723
enable_test=False),
2824
EmbedModelInfo("thenlper/gte-base-zh",
2925
architecture="BertModel",
30-
dtype="float32",
3126
enable_test=False),
3227
EmbedModelInfo("thenlper/gte-small-zh",
3328
architecture="BertModel",
34-
dtype="float32",
3529
enable_test=False),
3630
########### NewModel
3731
EmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
@@ -46,7 +40,6 @@
4640
########### Qwen2ForCausalLM
4741
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
4842
architecture="Qwen2ForCausalLM",
49-
dtype="float32",
5043
enable_test=True),
5144
########## ModernBertModel
5245
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
4+
from ...utils import EmbedModelInfo
5+
from .embed_utils import correctness_test_embed_models
6+
from .mteb_utils import mteb_test_embed_models
7+
8+
MODELS = [
9+
########## BertModel
10+
EmbedModelInfo("intfloat/e5-small",
11+
architecture="BertModel",
12+
enable_test=True),
13+
EmbedModelInfo("intfloat/e5-base",
14+
architecture="BertModel",
15+
enable_test=False),
16+
EmbedModelInfo("intfloat/e5-large",
17+
architecture="BertModel",
18+
enable_test=False),
19+
EmbedModelInfo("intfloat/multilingual-e5-small",
20+
architecture="BertModel",
21+
enable_test=False),
22+
########## XLMRobertaModel
23+
EmbedModelInfo("intfloat/multilingual-e5-base",
24+
architecture="XLMRobertaModel",
25+
enable_test=True),
26+
EmbedModelInfo("intfloat/multilingual-e5-large",
27+
architecture="XLMRobertaModel",
28+
enable_test=False),
29+
EmbedModelInfo("intfloat/multilingual-e5-large-instruct",
30+
architecture="XLMRobertaModel",
31+
enable_test=False),
32+
]
33+
34+
35+
@pytest.mark.parametrize("model_info", MODELS)
36+
def test_embed_models_mteb(hf_runner, vllm_runner,
37+
model_info: EmbedModelInfo) -> None:
38+
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
39+
40+
41+
@pytest.mark.parametrize("model_info", MODELS)
42+
def test_embed_models_correctness(hf_runner, vllm_runner,
43+
model_info: EmbedModelInfo,
44+
example_prompts) -> None:
45+
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
46+
example_prompts)

tests/models/language/pooling/test_jina.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
EMBEDDING_MODELS = [
3333
EmbedModelInfo("jinaai/jina-embeddings-v3",
3434
architecture="XLMRobertaModel",
35-
is_matryoshka=True,
36-
dtype="float32")
35+
is_matryoshka=True)
3736
]
3837

3938

tests/models/language/pooling/test_nomic.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,15 @@
99
MODELS = [
1010
EmbedModelInfo("nomic-ai/nomic-embed-text-v1",
1111
architecture="NomicBertModel",
12-
dtype="float32",
1312
enable_test=True),
1413
EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
1514
architecture="NomicBertModel",
16-
dtype="float32",
1715
enable_test=False),
1816
EmbedModelInfo("nomic-ai/CodeRankEmbed",
1917
architecture="NomicBertModel",
2018
enable_test=False),
2119
EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
2220
architecture="NomicBertModel",
23-
dtype="float32",
2421
enable_test=True)
2522
]
2623

vllm/model_executor/models/bert.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -414,10 +414,15 @@ def forward(
414414
intermediate_tensors: Optional[IntermediateTensors] = None,
415415
inputs_embeds: Optional[torch.Tensor] = None,
416416
) -> torch.Tensor:
417-
return self.model(input_ids=input_ids,
418-
position_ids=positions,
419-
inputs_embeds=inputs_embeds,
420-
intermediate_tensors=intermediate_tensors)
417+
hidden_states = self.model(input_ids=input_ids,
418+
position_ids=positions,
419+
inputs_embeds=inputs_embeds,
420+
intermediate_tensors=intermediate_tensors)
421+
422+
# convert the embedding output to float32,
423+
# otherwise precision will be lost significantly
424+
hidden_states = hidden_states.to(torch.float32)
425+
return hidden_states
421426

422427
def pooler(
423428
self,

vllm/model_executor/models/bert_with_rope.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,12 @@ def forward(
432432
else:
433433
hidden_states = self.embeddings(input_ids=input_ids,
434434
token_type_ids=token_type_ids)
435-
return self.encoder(positions, hidden_states)
435+
hidden_states = self.encoder(positions, hidden_states)
436+
437+
# convert the embedding output to float32,
438+
# otherwise precision will be lost significantly
439+
hidden_states = hidden_states.to(torch.float32)
440+
return hidden_states
436441

437442
def load_weights(self, weights: Iterable[tuple[str,
438443
torch.Tensor]]) -> set[str]:

0 commit comments

Comments
 (0)