Skip to content

Commit f48968b

Browse files
[TRTLLM-6928][fix] Refactor multimodal unittest (#8453)
Signed-off-by: yechank <[email protected]>
1 parent 14bc857 commit f48968b

22 files changed

+1244
-808
lines changed

tensorrt_llm/_torch/models/checkpoints/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .hf.config_loader import HfConfigLoader
44
from .hf.gemma3_weight_mapper import Gemma3HfWeightMapper
55
from .hf.llama4_weight_mapper import Llama4HfWeightMapper
6+
from .hf.llava_next_weight_mapper import LlavaNextHfWeightMapper
67
from .hf.mixtral_weight_mapper import MixtralHfWeightMapper
78
from .hf.nemotron_h_weight_mapper import NemotronHHfWeightMapper
89
from .hf.qwen2_moe_weight_mapper import Qwen2MoeHfWeightMapper
@@ -17,5 +18,5 @@
1718
"BaseCheckpointLoader", "HfCheckpointLoader", "NemotronHHfWeightMapper",
1819
"Gemma3HfWeightMapper", "MixtralHfWeightMapper", "Llama4HfWeightMapper",
1920
"Qwen2MoeHfWeightMapper", "Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper",
20-
"Qwen3NextHfWeightMapper"
21+
"Qwen3NextHfWeightMapper", "LlavaNextHfWeightMapper"
2122
]
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper
2+
from tensorrt_llm._torch.models.modeling_utils import register_mapper
3+
4+
5+
@register_mapper("HF", "LlavaNextForConditionalGeneration")
6+
class LlavaNextHfWeightMapper(HfWeightMapper):
7+
def preprocess_weights(self, weights: dict) -> dict:
8+
transformed_weights = {}
9+
for key, value in weights.items():
10+
if key.startswith("model."):
11+
new_key = key[len("model.") :]
12+
transformed_weights[new_key] = value
13+
else:
14+
transformed_weights[key] = value
15+
return transformed_weights

tensorrt_llm/_torch/models/modeling_llava_next.py

Lines changed: 62 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,23 @@
55
import numpy as np
66
import torch
77
import torch.nn as nn
8-
from transformers import (AutoConfig, AutoModel, AutoProcessor, AutoTokenizer,
9-
LlavaNextConfig, PretrainedConfig, PreTrainedModel)
10-
from transformers.modeling_utils import load_sharded_checkpoint
8+
from transformers import (AutoProcessor, AutoTokenizer, LlavaNextConfig,
9+
PretrainedConfig, PreTrainedModel)
1110
from transformers.models.llava_next.modeling_llava_next import (
1211
LlavaNextMultiModalProjector, get_anyres_image_grid_shape,
1312
image_size_to_num_patches, unpad_image)
1413

14+
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
15+
BaseWeightMapper
16+
from tensorrt_llm._torch.models.checkpoints.hf.llava_next_weight_mapper import \
17+
LlavaNextHfWeightMapper
1518
from tensorrt_llm.inputs.multimodal import MultimodalParams
1619

1720
from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
1821
InputProcessor, MultimodalPlaceholderMetadata,
1922
MultimodalPlaceholderPlacement, TextPrompt,
2023
register_input_processor,
2124
support_multimodal_disaggregated)
22-
from ...llmapi.utils import download_hf_model
2325
from ...logger import logger
2426
from ...sampling_params import SamplingParams
2527
from ..attention_backend import AttentionMetadata
@@ -28,8 +30,7 @@
2830
from .modeling_clip import CLIPVisionModel
2931
from .modeling_multimodal_utils import (find_input_mm_embeds, fuse_input_embeds,
3032
get_multimodal_embeddings)
31-
from .modeling_utils import (filter_weights, register_auto_model,
32-
register_vision_encoder)
33+
from .modeling_utils import register_auto_model, register_vision_encoder
3334

3435
DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
3536

@@ -295,62 +296,36 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
295296
super().__init__()
296297
self.model_config = model_config
297298
self.pretrained_config = model_config.pretrained_config
298-
# TODO: use config.mapping.get_local_rank() instead
299-
self.device = f"cuda:{torch.cuda.current_device()}"
300-
model_path = self.pretrained_config._name_or_path
301299

