22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44from collections .abc import Iterable
5+ from contextlib import contextmanager
56from typing import TYPE_CHECKING , Any , TypeVar , cast
67
78import torch
@@ -373,6 +374,76 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
373374 text_config .use_sep_token = use_sep_token
374375
375376
377+ def _get_language_model_for_seq_cls (model ) -> nn .Module :
378+ """
379+ Get the language model component for sequence classification conversion.
380+ For VLMs, returns the inner language model. For standard LLMs, returns model itself.
381+ """
382+ if supports_multimodal (model ):
383+ try :
384+ lm = model .get_language_model ()
385+ if lm is not model :
386+ return lm
387+ except Exception :
388+ pass
389+
390+ for attr_name in ("language_model" , "lm" , "text_model" ):
391+ if hasattr (model , attr_name ):
392+ candidate = getattr (model , attr_name )
393+ if (
394+ isinstance (candidate , nn .Module )
395+ and candidate is not model
396+ and hasattr (candidate , "model" )
397+ ):
398+ return candidate
399+
400+ for name , child in model .named_children ():
401+ child_name = type (child ).__name__
402+ if ("ForCausalLM" in child_name or "LMHead" in child_name ) and hasattr (
403+ child , "model"
404+ ):
405+ return child
406+
407+ return model
408+
409+
410+ @contextmanager
411+ def _disable_seq_cls_loading_on_inner_model (language_model , is_vlm : bool ):
412+ """
413+ Context manager to temporarily disable sequence classification loading
414+ on inner VLM models to prevent recursive seq_cls_model_loader calls.
415+ """
416+ if not is_vlm :
417+ yield
418+ return
419+
420+ inner_hf_config = getattr (language_model , "config" , None )
421+ if inner_hf_config is None :
422+ yield
423+ return
424+
425+ inner_text_config = inner_hf_config .get_text_config ()
426+ original_method = getattr (inner_text_config , "method" , None )
427+ original_tokens = getattr (inner_text_config , "classifier_from_token" , None )
428+ original_hf_tokens = getattr (inner_hf_config , "classifier_from_token" , None )
429+
430+ try :
431+ if original_method is not None :
432+ inner_text_config .method = None
433+ if original_tokens is not None :
434+ inner_text_config .classifier_from_token = None
435+ if original_hf_tokens is not None :
436+ inner_hf_config .classifier_from_token = None
437+ yield
438+ finally :
439+ if original_method is not None :
440+ inner_text_config .method = original_method
441+ if original_tokens is not None :
442+ inner_text_config .classifier_from_token = original_tokens
443+ if original_hf_tokens is not None :
444+ inner_hf_config .classifier_from_token = original_hf_tokens
445+
446+
376447def load_weights_using_from_2_way_softmax (
377448 model , weights : Iterable [tuple [str , torch .Tensor ]]
378449):
@@ -393,9 +464,9 @@ def load_weights_using_from_2_way_softmax(
393464 tokens = cast (list [int ], tokens )
394465 assert len (tokens ) == 2
395466
396- language_model = (
397- model . get_language_model () if hasattr ( model , "get_language_model" ) else model
398- )
467+ language_model = _get_language_model_for_seq_cls ( model )
468+ is_vlm = language_model is not model
469+
399470 language_model .lm_head = ParallelLMHead (
400471 text_config .vocab_size , text_config .hidden_size , quant_config = quant_config
401472 )
@@ -411,12 +482,13 @@ def load_weights_using_from_2_way_softmax(
411482 )
412483 language_model .lm_head = language_model .lm_head .tie_weights (embed_tokens )
413484
414- # ModelForPooling is dynamically defined inside the _create_pooling_model_cls
415- # function, so we need use this hacky method to obtain it.
416- pooling_model_cls = next (
417- x for x in type (model ).__mro__ if x .__name__ == "ModelForPooling"
418- )
419- loaded_weights = pooling_model_cls .load_weights (model , weights )
485+ with _disable_seq_cls_loading_on_inner_model (language_model , is_vlm ):
486+ # ModelForPooling is dynamically defined inside the _create_pooling_model_cls
487+ # function, so we need use this hacky method to obtain it.
488+ pooling_model_cls = next (
489+ x for x in type (model ).__mro__ if x .__name__ == "ModelForPooling"
490+ )
491+ loaded_weights = pooling_model_cls .load_weights (model , weights )
420492
421493 from vllm .tokenizers import get_tokenizer
422494
@@ -434,12 +506,15 @@ def load_weights_using_from_2_way_softmax(
434506 torch .float32
435507 ) - lm_head_weight .data [[false_id ]].to (torch .float32 )
436508
437- param = model .score .weight
509+ score_layer = language_model .score if is_vlm else model .score
510+ param = score_layer .weight
438511 weight_loader = getattr (param , "weight_loader" , default_weight_loader )
439512 weight_loader (param , score_weight )
440513
441514 del language_model .lm_head
442- loaded_weights .add ("score.weight" )
515+
516+ score_weight_name = "language_model.score.weight" if is_vlm else "score.weight"
517+ loaded_weights .add (score_weight_name )
443518
444519 lm_head_name = "lm_head.weight"
445520 if hf_to_vllm_mapper := getattr (model , "hf_to_vllm_mapper" , None ):
@@ -460,22 +535,30 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
460535 tokens = cast (list [int ], tokens )
461536 assert len (tokens ) > 0
462537
463- model .lm_head = ParallelLMHead (
538+ language_model = _get_language_model_for_seq_cls (model )
539+ is_vlm = language_model is not model
540+
541+ language_model .lm_head = ParallelLMHead (
464542 text_config .vocab_size , text_config .hidden_size , quant_config = quant_config
465543 )
466544 if text_config .tie_word_embeddings :
467545 # embed_tokens is the assumed name for input embeddings. If the model does not
468546 # have this attribute, we fall back to get_input_embeddings(), which is used by
469547 # the Transformers modeling backend.
548+ text_backbone = language_model .model
470549 embed_tokens = (
471- model . model .embed_tokens
472- if hasattr (model . model , "embed_tokens" )
473- else model . model .get_input_embeddings ()
550+ text_backbone .embed_tokens
551+ if hasattr (text_backbone , "embed_tokens" )
552+ else text_backbone .get_input_embeddings ()
474553 )
475- model .lm_head = model .lm_head .tie_weights (embed_tokens )
554+ language_model .lm_head = language_model .lm_head .tie_weights (embed_tokens )
476555
477- # Skip ModelForSequenceClassification in MRO to avoid infinite recursion
478- loaded_weights = type (model ).__mro__ [1 ].load_weights (model , weights )
556+ with _disable_seq_cls_loading_on_inner_model (language_model , is_vlm ):
557+ pooling_model_cls = next (
558+ x for x in type (model ).__mro__ if x .__name__ == "ModelForPooling"
559+ )
560+ # Skip ModelForSequenceClassification in MRO to avoid infinite recursion
561+ loaded_weights = pooling_model_cls .load_weights (model , weights )
479562
480563 from vllm .tokenizers import get_tokenizer
481564
@@ -487,15 +570,22 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
487570 )
488571
489572 token_ids = [tokenizer .convert_tokens_to_ids (t ) for t in tokens ]
490- score_weight = model .lm_head .weight .data [token_ids ]
573+ score_weight = language_model .lm_head .weight .data [token_ids ]
491574
492- param = model .score .weight
575+ score_layer = language_model .score if is_vlm else model .score
576+ param = score_layer .weight
493577 weight_loader = getattr (param , "weight_loader" , default_weight_loader )
494578 weight_loader (param , score_weight )
495579
496- del model .lm_head
497- loaded_weights .add ("score.weight" )
498- loaded_weights .discard ("lm_head.weight" )
580+ del language_model .lm_head
581+
582+ score_weight_name = "language_model.score.weight" if is_vlm else "score.weight"
583+ loaded_weights .add (score_weight_name )
584+
585+ lm_head_name = "lm_head.weight"
586+ if hf_to_vllm_mapper := getattr (model , "hf_to_vllm_mapper" , None ):
587+ lm_head_name = hf_to_vllm_mapper ._map_name (lm_head_name )
588+ loaded_weights .discard (lm_head_name )
499589 return loaded_weights
500590
501591
0 commit comments