1515
1616from tensorrt_llm .inputs .multimodal import MultimodalParams
1717
18- from ...inputs import (BaseMultimodalInputProcessor , ExtraProcessedInputs ,
19- InputProcessor , MultimodalPlaceholderMetadata ,
18+ from ...inputs import (BaseMultimodalDummyInputsBuilder ,
19+ BaseMultimodalInputProcessor , ExtraProcessedInputs ,
20+ MultimodalPlaceholderMetadata ,
2021 MultimodalPlaceholderPlacement , TextPrompt ,
2122 register_input_processor )
2223from ...logger import logger
@@ -564,33 +565,54 @@ def build_mlp(
564565 return nn .Sequential (* layers )
565566
566567
567- class HCXVisionInputProcessor (BaseMultimodalInputProcessor , InputProcessor ):
568+ class HCXVisionInputProcessor (BaseMultimodalDummyInputsBuilder ,
569+ BaseMultimodalInputProcessor ):
568570
569571 def __init__ (self ,
570572 model_path : str ,
571- model_config : PretrainedConfig ,
573+ config : PretrainedConfig ,
572574 tokenizer : AutoTokenizer ,
573575 trust_remote_code : bool = True ):
574-
575- self .pretrained_config = model_config
576- self .tokenizer = tokenizer
577- self .use_fast = True
578- if self .tokenizer is None :
579- self .tokenizer = AutoTokenizer .from_pretrained (
580- model_path ,
581- trust_remote_code = trust_remote_code ,
582- use_fast = self .use_fast )
583- self .processor = AutoProcessor .from_pretrained (
576+ super ().__init__ ()
577+ self ._config = config
578+ self ._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer .from_pretrained (
579+ model_path ,
580+ trust_remote_code = trust_remote_code ,
581+ use_fast = self .use_fast )
582+ self ._processor = AutoProcessor .from_pretrained (
584583 model_path ,
585584 trust_remote_code = trust_remote_code ,
586585 use_fast = self .use_fast )
587- self .tllm_multimodal_token_id = self .pretrained_config .language_config [
586+ self ._model_path = model_path
587+ self ._dtype = self .config .torch_dtype
588+
589+ self .tllm_multimodal_token_id = self .config .language_config [
588590 "vocab_size" ] + 1
589591 self .vision_query_lengths = None
590592 self ._vision_query_generator = None
591593
594+ @property
595+ def config (self ) -> PretrainedConfig :
596+ return self ._config
597+
598+ @property
599+ def tokenizer (self ) -> AutoTokenizer :
600+ return self ._tokenizer
601+
602+ @property
603+ def model_path (self ) -> str :
604+ return self ._model_path
605+
606+ @property
607+ def processor (self ) -> AutoProcessor :
608+ return self ._processor
609+
610+ @property
611+ def dtype (self ) -> torch .dtype :
612+ return self ._dtype
613+
592614 def get_vocab_size (self ):
593- return self .pretrained_config .language_config ["vocab_size" ]
615+ return self .config .language_config ["vocab_size" ]
594616
595617 def get_num_tokens_per_image (
596618 self ,
@@ -656,8 +678,7 @@ def _post_process(self,
656678 vision_query_lengths = preprocessed_image .get ("vision_query_lengths" ,
657679 None )
658680 non_vision_query_lengths = determine_non_vision_query_lengths (
659- input_ids , self .tokenizer .pad_token_id ,
660- self .pretrained_config .img_start_id )
681+ input_ids , self .tokenizer .pad_token_id , self .config .img_start_id )
661682 batch_size = input_ids .size (0 )
662683
663684 len_inputs_embeds = max ([
@@ -666,19 +687,18 @@ def _post_process(self,
666687 non_vision_query_lengths , vision_query_lengths )
667688 ])
668689
669- len_inputs_embeds = min (self .pretrained_config .decoder_max_length ,
690+ len_inputs_embeds = min (self .config .decoder_max_length ,
670691 len_inputs_embeds )
671692
672- image_cnts = (input_ids == self .pretrained_config .img_start_id ).sum (
673- dim = 1 ).tolist ()
693+ image_cnts = (input_ids == self .config .img_start_id ).sum (dim = 1 ).tolist ()
674694
675695 fused_input_ids = torch .zeros ([batch_size , len_inputs_embeds ],
676696 dtype = input_ids .dtype )
677697 for batch_idx , sample in enumerate (input_ids ):
678698 non_vision_query_length = non_vision_query_lengths [batch_idx ]
679699 sample = sample [:non_vision_query_length + image_cnts [batch_idx ]]
680700
681- mask = (sample == self .pretrained_config .img_start_id )
701+ mask = (sample == self .config .img_start_id )
682702 img_start_ids = mask .nonzero ()
683703 input_start , temp_start = 0 , 0
684704
@@ -779,32 +799,30 @@ class HCXVisionModel(nn.Module):
779799 def __init__ (self , model_config : ModelConfig [PretrainedConfig ]):
780800 super ().__init__ ()
781801 self .model_config = model_config
782- self .pretrained_config = model_config .pretrained_config
802+ self .config = model_config .pretrained_config
783803 siglip_model_config = copy .deepcopy (self .model_config )
784804 siglip_model_config .pretrained_config = self .model_config .pretrained_config .vision_config
785805 self .visual_token_idx = 0 if "siglip" in self .model_config .pretrained_config .vision_config .model_type else 1
786806 self .dtype = self .model_config .pretrained_config .vision_config .torch_dtype
787807 self .vision_model = SiglipVisionModel (siglip_model_config ).to (
788808 self .dtype )
789809 self .mm_projector = HCXVisionCAbstractor (
790- num_queries = self .pretrained_config .num_queries_vis_abstractor ,
791- num_input_tokens = (
792- self .pretrained_config .vision_config .image_size //
793- self .pretrained_config .vision_config .patch_size )** 2 ,
794- encoder_hidden_size = self .pretrained_config .vision_config .
795- hidden_size ,
796- hidden_size = self .pretrained_config .vision_config .hidden_size ,
797- output_hidden_size = self .pretrained_config .hidden_size ,
798- pos_emb = self .pretrained_config .proj_pos_emb ,
799- prenorm = self .pretrained_config .proj_prenorm ,
810+ num_queries = self .config .num_queries_vis_abstractor ,
811+ num_input_tokens = (self .config .vision_config .image_size //
812+ self .config .vision_config .patch_size )** 2 ,
813+ encoder_hidden_size = self .config .vision_config .hidden_size ,
814+ hidden_size = self .config .vision_config .hidden_size ,
815+ output_hidden_size = self .config .hidden_size ,
816+ pos_emb = self .config .proj_pos_emb ,
817+ prenorm = self .config .proj_prenorm ,
800818 ).to (self .dtype )
801819 self .image_newline = nn .Parameter (torch .empty (
802- self .pretrained_config .hidden_size , ),
820+ self .config .hidden_size , ),
803821 requires_grad = False ).to (self .dtype )
804822
805- self .unpad = self .pretrained_config .unpad
806- self .use_nth_layer = self .pretrained_config .use_nth_layer
807- self .anyres = self .pretrained_config .anyres
823+ self .unpad = self .config .unpad
824+ self .use_nth_layer = self .config .use_nth_layer
825+ self .anyres = self .config .anyres
808826 self .possible_resolutions = self ._init_possible_resolutions ()
809827 self .post_config ()
810828
@@ -814,18 +832,18 @@ def post_config(self):
814832
815833 def _init_possible_resolutions (self ):
816834 possible_resolutions = []
817- if self .pretrained_config .anyres :
818- assert self .pretrained_config .max_num_grids > 0
819- for i in range (1 , self .pretrained_config .max_num_grids + 1 ):
820- for j in range (1 , self .pretrained_config .max_num_grids + 1 ):
821- if i == 1 and j == 1 and not self .pretrained_config .use_1x1_grid :
835+ if self .config .anyres :
836+ assert self .config .max_num_grids > 0
837+ for i in range (1 , self .config .max_num_grids + 1 ):
838+ for j in range (1 , self .config .max_num_grids + 1 ):
839+ if i == 1 and j == 1 and not self .config .use_1x1_grid :
822840 continue
823- if i * j <= self .pretrained_config .max_num_grids :
841+ if i * j <= self .config .max_num_grids :
824842 possible_resolutions .append ([i , j ])
825843
826844 possible_resolutions = [[
827- ys * self .pretrained_config .vision_config .image_size ,
828- xs * self .pretrained_config .vision_config .image_size
845+ ys * self .config .vision_config .image_size ,
846+ xs * self .config .vision_config .image_size
829847 ] for ys , xs in possible_resolutions ]
830848 return possible_resolutions
831849
0 commit comments