302-
# Determine the actual local path for model files
303-
if os.path.isdir(model_path):
304-
local_model_path = model_path
305-
else:
306-
local_model_path = download_hf_model(model_path)
307-
308-
# Partially load the model to reduce memory usage(Vision tower and multi-modal projector)
309-
hf_model_config = AutoConfig.from_pretrained(local_model_path)
310-
self.dtype = hf_model_config.text_config.torch_dtype
311-
module_dict = nn.ModuleDict({
312-
"vision_tower":
313-
AutoModel.from_config(hf_model_config.vision_config),
314-
"multi_modal_projector":
315-
LlavaNextMultiModalProjector(hf_model_config)
316-
})
317-
module_dict.register_parameter(
318-
"image_newline",
319-
nn.Parameter(torch.empty(hf_model_config.text_config.hidden_size)))
320-
321-
missing_keys, _ = load_sharded_checkpoint(module_dict,
322-
local_model_path,
323-
strict=False)
324-
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
325-
hf_vision_tower = module_dict["vision_tower"].to(self.dtype)
326-
hf_mm_projector = module_dict["multi_modal_projector"].to(
327-
self.dtype).to(self.device)
328-
hf_image_newline = module_dict.image_newline.to(self.dtype).to(
329-
self.device)
330-
331-
# For A100 GPU, fallback to HF vision tower due to accuracy issue in TRT-LLM CLIPAttention
332-
# Otherwise, use TRTLLM vision tower(CLIPVisionModel)
333-
prop = torch.cuda.get_device_properties(0)
334-
sm_version = prop.major * 10 + prop.minor
335-
self.use_hf_vision_tower = sm_version == 80
336-
if self.use_hf_vision_tower:
337-
self.vision_tower = hf_vision_tower.to(self.device)
338-
else:
339-
vision_model_config = ModelConfig(
340-
pretrained_config=self.pretrained_config.vision_config,
341-
attn_backend="TRTLLM")
342-
self.vision_tower = CLIPVisionModel(vision_model_config).to(
343-
self.device).to(self.dtype)
344-
self.vision_tower.load_weights(hf_vision_tower.state_dict())
345-
346-
# Use HF multi-modal projector
347-
self.mm_projector = hf_mm_projector
348-
self.image_newline = hf_image_newline
300+
clip_model_config = copy.deepcopy(self.model_config)
301+
clip_model_config.pretrained_config = self.model_config.pretrained_config.vision_config
302+
self.dtype = self.model_config.pretrained_config.text_config.torch_dtype
303+
self.vision_model = CLIPVisionModel(clip_model_config).to(self.dtype)
304+
self.mm_projector = LlavaNextMultiModalProjector(
305+
self.pretrained_config).to(self.dtype)
306+
self.image_newline = nn.Parameter(torch.empty(
307+
self.pretrained_config.text_config.hidden_size),
308+
requires_grad=False).to(self.dtype)
349309
self.vision_feature_select_strategy = getattr(
350310
self.pretrained_config, "vision_feature_select_strategy", "default")
351-
352311
self.post_config()
353312

313+
def load_weights(self, weights):
314+
315+
def filter_weights(prefix, weights: Dict):
316+
result = {}
317+
for key, weight in weights.items():
318+
if key.startswith(prefix):
319+
new_key = key[len(prefix):]
320+
result[new_key] = weight
321+
return result
322+
323+
visual_model_weights = filter_weights("vision_tower.", weights)
324+
self.vision_model.load_weights(visual_model_weights)
325+
mm_projector_weights = filter_weights("multi_modal_projector.", weights)
326+
self.mm_projector.load_state_dict(mm_projector_weights, strict=True)
327+
self.image_newline.data.copy_(weights["image_newline"])
328+
354329
def post_config(self):
355330
self.config = self.pretrained_config.vision_config
356331

@@ -464,7 +439,6 @@ def forward(self, multimodal_params: List[MultimodalParams]):
464439
for multimodal_param in multimodal_params
465440
]
466441
pixel_values = self._pad_for_batching(pixel_values)
467-
468442
pixel_values = torch.cat(pixel_values, dim=0)
469443
image_sizes = torch.cat(image_sizes, dim=0)
470444

