Skip to content

Commit 67208f1

Browse files
[None][fix] InputProcessor config naming convention fix (NVIDIA#8705)
Signed-off-by: yechank <[email protected]>
1 parent 4fe47fa commit 67208f1

File tree

15 files changed

+606
-355
lines changed

15 files changed

+606
-355
lines changed

tensorrt_llm/_torch/auto_deploy/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class ADInputProcessor(DefaultInputProcessor):
2222
"""
2323

2424
def __init__(self, tokenizer: Optional[TokenizerBase], processor: Optional[Any] = None):
25-
super().__init__(model_path=None, model_config=None, tokenizer=tokenizer)
25+
super().__init__(model_path=None, config=None, tokenizer=tokenizer)
2626
# NOTE: HF's tokenizer/processor that has the apply_chat_template method
2727
self.processor = processor or getattr(tokenizer, "tokenizer", None)
2828

tensorrt_llm/_torch/models/modeling_gemma3vl.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from typing import List, Optional, Tuple
55

66
import torch
7-
from transformers import AutoProcessor, Gemma3Config, PreTrainedModel
7+
from transformers import (AutoProcessor, AutoTokenizer, Gemma3Config,
8+
PretrainedConfig, PreTrainedModel)
89

910
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
1011
BaseWeightMapper
1112

1213
from ..._utils import nvtx_range
13-
from ...inputs import (ExtraProcessedInputs, InputProcessor,
14+
from ...inputs import (BaseMultimodalDummyInputsBuilder,
15+
BaseMultimodalInputProcessor, ExtraProcessedInputs,
1416
MultimodalPlaceholderMetadata,
1517
MultimodalPlaceholderPlacement, TextPrompt,
1618
register_input_processor)
@@ -33,15 +35,43 @@ def _is_disagg() -> bool:
3335
return os.getenv(_MULTIMODAL_ENV_NAME, "0") == "1"
3436

3537

36-
class Gemma3InputProcessor(InputProcessor):
38+
class Gemma3InputProcessor(BaseMultimodalInputProcessor,
39+
BaseMultimodalDummyInputsBuilder):
3740

38-
def __init__(self, model_path, model_config, tokenizer, trust_remote_code):
41+
def __init__(self,
42+
model_path: str,
43+
config: PretrainedConfig,
44+
tokenizer: AutoTokenizer,
45+
trust_remote_code: bool = True):
46+
super().__init__()
47+
self._config = config
48+
self._tokenizer = tokenizer
49+
self._model_path = model_path
50+
self._processor = AutoProcessor.from_pretrained(
51+
model_path,
52+
trust_remote_code=trust_remote_code,
53+
use_fast=self.use_fast)
54+
self._dtype = self.config.torch_dtype
55+
56+
@property
57+
def config(self) -> PretrainedConfig:
58+
return self._config
59+
60+
@property
61+
def tokenizer(self) -> AutoTokenizer:
62+
return self._tokenizer
3963

40-
self.tokenizer = tokenizer
41-
self.processor = AutoProcessor.from_pretrained(
42-
model_path, trust_remote_code=trust_remote_code, use_fast=True)
43-
self.model_config = model_config
44-
self.device = 'cuda'
64+
@property
65+
def model_path(self) -> str:
66+
return self._model_path
67+
68+
@property
69+
def processor(self) -> AutoProcessor:
70+
return self._processor
71+
72+
@property
73+
def dtype(self) -> torch.dtype:
74+
return self._dtype
4575

4676
@nvtx_range("[Vision] preprocess")
4777
def _preprocess(self, inputs):
@@ -59,7 +89,7 @@ def _preprocess(self, inputs):
5989
images=images,
6090
do_rescale=do_rescale,
6191
return_tensors="pt",
62-
device=self.device).to(dtype=torch.bfloat16)
92+
).to(dtype=self.dtype)
6393

6494
input_ids = processor_output["input_ids"]
6595
pixel_values = processor_output.get("pixel_values")

tensorrt_llm/_torch/models/modeling_hyperclovax.py

Lines changed: 63 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616
from tensorrt_llm.inputs.multimodal import MultimodalParams
1717

18-
from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
19-
InputProcessor, MultimodalPlaceholderMetadata,
18+
from ...inputs import (BaseMultimodalDummyInputsBuilder,
19+
BaseMultimodalInputProcessor, ExtraProcessedInputs,
20+
MultimodalPlaceholderMetadata,
2021
MultimodalPlaceholderPlacement, TextPrompt,
2122
register_input_processor)
2223
from ...logger import logger
@@ -564,33 +565,54 @@ def build_mlp(
564565
return nn.Sequential(*layers)
565566

566567

567-
class HCXVisionInputProcessor(BaseMultimodalInputProcessor, InputProcessor):
568+
class HCXVisionInputProcessor(BaseMultimodalDummyInputsBuilder,
569+
BaseMultimodalInputProcessor):
568570

569571
def __init__(self,
570572
model_path: str,
571-
model_config: PretrainedConfig,
573+
config: PretrainedConfig,
572574
tokenizer: AutoTokenizer,
573575
trust_remote_code: bool = True):
574-
575-
self.pretrained_config = model_config
576-
self.tokenizer = tokenizer
577-
self.use_fast = True
578-
if self.tokenizer is None:
579-
self.tokenizer = AutoTokenizer.from_pretrained(
580-
model_path,
581-
trust_remote_code=trust_remote_code,
582-
use_fast=self.use_fast)
583-
self.processor = AutoProcessor.from_pretrained(
576+
super().__init__()
577+
self._config = config
578+
self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
579+
model_path,
580+
trust_remote_code=trust_remote_code,
581+
use_fast=self.use_fast)
582+
self._processor = AutoProcessor.from_pretrained(
584583
model_path,
585584
trust_remote_code=trust_remote_code,
586585
use_fast=self.use_fast)
587-
self.tllm_multimodal_token_id = self.pretrained_config.language_config[
586+
self._model_path = model_path
587+
self._dtype = self.config.torch_dtype
588+
589+
self.tllm_multimodal_token_id = self.config.language_config[
588590
"vocab_size"] + 1
589591
self.vision_query_lengths = None
590592
self._vision_query_generator = None
591593

594+
@property
595+
def config(self) -> PretrainedConfig:
596+
return self._config
597+
598+
@property
599+
def tokenizer(self) -> AutoTokenizer:
600+
return self._tokenizer
601+
602+
@property
603+
def model_path(self) -> str:
604+
return self._model_path
605+
606+
@property
607+
def processor(self) -> AutoProcessor:
608+
return self._processor
609+
610+
@property
611+
def dtype(self) -> torch.dtype:
612+
return self._dtype
613+
592614
def get_vocab_size(self):
593-
return self.pretrained_config.language_config["vocab_size"]
615+
return self.config.language_config["vocab_size"]
594616

595617
def get_num_tokens_per_image(
596618
self,
@@ -656,8 +678,7 @@ def _post_process(self,
656678
vision_query_lengths = preprocessed_image.get("vision_query_lengths",
657679
None)
658680
non_vision_query_lengths = determine_non_vision_query_lengths(
659-
input_ids, self.tokenizer.pad_token_id,
660-
self.pretrained_config.img_start_id)
681+
input_ids, self.tokenizer.pad_token_id, self.config.img_start_id)
661682
batch_size = input_ids.size(0)
662683

663684
len_inputs_embeds = max([
@@ -666,19 +687,18 @@ def _post_process(self,
666687
non_vision_query_lengths, vision_query_lengths)
667688
])
668689

669-
len_inputs_embeds = min(self.pretrained_config.decoder_max_length,
690+
len_inputs_embeds = min(self.config.decoder_max_length,
670691
len_inputs_embeds)
671692

672-
image_cnts = (input_ids == self.pretrained_config.img_start_id).sum(
673-
dim=1).tolist()
693+
image_cnts = (input_ids == self.config.img_start_id).sum(dim=1).tolist()
674694

675695
fused_input_ids = torch.zeros([batch_size, len_inputs_embeds],
676696
dtype=input_ids.dtype)
677697
for batch_idx, sample in enumerate(input_ids):
678698
non_vision_query_length = non_vision_query_lengths[batch_idx]
679699
sample = sample[:non_vision_query_length + image_cnts[batch_idx]]
680700

681-
mask = (sample == self.pretrained_config.img_start_id)
701+
mask = (sample == self.config.img_start_id)
682702
img_start_ids = mask.nonzero()
683703
input_start, temp_start = 0, 0
684704

@@ -779,32 +799,30 @@ class HCXVisionModel(nn.Module):
779799
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
780800
super().__init__()
781801
self.model_config = model_config
782-
self.pretrained_config = model_config.pretrained_config
802+
self.config = model_config.pretrained_config
783803
siglip_model_config = copy.deepcopy(self.model_config)
784804
siglip_model_config.pretrained_config = self.model_config.pretrained_config.vision_config
785805
self.visual_token_idx = 0 if "siglip" in self.model_config.pretrained_config.vision_config.model_type else 1
786806
self.dtype = self.model_config.pretrained_config.vision_config.torch_dtype
787807
self.vision_model = SiglipVisionModel(siglip_model_config).to(
788808
self.dtype)
789809
self.mm_projector = HCXVisionCAbstractor(
790-
num_queries=self.pretrained_config.num_queries_vis_abstractor,
791-
num_input_tokens=(
792-
self.pretrained_config.vision_config.image_size //
793-
self.pretrained_config.vision_config.patch_size)**2,
794-
encoder_hidden_size=self.pretrained_config.vision_config.
795-
hidden_size,
796-
hidden_size=self.pretrained_config.vision_config.hidden_size,
797-
output_hidden_size=self.pretrained_config.hidden_size,
798-
pos_emb=self.pretrained_config.proj_pos_emb,
799-
prenorm=self.pretrained_config.proj_prenorm,
810+
num_queries=self.config.num_queries_vis_abstractor,
811+
num_input_tokens=(self.config.vision_config.image_size //
812+
self.config.vision_config.patch_size)**2,
813+
encoder_hidden_size=self.config.vision_config.hidden_size,
814+
hidden_size=self.config.vision_config.hidden_size,
815+
output_hidden_size=self.config.hidden_size,
816+
pos_emb=self.config.proj_pos_emb,
817+
prenorm=self.config.proj_prenorm,
800818
).to(self.dtype)
801819
self.image_newline = nn.Parameter(torch.empty(
802-
self.pretrained_config.hidden_size, ),
820+
self.config.hidden_size, ),
803821
requires_grad=False).to(self.dtype)
804822

805-
self.unpad = self.pretrained_config.unpad
806-
self.use_nth_layer = self.pretrained_config.use_nth_layer
807-
self.anyres = self.pretrained_config.anyres
823+
self.unpad = self.config.unpad
824+
self.use_nth_layer = self.config.use_nth_layer
825+
self.anyres = self.config.anyres
808826
self.possible_resolutions = self._init_possible_resolutions()
809827
self.post_config()
810828

@@ -814,18 +832,18 @@ def post_config(self):
814832

815833
def _init_possible_resolutions(self):
816834
possible_resolutions = []
817-
if self.pretrained_config.anyres:
818-
assert self.pretrained_config.max_num_grids > 0
819-
for i in range(1, self.pretrained_config.max_num_grids + 1):
820-
for j in range(1, self.pretrained_config.max_num_grids + 1):
821-
if i == 1 and j == 1 and not self.pretrained_config.use_1x1_grid:
835+
if self.config.anyres:
836+
assert self.config.max_num_grids > 0
837+
for i in range(1, self.config.max_num_grids + 1):
838+
for j in range(1, self.config.max_num_grids + 1):
839+
if i == 1 and j == 1 and not self.config.use_1x1_grid:
822840
continue
823-
if i * j <= self.pretrained_config.max_num_grids:
841+
if i * j <= self.config.max_num_grids:
824842
possible_resolutions.append([i, j])
825843

826844
possible_resolutions = [[
827-
ys * self.pretrained_config.vision_config.image_size,
828-
xs * self.pretrained_config.vision_config.image_size
845+
ys * self.config.vision_config.image_size,
846+
xs * self.config.vision_config.image_size
829847
] for ys, xs in possible_resolutions]
830848
return possible_resolutions
831849

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import torch
66
from PIL.Image import Image
77
from torch import nn
8-
from transformers import (AutoProcessor, Llama4Config, Llama4VisionModel,
9-
LlamaConfig)
8+
from transformers import (AutoProcessor, AutoTokenizer, Llama4Config,
9+
Llama4VisionModel, LlamaConfig, PretrainedConfig)
1010
from transformers.modeling_utils import load_sharded_checkpoint
1111
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
1212

@@ -21,7 +21,8 @@
2121
from tensorrt_llm.lora_manager import HfLoraLoader
2222
from tensorrt_llm.models.convert_utils import split_matrix_tp
2323

24-
from ...inputs import (ExtraProcessedInputs, InputProcessor,
24+
from ...inputs import (BaseMultimodalDummyInputsBuilder,
25+
BaseMultimodalInputProcessor, ExtraProcessedInputs,
2526
MultimodalPlaceholderMetadata,
2627
MultimodalPlaceholderPlacement, TextPrompt,
2728
register_input_processor)
@@ -1042,26 +1043,54 @@ def forward(self, multimodal_params: List[MultimodalParams]):
10421043
return [image_features]
10431044

10441045

1045-
class Llama4InputProcessor(InputProcessor):
1046+
from transformers import AutoTokenizer, PretrainedConfig
1047+
1048+
1049+
class Llama4InputProcessor(BaseMultimodalInputProcessor,
1050+
BaseMultimodalDummyInputsBuilder):
10461051

10471052
def __init__(self,
1048-
model_path,
1049-
model_config,
1050-
tokenizer,
1053+
model_path: str,
1054+
config: PretrainedConfig,
1055+
tokenizer: AutoTokenizer,
10511056
trust_remote_code: bool = True):
1052-
self.use_fast = True
1053-
self.processor = AutoProcessor.from_pretrained(
1057+
super().__init__()
1058+
self._config = config
1059+
self._dtype = self._config.torch_dtype
1060+
self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
1061+
model_path)
1062+
self._model_path = model_path
1063+
self._processor = AutoProcessor.from_pretrained(
10541064
model_path,
1055-
trust_remote_code=trust_remote_code,
1056-
use_fast=self.use_fast)
1057-
self.model_config = model_config
1058-
self.tokenizer = tokenizer
1059-
self.vocab_size = model_config.text_config.vocab_size
1060-
self.image_token_index = model_config.image_token_index
1065+
use_fast=self.use_fast,
1066+
trust_remote_code=trust_remote_code)
1067+
1068+
self.vocab_size = self.config.text_config.vocab_size
1069+
self.image_token_index = self.config.image_token_index
10611070
self.fake_image_token = self.processor.fake_image_token
10621071
self.image_token = self.processor.img_patch_token
1063-
self.image_token_start_index = self.model_config.boi_token_index
1064-
self.image_token_end_index = self.model_config.eoi_token_index
1072+
self.image_token_start_index = self.config.boi_token_index
1073+
self.image_token_end_index = self.config.eoi_token_index
1074+
1075+
@property
1076+
def config(self) -> PretrainedConfig:
1077+
return self._config
1078+
1079+
@property
1080+
def tokenizer(self) -> AutoTokenizer:
1081+
return self._tokenizer
1082+
1083+
@property
1084+
def model_path(self) -> str:
1085+
return self._model_path
1086+
1087+
@property
1088+
def processor(self) -> AutoProcessor:
1089+
return self._processor
1090+
1091+
@property
1092+
def dtype(self) -> torch.dtype:
1093+
return self._dtype
10651094

10661095
def attach_multimodal_embeddings(
10671096
self, inputs: TextPrompt, multimodal_embedding: Dict[str,
@@ -1121,7 +1150,7 @@ def attach_multimodal_embeddings(
11211150
f"Missing required key in multimodal embedding: {e}")
11221151

11231152
# Validate embedding dimensions
1124-
model_hidden_size = self.model_config.text_config.hidden_size
1153+
model_hidden_size = self.config.text_config.hidden_size
11251154
for i, embedding in enumerate(mm_embeddings):
11261155
if embedding.shape[-1] != model_hidden_size:
11271156
raise ValueError(

0 commit comments

Comments
 (0)