Skip to content

Commit 362a827

Browse files
feat: llama4 input processor (NVIDIA#3383)
Signed-off-by: Alexandre Milesi <[email protected]> Signed-off-by: Haohang Huang <[email protected]> Co-authored-by: Alexandre Milesi <[email protected]> Co-authored-by: Haohang Huang <[email protected]>
1 parent d747223 commit 362a827

File tree

13 files changed

+147
-33
lines changed

13 files changed

+147
-33
lines changed

examples/pytorch/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --mo
5151
| `LlavaLlamaModel` | VILA | `Efficient-Large-Model/NVILA-8B` | L + V |
5252
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | `llava-hf/llava-v1.6-mistral-7b-hf` | L + V |
5353
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA | `meta-llama/Meta-Llama-3.1-70B` | L |
54-
| `Llama4ForConditionalGeneration` | Llama 4 Scout/Maverick | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct` | L |
54+
| `Llama4ForConditionalGeneration` | Llama 4 Scout/Maverick | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct` | L + V |
5555
| `MistralForCausalLM` | Mistral | `mistralai/Mistral-7B-v0.1` | L |
5656
| `MixtralForCausalLM` | Mixtral | `mistralai/Mixtral-8x7B-v0.1` | L |
5757
| `MllamaForConditionalGeneration` | Llama 3.2 | `meta-llama/Llama-3.2-11B-Vision` | L |

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ torchvision
2626
nvidia-modelopt[torch]~=0.27.0
2727
nvidia-nccl-cu12
2828
nvidia-cuda-nvrtc-cu12
29-
transformers==4.51.0
29+
transformers~=4.51.1
3030
pydantic>=2.9.1
3131
pillow==10.3.0
3232
wheel<=0.45.1

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
import copy
2-
from typing import Any, Dict, Optional, Tuple
2+
from typing import Any, Dict, List, Optional, Tuple
33

44
import torch
5+
from PIL.Image import Image
56
from torch import nn
6-
from transformers import Llama4Config, LlamaConfig
7+
from transformers import (AutoProcessor, Llama4Config, Llama4VisionModel,
8+
LlamaConfig)
9+
from transformers.modeling_utils import load_sharded_checkpoint
10+
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
711

812
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
913
AllReduceParams, DeepseekAllReduce)
1014
from tensorrt_llm._torch.pipeline_interface import PipelineInterface
1115
from tensorrt_llm.functional import PositionEmbeddingType
1216

17+
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
18+
register_input_processor)
19+
from ...sampling_params import SamplingParams
1320
from ..attention_backend import AttentionMetadata
1421
from ..attention_backend.interface import (PositionalEmbeddingParams,
1522
PredefinedAttentionMask, RopeParams)
@@ -26,6 +33,7 @@
2633
from ..modules.rms_norm import RMSNorm
2734
from ..modules.rotary_embedding import RotaryEmbedding
2835
from ..speculative import Eagle3SpecMetadata, SpecMetadata
36+
from .modeling_multimodal_utils import fuse_input_embeds
2937
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
3038
EagerFusionConfig, MissingLayer,
3139
register_auto_model, support_pp,
@@ -829,15 +837,13 @@ def __init__(
829837
vocab_size=model_config.pretrained_config.vocab_size)
830838

831839

832-
@register_auto_model("Llama4ForConditionalGeneration")
833-
class Llama4ForConditionalGeneration(DecoderModelForCausalLM[Llama4Model,
834-
Llama4Config]):
840+
@register_auto_model("Llama4ForCausalLM")
841+
class Llama4ForCausalLM(DecoderModelForCausalLM[LlamaModel, Llama4Config]):
835842

836843
def __init__(
837844
self,
838845
model_config: ModelConfig[Llama4Config],
839846
):
840-
# TODO: figure out a better way to handle multimodality.
841847
model_config = copy.copy(model_config)
842848
architectures = model_config.pretrained_config.architectures
843849
model_config.pretrained_config = model_config.pretrained_config.text_config
@@ -876,6 +882,82 @@ def load_weights(self, weights: Dict):
876882
idx + 1].input_layernorm
877883

878884

