99import numpy as np
1010import pytest
1111import requests
12+ import torch
1213
1314from tests .models .utils import (EmbedModelInfo , RerankModelInfo ,
1415 check_embeddings_close )
@@ -165,16 +166,19 @@ def mteb_test_embed_models(hf_runner,
165166 vllm_extra_kwargs = None ,
166167 hf_model_callback = None ,
167168 atol = MTEB_EMBED_TOL ):
169+ # A model family has many models with the same architecture,
170+ # and we don't need to test each one.
168171 if not model_info .enable_test :
169- # A model family has many models with the same architecture,
170- # and we don't need to test each one.
171172 pytest .skip ("Skipping test." )
172173
173- example_prompts = ["The chef prepared a delicious meal." ]
174+ # Test embed_dims, isnan and whether to use normalize
175+ example_prompts = ["The chef prepared a delicious meal." * 1000 ]
174176
177+ # Allow vllm to test using the given dtype, such as float32
175178 vllm_extra_kwargs = vllm_extra_kwargs or {}
176179 vllm_extra_kwargs ["dtype" ] = model_info .dtype
177180
181+ # Allow vllm to test using hf_overrides
178182 if model_info .hf_overrides is not None :
179183 vllm_extra_kwargs ["hf_overrides" ] = model_info .hf_overrides
180184
@@ -186,21 +190,32 @@ def mteb_test_embed_models(hf_runner,
186190
187191 model_config = vllm_model .llm .llm_engine .model_config
188192
193+ # Confirm whether vllm is using the correct architecture
189194 if model_info .architecture :
190195 assert model_info .architecture in model_config .architectures
196+
197+ # Confirm whether vllm uses the correct default_pooling_type, which
198+ # relates to whether chunked prefill and prefix caching are enabled
191199 assert (model_config ._model_info .default_pooling_type ==
192200 model_info .default_pooling_type )
193201
194202 vllm_main_score = run_mteb_embed_task (VllmMtebEncoder (vllm_model ),
195203 MTEB_EMBED_TASKS )
196204 vllm_dtype = vllm_model .llm .llm_engine .model_config .dtype
197- vllm_outputs = vllm_model .embed (example_prompts )
198205
206+ # Test embed_dims, isnan and whether to use normalize
207+ vllm_outputs = vllm_model .embed (example_prompts ,
208+ truncate_prompt_tokens = - 1 )
209+ assert not torch .any (torch .isnan (torch .tensor (vllm_outputs )))
210+
211+ # Accelerate mteb test by setting
212+ # SentenceTransformers mteb score to a constant
199213 if model_info .mteb_score is None :
200214 with hf_runner (model_info .name ,
201215 is_sentence_transformer = True ,
202216 dtype = "float32" ) as hf_model :
203217
218+ # e.g. setting default parameters for the encode method of hf_runner
204219 if hf_model_callback is not None :
205220 hf_model_callback (hf_model )
206221
@@ -299,14 +314,16 @@ def mteb_test_rerank_models(hf_runner,
299314 hf_model_callback = None ,
300315 vllm_mteb_encoder = VllmMtebEncoder ,
301316 atol = MTEB_RERANK_TOL ):
317+ # A model family has many models with the same architecture,
318+ # and we don't need to test each one.
302319 if not model_info .enable_test :
303- # A model family has many models with the same architecture,
304- # and we don't need to test each one.
305320 pytest .skip ("Skipping test." )
306321
322+ # Allow vllm to test using the given dtype, such as float32
307323 vllm_extra_kwargs = vllm_extra_kwargs or {}
308324 vllm_extra_kwargs ["dtype" ] = model_info .dtype
309325
326+ # Allow vllm to test using hf_overrides
310327 if model_info .hf_overrides is not None :
311328 vllm_extra_kwargs ["hf_overrides" ] = model_info .hf_overrides
312329
@@ -319,9 +336,15 @@ def mteb_test_rerank_models(hf_runner,
319336
320337 model_config = vllm_model .llm .llm_engine .model_config
321338
339+ # Confirm whether vllm is using the correct architecture
322340 if model_info .architecture :
323341 assert (model_info .architecture in model_config .architectures )
342+
343+ # Score API is only enabled for num_labels == 1
324344 assert model_config .hf_config .num_labels == 1
345+
346+ # Confirm whether vllm uses the correct default_pooling_type, which
347+ # relates to whether chunked prefill and prefix caching are enabled
325348 assert (model_config ._model_info .default_pooling_type ==
326349 model_info .default_pooling_type )
327350
@@ -330,6 +353,8 @@ def mteb_test_rerank_models(hf_runner,
330353 languages = MTEB_RERANK_LANGS )
331354 vllm_dtype = model_config .dtype
332355
356+ # Accelerate mteb test by setting
357+ # SentenceTransformers mteb score to a constant
333358 if model_info .mteb_score is None :
334359 st_main_score , st_dtype = mteb_test_rerank_models_hf (
335360 hf_runner , model_info .name , hf_model_callback )
0 commit comments