55import numpy as np
66import torch
77import torch .nn as nn
8- from transformers import (AutoConfig , AutoModel , AutoProcessor , AutoTokenizer ,
9- LlavaNextConfig , PretrainedConfig , PreTrainedModel )
10- from transformers .modeling_utils import load_sharded_checkpoint
8+ from transformers import (AutoProcessor , AutoTokenizer , LlavaNextConfig ,
9+ PretrainedConfig , PreTrainedModel )
1110from transformers .models .llava_next .modeling_llava_next import (
1211 LlavaNextMultiModalProjector , get_anyres_image_grid_shape ,
1312 image_size_to_num_patches , unpad_image )
1413
14+ from tensorrt_llm ._torch .models .checkpoints .base_weight_mapper import \
15+ BaseWeightMapper
16+ from tensorrt_llm ._torch .models .checkpoints .hf .llava_next_weight_mapper import \
17+ LlavaNextHfWeightMapper
1518from tensorrt_llm .inputs .multimodal import MultimodalParams
1619
1720from ...inputs import (BaseMultimodalInputProcessor , ExtraProcessedInputs ,
1821 InputProcessor , MultimodalPlaceholderMetadata ,
1922 MultimodalPlaceholderPlacement , TextPrompt ,
2023 register_input_processor ,
2124 support_multimodal_disaggregated )
22- from ...llmapi .utils import download_hf_model
2325from ...logger import logger
2426from ...sampling_params import SamplingParams
2527from ..attention_backend import AttentionMetadata
2830from .modeling_clip import CLIPVisionModel
2931from .modeling_multimodal_utils import (find_input_mm_embeds , fuse_input_embeds ,
3032 get_multimodal_embeddings )
31- from .modeling_utils import (filter_weights , register_auto_model ,
32- register_vision_encoder )
33+ from .modeling_utils import register_auto_model , register_vision_encoder
3334
3435DISAGG = os .getenv ('TLLM_MULTIMODAL_DISAGGREGATED' , '0' ) == '1'
3536
@@ -295,62 +296,36 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
295296 super ().__init__ ()
296297 self .model_config = model_config
297298 self .pretrained_config = model_config .pretrained_config
298- # TODO: use config.mapping.get_local_rank() instead
299- self .device = f"cuda:{ torch .cuda .current_device ()} "
300- model_path = self .pretrained_config ._name_or_path
301299
302- # Determine the actual local path for model files
303- if os .path .isdir (model_path ):
304- local_model_path = model_path
305- else :
306- local_model_path = download_hf_model (model_path )
307-
308- # Partially load the model to reduce memory usage(Vision tower and multi-modal projector)
309- hf_model_config = AutoConfig .from_pretrained (local_model_path )
310- self .dtype = hf_model_config .text_config .torch_dtype
311- module_dict = nn .ModuleDict ({
312- "vision_tower" :
313- AutoModel .from_config (hf_model_config .vision_config ),
314- "multi_modal_projector" :
315- LlavaNextMultiModalProjector (hf_model_config )
316- })
317- module_dict .register_parameter (
318- "image_newline" ,
319- nn .Parameter (torch .empty (hf_model_config .text_config .hidden_size )))
320-
321- missing_keys , _ = load_sharded_checkpoint (module_dict ,
322- local_model_path ,
323- strict = False )
324- assert len (missing_keys ) == 0 , f"Missing keys: { missing_keys } "
325- hf_vision_tower = module_dict ["vision_tower" ].to (self .dtype )
326- hf_mm_projector = module_dict ["multi_modal_projector" ].to (
327- self .dtype ).to (self .device )
328- hf_image_newline = module_dict .image_newline .to (self .dtype ).to (
329- self .device )
330-
331- # For A100 GPU, fallback to HF vision tower due to accuracy issue in TRT-LLM CLIPAttention
332- # Otherwise, use TRTLLM vision tower(CLIPVisionModel)
333- prop = torch .cuda .get_device_properties (0 )
334- sm_version = prop .major * 10 + prop .minor
335- self .use_hf_vision_tower = sm_version == 80
336- if self .use_hf_vision_tower :
337- self .vision_tower = hf_vision_tower .to (self .device )
338- else :
339- vision_model_config = ModelConfig (
340- pretrained_config = self .pretrained_config .vision_config ,
341- attn_backend = "TRTLLM" )
342- self .vision_tower = CLIPVisionModel (vision_model_config ).to (
343- self .device ).to (self .dtype )
344- self .vision_tower .load_weights (hf_vision_tower .state_dict ())
345-
346- # Use HF multi-modal projector
347- self .mm_projector = hf_mm_projector
348- self .image_newline = hf_image_newline
300+ clip_model_config = copy .deepcopy (self .model_config )
301+ clip_model_config .pretrained_config = self .model_config .pretrained_config .vision_config
302+ self .dtype = self .model_config .pretrained_config .text_config .torch_dtype
303+ self .vision_model = CLIPVisionModel (clip_model_config ).to (self .dtype )
304+ self .mm_projector = LlavaNextMultiModalProjector (
305+ self .pretrained_config ).to (self .dtype )
306+ self .image_newline = nn .Parameter (torch .empty (
307+ self .pretrained_config .text_config .hidden_size ),
308+ requires_grad = False ).to (self .dtype )
349309 self .vision_feature_select_strategy = getattr (
350310 self .pretrained_config , "vision_feature_select_strategy" , "default" )
351-
352311 self .post_config ()
353312
313+ def load_weights (self , weights ):
314+
315+ def filter_weights (prefix , weights : Dict ):
316+ result = {}
317+ for key , weight in weights .items ():
318+ if key .startswith (prefix ):
319+ new_key = key [len (prefix ):]
320+ result [new_key ] = weight
321+ return result
322+
323+ visual_model_weights = filter_weights ("vision_tower." , weights )
324+ self .vision_model .load_weights (visual_model_weights )
325+ mm_projector_weights = filter_weights ("multi_modal_projector." , weights )
326+ self .mm_projector .load_state_dict (mm_projector_weights , strict = True )
327+ self .image_newline .data .copy_ (weights ["image_newline" ])
328+
354329 def post_config (self ):
355330 self .config = self .pretrained_config .vision_config
356331
@@ -464,7 +439,6 @@ def forward(self, multimodal_params: List[MultimodalParams]):
464439 for multimodal_param in multimodal_params
465440 ]
466441 pixel_values = self ._pad_for_batching (pixel_values )
467-
468442 pixel_values = torch .cat (pixel_values , dim = 0 )
469443 image_sizes = torch .cat (image_sizes , dim = 0 )
470444
@@ -484,23 +458,18 @@ def forward(self, multimodal_params: List[MultimodalParams]):
484458 ]
485459 pixel_values = torch .cat (_pixel_values_list , dim = 0 )
486460
487- if self .use_hf_vision_tower :
488- image_features = self .vision_tower (
489- pixel_values , output_hidden_states = True ).hidden_states
490- else :
491- attn_metadata = self .vision_tower .prepare_attn_metadata (
492- pixel_values .shape [0 ])
493- image_features = self .vision_tower (
494- pixel_values ,
495- attn_metadata = attn_metadata ,
496- )
461+ attn_metadata = self .vision_model .prepare_attn_metadata (
462+ pixel_values .shape [0 ])
463+ image_features = self .vision_model (
464+ pixel_values ,
465+ attn_metadata = attn_metadata ,
466+ )
497467 selected_image_feature = image_features [- 2 ][:, 1 :]
498468 image_features = self .mm_projector (selected_image_feature )
499-
500469 image_features = torch .split (image_features , image_num_patches , dim = 0 )
501470
502- # NOTE: 'pack_image_features' is directly copied from the HF's code
503- image_features , feature_lens = self .pack_image_features (
471+ # NOTE: 'pack_image_features' is from the HF's code
472+ image_features , _ = self .pack_image_features (
504473 image_features ,
505474 image_sizes ,
506475 vision_feature_select_strategy = self .vision_feature_select_strategy ,
@@ -526,6 +495,7 @@ class LlavaNextModel(PreTrainedModel):
526495 def __init__ (self , model_config : ModelConfig [PretrainedConfig ], * args ,
527496 ** kwargs ) -> None :
528497 config = model_config .pretrained_config
498+ self ._supports_sdpa = True
529499 super ().__init__ (config )
530500 if hasattr (self , "llm" ):
531501 return
@@ -543,16 +513,29 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
543513 self .llm = AutoModelForCausalLM .from_config (llm_model_config )
544514
545515 self .model_config = model_config
546- self .model_dtype = getattr (config .text_config , "torch_dtype" ,
547- torch .float16 )
548- logger .info (f"{ self .dtype = } { self .model_dtype = } " )
549-
550516 self .post_config ()
551- self .is_loaded = True
552517
553- def load_weights (self , weights ):
554- weights = filter_weights ("language_model" , weights )
555- self .llm .load_weights (weights )
518+ def load_weights (self , weights , weight_mapper : BaseWeightMapper ):
519+ if isinstance (weight_mapper , LlavaNextHfWeightMapper ):
520+ weights = weight_mapper .preprocess_weights (weights )
521+
522+ self .mm_encoder .load_weights (weights )
523+
524+ def filter_weights (weights : Dict ):
525+ transformed_weights = {}
526+ for key , weight in weights .items ():
527+ if key .startswith ("language_model." ):
528+ if isinstance (weight_mapper , LlavaNextHfWeightMapper ):
529+ new_key = "model." + key [len ("language_model." ):]
530+ else :
531+ new_key = key [len ("language_model." ):]
532+ transformed_weights [new_key ] = weight
533+ elif key .startswith ("lm_head." ):
534+ transformed_weights [key ] = weight
535+ return transformed_weights
536+
537+ language_model_weights = filter_weights (weights )
538+ self .llm .load_weights (language_model_weights )
556539
557540 def post_config (self ):
558541 self .config = self .llm .config
@@ -590,7 +573,6 @@ def forward(
590573 mm_embeds , multimodal_params [:num_context_requests ])
591574 input_ids , inputs_embeds = fuse_input_embeds (
592575 self .llm .model .embed_tokens , input_ids , mm_embeds , ** kwargs )
593-
594576 logits = self .llm .forward (attn_metadata , input_ids , position_ids ,
595577 inputs_embeds , return_context_logits )
596578 return logits
0 commit comments