885+
class Llama4InputProcessor(InputProcessor):
886+
887+
def __init__(self, model_path, model_config, tokenizer):
888+
self.processor = AutoProcessor.from_pretrained(model_path,
889+
use_fast=True)
890+
self.model_config = model_config
891+
self.tokenizer = tokenizer
892+
self.vocab_size = model_config.text_config.vocab_size
893+
self.image_token_index = model_config.image_token_index
894+
895+
self.encoder = nn.ModuleDict({
896+
"vision_model":
897+
Llama4VisionModel(model_config.vision_config),
898+
"multi_modal_projector":
899+
Llama4MultiModalProjector(model_config)
900+
}).cuda()
901+
load_sharded_checkpoint(self.encoder, model_path, strict=False)
902+
903+
@torch.inference_mode()
904+
def __call__(
905+
self, inputs: TextPrompt, sampling_params: SamplingParams
906+
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
907+
text_prompt, mm_data = inputs.get("prompt"), inputs.get(
908+
"multi_modal_data")
909+
images, do_rescale = None, True
910+
911+
if mm_data and mm_data.get("image"):
912+
images = mm_data["image"]
913+
img_type = type(mm_data["image"][0])
914+
do_rescale = (img_type == Image)
915+
assert all(isinstance(img, img_type) for img in mm_data["image"])
916+
917+
# preprocess images and insert image tokens
918+
processed = self.processor(text=text_prompt,
919+
images=images,
920+
return_tensors="pt",
921+
device="cuda",
922+
do_rescale=do_rescale,
923+
add_special_tokens=False)
924+
if images:
925+
token_ids, pixel_values = processed["input_ids"].squeeze(
926+
), processed["pixel_values"]
927+
mm_embeds = self.encoder.vision_model(
928+
pixel_values.float().cuda()).last_hidden_state.flatten(0, 1)
929+
mm_embeds = self.encoder.multi_modal_projector(mm_embeds)
930+
# for fuse_input_embeds
931+
token_ids[token_ids == self.image_token_index] = self.vocab_size + 1
932+
return token_ids.tolist(), {
933+
"prompt_tuning_config": [mm_embeds, None, None]
934+
}
935+
else:
936+
return processed["input_ids"].squeeze().tolist(), {}
937+
938+
939+
@register_auto_model("Llama4ForConditionalGeneration")
940+
@register_input_processor(Llama4InputProcessor)
941+
class Llama4ForConditionalGeneration(Llama4ForCausalLM):
942+
943+
@torch.inference_mode()
944+
def forward(
945+
self,
946+
attn_metadata: AttentionMetadata,
947+
input_ids: Optional[torch.LongTensor] = None,
948+
position_ids: Optional[torch.LongTensor] = None,
949+
inputs_embeds: Optional[torch.FloatTensor] = None,
950+
return_context_logits: Optional[bool] = False,
951+
**kwargs,
952+
) -> torch.Tensor:
953+
mm_embed = kwargs.get("multi_modal_data", [])
954+
input_ids, inputs_embeds = fuse_input_embeds(self.model.embed_tokens,
955+
input_ids, mm_embed)
956+
logits = super().forward(attn_metadata, input_ids, position_ids,
957+
inputs_embeds, return_context_logits)
958+
return logits
959+
960+
879961
@register_auto_model("MistralForCausalLM")
880962
class MistralForCausalLM(DecoderModelForCausalLM[LlamaModel, LlamaConfig]):
881963

tensorrt_llm/_torch/models/modeling_llava_next.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ def forward(
221221
mm_embed
222222
) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests"
223223

224-
input_ids, inputs_embeds = fuse_input_embeds(self, input_ids, mm_embed)
224+
input_ids, inputs_embeds = fuse_input_embeds(
225+
self.llm.model.embed_tokens, input_ids, mm_embed)
225226
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
226227
inputs_embeds, return_context_logits)
227228
return logits

tensorrt_llm/_torch/models/modeling_multimodal_utils.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
from PIL import Image
2626
from torchvision.transforms import Normalize, Resize, ToTensor
2727

28+
from tensorrt_llm._torch.modules.embedding import Embedding
29+
2830