@@ -484,23 +458,18 @@ def forward(self, multimodal_params: List[MultimodalParams]):
484458
]
485459
pixel_values = torch.cat(_pixel_values_list, dim=0)
486460

487-
if self.use_hf_vision_tower:
488-
image_features = self.vision_tower(
489-
pixel_values, output_hidden_states=True).hidden_states
490-
else:
491-
attn_metadata = self.vision_tower.prepare_attn_metadata(
492-
pixel_values.shape[0])
493-
image_features = self.vision_tower(
494-
pixel_values,
495-
attn_metadata=attn_metadata,
496-
)
461+
attn_metadata = self.vision_model.prepare_attn_metadata(
462+
pixel_values.shape[0])
463+
image_features = self.vision_model(
464+
pixel_values,
465+
attn_metadata=attn_metadata,
466+
)
497467
selected_image_feature = image_features[-2][:, 1:]
498468
image_features = self.mm_projector(selected_image_feature)
499-
500469
image_features = torch.split(image_features, image_num_patches, dim=0)
501470

502-
# NOTE: 'pack_image_features' is directly copied from the HF's code
503-
image_features, feature_lens = self.pack_image_features(
471+
# NOTE: 'pack_image_features' is from the HF's code
472+
image_features, _ = self.pack_image_features(
504473
image_features,
505474
image_sizes,
506475
vision_feature_select_strategy=self.vision_feature_select_strategy,
@@ -526,6 +495,7 @@ class LlavaNextModel(PreTrainedModel):
526495
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
527496
**kwargs) -> None:
528497
config = model_config.pretrained_config
498+
self._supports_sdpa = True
529499
super().__init__(config)
530500
if hasattr(self, "llm"):
531501
return
@@ -543,16 +513,29 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
543513
self.llm = AutoModelForCausalLM.from_config(llm_model_config)
544514

545515
self.model_config = model_config
546-
self.model_dtype = getattr(config.text_config, "torch_dtype",
547-
torch.float16)
548-
logger.info(f"{self.dtype=} {self.model_dtype=}")
549-
550516
self.post_config()
551-
self.is_loaded = True
552517

553-
def load_weights(self, weights):
554-
weights = filter_weights("language_model", weights)
555-
self.llm.load_weights(weights)
518+
def load_weights(self, weights, weight_mapper: BaseWeightMapper):
519+
if isinstance(weight_mapper, LlavaNextHfWeightMapper):
520+
weights = weight_mapper.preprocess_weights(weights)
521+
522+
self.mm_encoder.load_weights(weights)
523+
524+
def filter_weights(weights: Dict):
525+
transformed_weights = {}
526+
for key, weight in weights.items():
527+
if key.startswith("language_model."):
528+
if isinstance(weight_mapper, LlavaNextHfWeightMapper):
529+
new_key = "model." + key[len("language_model."):]
530+
else:
531+
new_key = key[len("language_model."):]
532+
transformed_weights[new_key] = weight
533+
elif key.startswith("lm_head."):
534+
transformed_weights[key] = weight
535+
return transformed_weights
536+
537+
language_model_weights = filter_weights(weights)
538+
self.llm.load_weights(language_model_weights)
556539

557540
def post_config(self):
558541
self.config = self.llm.config
@@ -590,7 +573,6 @@ def forward(
590573
mm_embeds, multimodal_params[:num_context_requests])
591574
input_ids, inputs_embeds = fuse_input_embeds(
592575
self.llm.model.embed_tokens, input_ids, mm_embeds, **kwargs)
593-
594576
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
595577
inputs_embeds, return_context_logits)
596578
return logits

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,13 @@ def __init__(
840840
model_config.pretrained_config.disable_fuse_rope = disabble_fuse_rope
841841
model_config.pretrained_config.rope_scaling['type'] = 'mrope'
842842
config = model_config.pretrained_config
843+
844+
self._supports_sdpa = True
843845
super().__init__(config)
844846

847+
if not disabble_fuse_rope:
848+
self.init_mrope_embedding(model_config)
849+
845850
self.model_config = model_config
846851
self.config = model_config.pretrained_config
847852

@@ -947,14 +952,10 @@ def forward(
947952
VLM forward logic with inflight batching support.
948953
"""
949954
num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations
950-
logger.debug(
951-
f"num_context_requests: {num_context_requests}, num_generation_requests: {num_generation_requests}"
952-
)
953955

954956
multimodal_params = kwargs.get("multimodal_params", [])
955957
mm_embeds = []
956958
mrope_config = {}
957-
958959
if len(multimodal_params) > 0:
959960
if not DISAGG:
960961
mm_embeds = get_multimodal_embeddings(
@@ -965,7 +966,6 @@ def forward(
965966
"Qwen2VLModel does not support disaggregated inference yet. Please unset "
966967
f"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
967968
)
968-
969969
mm_embeds = find_input_mm_embeds(
970970
mm_embeds, multimodal_params[:num_context_requests])
971971
if not self.model_config.pretrained_config.disable_fuse_rope:

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,6 @@ def create_py_executor(
279279
if mm_encoder_only:
280280
# TODO(qijun): clean up pytorch_backend_config later
281281
pytorch_backend_config.mm_encoder_only = True
282-
pytorch_backend_config.load_format = LoadFormat.VISION_ONLY
283282
# Disable overlap scheduler for multimodal encoder-only mode
284283
logger.warning(
285284
"Disabling overlap scheduler for multimodal encoder-only mode. "
@@ -288,7 +287,6 @@ def create_py_executor(
288287
pytorch_backend_config.disable_overlap_scheduler = True
289288

290289
llm_args.mm_encoder_only = True
291-
llm_args.load_format = LoadFormat.VISION_ONLY
292290
llm_args.disable_overlap_scheduler = True
293291

294292
mapping = _get_mapping(llm_args.parallel_config.to_mapping())

tensorrt_llm/evaluate/lm_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def apply_chat_template(self,
236236
output = trtllm_apply_chat_template(
237237
model_type=self.model_type,
238238
tokenizer=self.llm.tokenizer,
239-
processor=self.llm.input_processor.processor,
239+
processor=getattr(self.llm.input_processor, 'processor', None),
240240
conversation=chat_history,
241241
add_generation_prompt=add_generation_prompt,
242242
mm_placeholder_counts=mm_placeholder_counts,

tensorrt_llm/inputs/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def apply_chat_template(
545545
processor: ProcessorMixin,
546546
conversation: list[ConversationMessage],
547547
add_generation_prompt: bool,
548-
mm_placeholder_counts: dict[str, int],
548+
mm_placeholder_counts: list[dict[str, int]],
549549
tools: Optional[list[dict[str, Any]]] = None,
550550
documents: Optional[list[dict[str, str]]] = None,
551551
chat_template: Optional[str] = None,
@@ -567,7 +567,7 @@ def apply_chat_template(
567567
if model_type in PLACEHOLDER_EXCEPTIONS:
568568
# flattened content do not work for these models, so go back to other formats as needed
569569
conversation = handle_placeholder_exceptions(model_type, conversation,
570-
[mm_placeholder_counts])
570+
mm_placeholder_counts)
571571

572572
return tokenizer.apply_chat_template(
573573
conversation=conversation,
@@ -732,7 +732,7 @@ def convert_to_conversation_message(
732732
processor=processor,
733733
conversation=[conv],
734734
add_generation_prompt=True,
735-
mm_placeholder_counts=mm_placeholder_counts)
735+
mm_placeholder_counts=[mm_placeholder_counts])
736736
input = {"prompt": prompt}
737737
if mm_placeholder_counts:
738738
if mm_embeddings is not None:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11
Qwen/Qwen2-VL-7B-Instruct:
22
- accuracy: 48.44
3+
Qwen/Qwen2.5-VL-7B-Instruct:
4+
- accuracy: 51.22
35
nvidia/Nano-v2-VLM:
46
- accuracy: 43.78
7+
llava-hf/llava-v1.6-mistral-7b:
8+
- accuracy: 35.33
9+
Efficient-Large-Model/NVILA-8B:
10+
- accuracy: 47.77
11+
Efficient-Large-Model/VILA1.5-3b:
12+
- accuracy: 32.33

0 commit comments

Comments
 (0)