Skip to content

Commit a8eb118

Browse files
[CI][Models] Add VLM Support for Sequence Classification Conversion (vllm-project#32885)
Signed-off-by: Andreas Karatzas <[email protected]>
1 parent fa6e599 commit a8eb118

File tree

3 files changed

+155
-39
lines changed

3 files changed

+155
-39
lines changed

vllm/model_executor/layers/layernorm.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -278,21 +278,35 @@ def __init__(
278278
self.variance_epsilon = eps
279279

280280
@staticmethod
281-
def forward_static(
281+
def _forward_static_no_residual(
282282
weight: torch.Tensor,
283283
variance_epsilon: float,
284284
x: torch.Tensor,
285-
residual: torch.Tensor | None,
286-
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
287-
"""PyTorch-native implementation equivalent to forward()."""
285+
) -> torch.Tensor:
286+
"""PyTorch-native implementation equivalent to forward() without residual."""
288287
orig_dtype = x.dtype
289-
if residual is not None:
290-
x = (
291-
x.float() + residual.float()
292-
if orig_dtype == torch.float16
293-
else x + residual
294-
)
295-
residual = x
288+
x = x.float()
289+
variance = x.pow(2).mean(dim=-1, keepdim=True)
290+
x = x * torch.rsqrt(variance + variance_epsilon)
291+
x = x * (1.0 + weight.float())
292+
x = x.to(orig_dtype)
293+
return x
294+
295+
@staticmethod
296+
def _forward_static_with_residual(
297+
weight: torch.Tensor,
298+
variance_epsilon: float,
299+
x: torch.Tensor,
300+
residual: torch.Tensor,
301+
) -> tuple[torch.Tensor, torch.Tensor]:
302+
"""PyTorch-native implementation equivalent to forward() with residual."""
303+
orig_dtype = x.dtype
304+
x = (
305+
x.float() + residual.float()
306+
if orig_dtype == torch.float16
307+
else x + residual
308+
)
309+
residual = x
296310

297311
x = x.float()
298312
variance = x.pow(2).mean(dim=-1, keepdim=True)
@@ -301,15 +315,22 @@ def forward_static(
301315
# See https://github.com/huggingface/transformers/pull/29402
302316
x = x * (1.0 + weight.float())
303317
x = x.to(orig_dtype)
304-
return x if residual is None else (x, residual)
318+
return x, residual
305319

306320
def forward_native(
307321
self,
308322
x: torch.Tensor,
309323
residual: torch.Tensor | None = None,
310324
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
311325
"""PyTorch-native implementation equivalent to forward()."""
312-
return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
326+
if residual is None:
327+
return self._forward_static_no_residual(
328+
self.weight.data, self.variance_epsilon, x
329+
)
330+
else:
331+
return self._forward_static_with_residual(
332+
self.weight.data, self.variance_epsilon, x, residual
333+
)
313334

314335
def forward_cuda(
315336
self,
@@ -320,8 +341,11 @@ def forward_cuda(
320341
return self.forward_native(x, residual)
321342

322343
if not getattr(self, "_is_compiled", False):
323-
self.forward_static = torch.compile( # type: ignore
324-
self.forward_static
344+
self._forward_static_no_residual = torch.compile( # type: ignore
345+
self._forward_static_no_residual
346+
)
347+
self._forward_static_with_residual = torch.compile( # type: ignore
348+
self._forward_static_with_residual
325349
)
326350
self._is_compiled = True
327351
return self.forward_native(x, residual)

vllm/model_executor/models/adapters.py

Lines changed: 113 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from collections.abc import Iterable
5+
from contextlib import contextmanager
56
from typing import TYPE_CHECKING, Any, TypeVar, cast
67

78
import torch
@@ -373,6 +374,76 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
373374
text_config.use_sep_token = use_sep_token
374375

375376

377+
def _get_language_model_for_seq_cls(model) -> nn.Module:
378+
"""
379+
Get the language model component for sequence classification conversion.
380+
For VLMs, returns the inner language model. For standard LLMs, returns model itself.
381+
"""
382+
if supports_multimodal(model):
383+
try:
384+
lm = model.get_language_model()
385+
if lm is not model:
386+
return lm
387+
except Exception:
388+
pass
389+
390+
for attr_name in ("language_model", "lm", "text_model"):
391+
if hasattr(model, attr_name):
392+
candidate = getattr(model, attr_name)
393+
if (
394+
isinstance(candidate, nn.Module)
395+
and candidate is not model
396+
and hasattr(candidate, "model")
397+
):
398+
return candidate
399+
400+
for name, child in model.named_children():
401+
child_name = type(child).__name__
402+
if ("ForCausalLM" in child_name or "LMHead" in child_name) and hasattr(
403+
child, "model"
404+
):
405+
return child
406+
407+
return model
408+
409+
410+
@contextmanager
411+
def _disable_seq_cls_loading_on_inner_model(language_model, is_vlm: bool):
412+
"""
413+
Context manager to temporarily disable sequence classification loading
414+
on inner VLM models to prevent recursive seq_cls_model_loader calls.
415+
"""
416+
if not is_vlm:
417+
yield
418+
return
419+
420+
inner_hf_config = getattr(language_model, "config", None)
421+
if inner_hf_config is None:
422+
yield
423+
return
424+
425+
inner_text_config = inner_hf_config.get_text_config()
426+
original_method = getattr(inner_text_config, "method", None)
427+
original_tokens = getattr(inner_text_config, "classifier_from_token", None)
428+
original_hf_tokens = getattr(inner_hf_config, "classifier_from_token", None)
429+
430+
try:
431+
if original_method is not None:
432+
inner_text_config.method = None
433+
if original_tokens is not None:
434+
inner_text_config.classifier_from_token = None
435+
if original_hf_tokens is not None:
436+
inner_hf_config.classifier_from_token = None
437+
yield
438+
finally:
439+
if original_method is not None:
440+
inner_text_config.method = original_method
441+
if original_tokens is not None:
442+
inner_text_config.classifier_from_token = original_tokens
443+
if original_hf_tokens is not None:
444+
inner_hf_config.classifier_from_token = original_hf_tokens
445+
446+
376447
def load_weights_using_from_2_way_softmax(
377448
model, weights: Iterable[tuple[str, torch.Tensor]]
378449
):
@@ -393,9 +464,9 @@ def load_weights_using_from_2_way_softmax(
393464
tokens = cast(list[int], tokens)
394465
assert len(tokens) == 2
395466

396-
language_model = (
397-
model.get_language_model() if hasattr(model, "get_language_model") else model
398-
)
467+
language_model = _get_language_model_for_seq_cls(model)
468+
is_vlm = language_model is not model
469+
399470
language_model.lm_head = ParallelLMHead(
400471
text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
401472
)
@@ -411,12 +482,13 @@ def load_weights_using_from_2_way_softmax(
411482
)
412483
language_model.lm_head = language_model.lm_head.tie_weights(embed_tokens)
413484

414-
# ModelForPooling is dynamically defined inside the _create_pooling_model_cls
415-
# function, so we need use this hacky method to obtain it.
416-
pooling_model_cls = next(
417-
x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
418-
)
419-
loaded_weights = pooling_model_cls.load_weights(model, weights)
485+
with _disable_seq_cls_loading_on_inner_model(language_model, is_vlm):
486+
# ModelForPooling is dynamically defined inside the _create_pooling_model_cls
487+
# function, so we need use this hacky method to obtain it.
488+
pooling_model_cls = next(
489+
x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
490+
)
491+
loaded_weights = pooling_model_cls.load_weights(model, weights)
420492

421493
from vllm.tokenizers import get_tokenizer
422494

@@ -434,12 +506,15 @@ def load_weights_using_from_2_way_softmax(
434506
torch.float32
435507
) - lm_head_weight.data[[false_id]].to(torch.float32)
436508

437-
param = model.score.weight
509+
score_layer = language_model.score if is_vlm else model.score
510+
param = score_layer.weight
438511
weight_loader = getattr(param, "weight_loader", default_weight_loader)
439512
weight_loader(param, score_weight)
440513

441514
del language_model.lm_head
442-
loaded_weights.add("score.weight")
515+
516+
score_weight_name = "language_model.score.weight" if is_vlm else "score.weight"
517+
loaded_weights.add(score_weight_name)
443518

444519
lm_head_name = "lm_head.weight"
445520
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
@@ -460,22 +535,30 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
460535
tokens = cast(list[int], tokens)
461536
assert len(tokens) > 0
462537

463-
model.lm_head = ParallelLMHead(
538+
language_model = _get_language_model_for_seq_cls(model)
539+
is_vlm = language_model is not model
540+
541+
language_model.lm_head = ParallelLMHead(
464542
text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
465543
)
466544
if text_config.tie_word_embeddings:
467545
# embed_tokens is the assumed name for input embeddings. If the model does not
468546
# have this attribute, we fall back to get_input_embeddings(), which is used by
469547
# the Transformers modeling backend.
548+
text_backbone = language_model.model
470549
embed_tokens = (
471-
model.model.embed_tokens
472-
if hasattr(model.model, "embed_tokens")
473-
else model.model.get_input_embeddings()
550+
text_backbone.embed_tokens
551+
if hasattr(text_backbone, "embed_tokens")
552+
else text_backbone.get_input_embeddings()
474553
)
475-
model.lm_head = model.lm_head.tie_weights(embed_tokens)
554+
language_model.lm_head = language_model.lm_head.tie_weights(embed_tokens)
476555

477-
# Skip ModelForSequenceClassification in MRO to avoid infinite recursion
478-
loaded_weights = type(model).__mro__[1].load_weights(model, weights)
556+
with _disable_seq_cls_loading_on_inner_model(language_model, is_vlm):
557+
pooling_model_cls = next(
558+
x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
559+
)
560+
# Skip ModelForSequenceClassification in MRO to avoid infinite recursion
561+
loaded_weights = pooling_model_cls.load_weights(model, weights)
479562

480563
from vllm.tokenizers import get_tokenizer
481564

@@ -487,15 +570,22 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
487570
)
488571

489572
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
490-
score_weight = model.lm_head.weight.data[token_ids]
573+
score_weight = language_model.lm_head.weight.data[token_ids]
491574

492-
param = model.score.weight
575+
score_layer = language_model.score if is_vlm else model.score
576+
param = score_layer.weight
493577
weight_loader = getattr(param, "weight_loader", default_weight_loader)
494578
weight_loader(param, score_weight)
495579

496-
del model.lm_head
497-
loaded_weights.add("score.weight")
498-
loaded_weights.discard("lm_head.weight")
580+
del language_model.lm_head
581+
582+
score_weight_name = "language_model.score.weight" if is_vlm else "score.weight"
583+
loaded_weights.add(score_weight_name)
584+
585+
lm_head_name = "lm_head.weight"
586+
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
587+
lm_head_name = hf_to_vllm_mapper._map_name(lm_head_name)
588+
loaded_weights.discard(lm_head_name)
499589
return loaded_weights
500590

501591

vllm/v1/attention/backends/triton_attn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def mm_prefix_range_tensor(self) -> torch.Tensor | None:
107107
for r in range_lists
108108
]
109109

110-
return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0)
110+
return torch.nested.nested_tensor(
111+
range_tensors, layout=torch.jagged
112+
).to_padded_tensor(0)
111113

112114

113115
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):

0 commit comments

Comments
 (0)