2931
def fuse_input_embeds(
30-
model,
32+
embedding_layer: Embedding,
3133
input_ids: torch.LongTensor,
3234
mm_embeds: List[torch.Tensor],
3335
) -> Tuple[Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
@@ -44,20 +46,24 @@ def fuse_input_embeds(
4446
if len(mm_embeds) == 0:
4547
return input_ids, None
4648

49+
vocab_size = embedding_layer.num_embeddings
4750
mm_embed = torch.cat(mm_embeds, dim=0)
51+
52+
text_token_indices = torch.where(input_ids < vocab_size)[0]
53+
mm_token_indices = torch.where(input_ids >= vocab_size)[0]
54+
55+
text_embed = embedding_layer(input_ids[text_token_indices])
4856
input_embeds = torch.empty(input_ids.shape[0],
4957
mm_embed.shape[-1],
50-
device=input_ids.device,
51-
dtype=model.model_dtype)
52-
53-
text_token_indices = torch.where(input_ids < model.vocab_size)[0]
54-
mm_token_indices = torch.where(input_ids >= model.vocab_size)[0]
58+
device=text_embed.device,
59+
dtype=text_embed.dtype)
5560

56-
text_embed = model.llm.model.embed_tokens(input_ids[text_token_indices])
57-
input_embeds[text_token_indices, :] = text_embed.to(model.model_dtype)
58-
input_embeds[mm_token_indices, :] = mm_embed.to(model.model_dtype)
61+
input_embeds[text_token_indices, :] = text_embed.to(
62+
dtype=input_embeds.dtype, device=input_embeds.device)
63+
input_embeds[mm_token_indices, :] = mm_embed.to(dtype=input_embeds.dtype,
64+
device=input_embeds.device)
5965

60-
return None, input_embeds.to(model.dtype)
66+
return None, input_embeds
6167

6268

6369
#region VILA utils

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,8 @@ def forward(
417417
assert mm_embed == [] or len(
418418
mm_embed) == num_context_requests, error_msg
419419

420-
input_ids, input_embeds = fuse_input_embeds(self, input_ids, mm_embed)
420+
input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens,
421+
input_ids, mm_embed)
421422

422423
mrope_config = kwargs.get("mrope_config", {})
423424
if mrope_config:

tensorrt_llm/_torch/models/modeling_vila.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1167,7 +1167,8 @@ def forward(
11671167
mm_embed
11681168
) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests"
11691169

1170-
input_ids, inputs_embeds = fuse_input_embeds(self, input_ids, mm_embed)
1170+
input_ids, inputs_embeds = fuse_input_embeds(
1171+
self.llm.model.embed_tokens, input_ids, mm_embed)
11711172
logits = self.llm.forward(attn_metadata=attn_metadata,
11721173
input_ids=input_ids,
11731174
position_ids=position_ids,

tensorrt_llm/_torch/modules/embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
local_in_features -= self.padding_size
6060
self.in_features = local_in_features
6161
self.out_features = local_out_features
62+
self.num_embeddings = num_embeddings
6263

6364
weight_shape = (self.out_features, self.in_features)
6465
self.weight = Parameter(torch.empty(weight_shape, dtype=dtype))

tensorrt_llm/inputs/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .registry import (ExtraProcessedInputs, InputProcessor,
33
create_input_processor, register_input_processor)
44
from .utils import (INPUT_FORMATTER_MAP, default_image_loader,
5-
default_video_loader, format_llava_next_input,
5+
default_video_loader, format_generic_input,
66
format_qwen2_vl_input, format_vila_input, load_image,
77
load_video)
88

@@ -11,5 +11,5 @@
1111
"InputProcessor", "create_input_processor", "register_input_processor",
1212
"ExtraProcessedInputs", "load_image", "load_video", "INPUT_FORMATTER_MAP",
1313
"default_image_loader", "default_video_loader", "format_vila_input",
14-
"format_llava_next_input", "format_qwen2_vl_input"
14+
"format_generic_input", "format_qwen2_vl_input"
1515
]

tensorrt_llm/inputs/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def add_media_token(prompt, multi_modal_data):
104104
return inputs
105105

106106

107-
def format_llava_next_input(model_dir, inputs):
107+
def format_generic_input(model_dir, inputs):
108108
"""
109109
This function formats the input for the Llava Next VL model.
110110
@@ -122,15 +122,16 @@ def format_llava_next_input(model_dir, inputs):
122122
def apply_template(prompt, multimodal_data):
123123
conversation = [
124124
{
125-
"role": "user",
125+
"role":
126+
"user",
126127
"content": [
127128
{
128129
"type": "text",
129130
"text": prompt
130131
},
131-
{
132+
*[{
132133
"type": "image"
133-
},
134+
} for _ in multimodal_data["image"]],
134135
],
135136
},
136137
]
@@ -220,7 +221,8 @@ def default_video_loader(prompts, videos, image_data_format="pt", num_frames=8):
220221

221222
INPUT_FORMATTER_MAP = {
222223
"llava_llama": format_vila_input,
223-
"llava_next": format_llava_next_input,
224+
"llava_next": format_generic_input,
224225
"qwen2_vl": format_qwen2_vl_input,
225226
"qwen2_5_vl": format_qwen2_vl_input,
227+
"llama4": format_generic_input,
226228
}

0 commit comments

Comments
 (0)