Skip to content

Commit 29523fd

Browse files
hmellorepwalsh
authored andcommitted
Enable headless models for pooling in the Transformers backend (vllm-project#21767)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 4a6adca commit 29523fd

File tree

5 files changed

+44
-9
lines changed

5 files changed

+44
-9
lines changed

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ def check_available_online(
525525
}
526526

527527
_TRANSFORMERS_BACKEND_MODELS = {
528+
"TransformersModel": _HfExamplesInfo("Qwen/Qwen3-Embedding-0.6B"),
528529
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
529530
"TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
530531
}

tests/models/test_transformers.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def check_implementation(
3434

3535
with runner_test(model, **kwargs_test, **kwargs) as model_test:
3636
model_config = model_test.llm.llm_engine.model_config
37-
assert model_config.architecture == (
38-
model_config._get_transformers_backend_cls())
37+
assert model_config.using_transformers_backend()
3938

4039
outputs_test = model_test.generate_greedy_logprobs(*args)
4140

@@ -135,8 +134,7 @@ def test_quantization(
135134
enforce_eager=True,
136135
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
137136
model_config = vllm_model.llm.llm_engine.model_config
138-
assert model_config.architecture == (
139-
model_config._get_transformers_backend_cls())
137+
assert model_config.using_transformers_backend()
140138

141139
transformers_outputs = vllm_model.generate_greedy_logprobs(
142140
example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs)
@@ -149,6 +147,25 @@ def test_quantization(
149147
)
150148

151149

150+
@pytest.mark.parametrize(
151+
"model",
152+
[
153+
# Layers live in `layers`
154+
"Qwen/Qwen3-Embedding-0.6B",
155+
# Layers live in `model.layers`
156+
"meta-llama/Llama-3.2-1B-Instruct"
157+
],
158+
)
159+
def test_embed_loading(vllm_runner, model):
160+
with vllm_runner(model,
161+
max_model_len=1024,
162+
enforce_eager=True,
163+
runner="pooling",
164+
model_impl="transformers") as model_test:
165+
model_config = model_test.llm.llm_engine.model_config
166+
assert model_config.using_transformers_backend()
167+
168+
152169
@pytest.mark.parametrize(
153170
"model",
154171
["jason9693/Qwen2.5-1.5B-apeach"],
@@ -169,8 +186,7 @@ def test_classify(
169186
dtype=dtype,
170187
model_impl="transformers") as vllm_model:
171188
model_config = vllm_model.llm.llm_engine.model_config
172-
assert model_config.architecture == (
173-
model_config._get_transformers_backend_cls())
189+
assert model_config.using_transformers_backend()
174190

175191
vllm_outputs = vllm_model.classify(example_prompts)
176192

vllm/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -812,12 +812,17 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
812812
def _get_transformers_backend_cls(self) -> str:
813813
"""Determine which Transformers backend class will be used if
814814
`model_impl` is set to `transformers` or `auto`."""
815+
if getattr(self, "runner_type", self.runner) == "pooling":
816+
return "TransformersModel"
815817
if self.hf_config != self.hf_text_config:
816818
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
817819
# probably a composite config, i.e. multimodal
818820
return "TransformersForMultimodalLM"
819-
else:
820-
return "TransformersForCausalLM"
821+
return "TransformersForCausalLM"
822+
823+
def using_transformers_backend(self) -> bool:
824+
"""Check if the model is using the Transformers backend class."""
825+
return self.architecture == self._get_transformers_backend_cls()
821826

822827
@property
823828
def registry(self):

vllm/model_executor/models/registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,9 @@
273273
}
274274

275275
_TRANSFORMERS_BACKEND_MODELS = {
276-
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
276+
"TransformersModel": ("transformers", "TransformersModel"),
277277
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
278+
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
278279
}
279280
# yapf: enable
280281

vllm/model_executor/models/transformers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,18 @@ def load_weights(self, weights: Iterable[tuple[str,
651651
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
652652

653653

654+
@support_torch_compile
655+
class TransformersModel(TransformersBase):
656+
hf_to_vllm_mapper = WeightsMapper(
657+
orig_to_new_prefix={
658+
# Add `model.` prefix for base model checkpoints
659+
"": "model.",
660+
# Remove `model.` from places it should not be
661+
"model.model.": "model.",
662+
"model.score": "score",
663+
})
664+
665+
654666
@support_torch_compile
655667
class TransformersForCausalLM(TransformersBase):
656668

0 commit comments

Comments
 (0)