|
1 | 1 | import copy |
2 | | -from typing import Any, Dict, Optional, Tuple |
| 2 | +from typing import Any, Dict, List, Optional, Tuple |
3 | 3 |
|
4 | 4 | import torch |
| 5 | +from PIL.Image import Image |
5 | 6 | from torch import nn |
6 | | -from transformers import Llama4Config, LlamaConfig |
| 7 | +from transformers import (AutoProcessor, Llama4Config, Llama4VisionModel, |
| 8 | + LlamaConfig) |
| 9 | +from transformers.modeling_utils import load_sharded_checkpoint |
| 10 | +from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector |
7 | 11 |
|
8 | 12 | from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, |
9 | 13 | AllReduceParams, DeepseekAllReduce) |
10 | 14 | from tensorrt_llm._torch.pipeline_interface import PipelineInterface |
11 | 15 | from tensorrt_llm.functional import PositionEmbeddingType |
12 | 16 |
|
| 17 | +from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, |
| 18 | + register_input_processor) |
| 19 | +from ...sampling_params import SamplingParams |
13 | 20 | from ..attention_backend import AttentionMetadata |
14 | 21 | from ..attention_backend.interface import (PositionalEmbeddingParams, |
15 | 22 | PredefinedAttentionMask, RopeParams) |
|
26 | 33 | from ..modules.rms_norm import RMSNorm |
27 | 34 | from ..modules.rotary_embedding import RotaryEmbedding |
28 | 35 | from ..speculative import Eagle3SpecMetadata, SpecMetadata |
| 36 | +from .modeling_multimodal_utils import fuse_input_embeds |
29 | 37 | from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, |
30 | 38 | EagerFusionConfig, MissingLayer, |
31 | 39 | register_auto_model, support_pp, |
@@ -829,15 +837,13 @@ def __init__( |
829 | 837 | vocab_size=model_config.pretrained_config.vocab_size) |
830 | 838 |
|
831 | 839 |
|
832 | | -@register_auto_model("Llama4ForConditionalGeneration") |
833 | | -class Llama4ForConditionalGeneration(DecoderModelForCausalLM[Llama4Model, |
834 | | - Llama4Config]): |
| 840 | +@register_auto_model("Llama4ForCausalLM") |
| 841 | +class Llama4ForCausalLM(DecoderModelForCausalLM[LlamaModel, Llama4Config]): |
835 | 842 |
|
836 | 843 | def __init__( |
837 | 844 | self, |
838 | 845 | model_config: ModelConfig[Llama4Config], |
839 | 846 | ): |
840 | | - # TODO: figure out a better way to handle multimodality. |
841 | 847 | model_config = copy.copy(model_config) |
842 | 848 | architectures = model_config.pretrained_config.architectures |
843 | 849 | model_config.pretrained_config = model_config.pretrained_config.text_config |
@@ -876,6 +882,82 @@ def load_weights(self, weights: Dict): |
876 | 882 | idx + 1].input_layernorm |
877 | 883 |
|
878 | 884 |
|
| 885 | +class Llama4InputProcessor(InputProcessor): |
| 886 | + |
| 887 | + def __init__(self, model_path, model_config, tokenizer): |
| 888 | + self.processor = AutoProcessor.from_pretrained(model_path, |
| 889 | + use_fast=True) |
| 890 | + self.model_config = model_config |
| 891 | + self.tokenizer = tokenizer |
| 892 | + self.vocab_size = model_config.text_config.vocab_size |
| 893 | + self.image_token_index = model_config.image_token_index |
| 894 | + |
| 895 | + self.encoder = nn.ModuleDict({ |
| 896 | + "vision_model": |
| 897 | + Llama4VisionModel(model_config.vision_config), |
| 898 | + "multi_modal_projector": |
| 899 | + Llama4MultiModalProjector(model_config) |
| 900 | + }).cuda() |
| 901 | + load_sharded_checkpoint(self.encoder, model_path, strict=False) |
| 902 | + |
| 903 | + @torch.inference_mode() |
| 904 | + def __call__( |
| 905 | + self, inputs: TextPrompt, sampling_params: SamplingParams |
| 906 | + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: |
| 907 | + text_prompt, mm_data = inputs.get("prompt"), inputs.get( |
| 908 | + "multi_modal_data") |
| 909 | + images, do_rescale = None, True |
| 910 | + |
| 911 | + if mm_data and mm_data.get("image"): |
| 912 | + images = mm_data["image"] |
| 913 | + img_type = type(mm_data["image"][0]) |
| 914 | + do_rescale = (img_type == Image) |
| 915 | + assert all(isinstance(img, img_type) for img in mm_data["image"]) |
| 916 | + |
| 917 | + # preprocess images and insert image tokens |
| 918 | + processed = self.processor(text=text_prompt, |
| 919 | + images=images, |
| 920 | + return_tensors="pt", |
| 921 | + device="cuda", |
| 922 | + do_rescale=do_rescale, |
| 923 | + add_special_tokens=False) |
| 924 | + if images: |
| 925 | + token_ids, pixel_values = processed["input_ids"].squeeze( |
| 926 | + ), processed["pixel_values"] |
| 927 | + mm_embeds = self.encoder.vision_model( |
| 928 | + pixel_values.float().cuda()).last_hidden_state.flatten(0, 1) |
| 929 | + mm_embeds = self.encoder.multi_modal_projector(mm_embeds) |
| 930 | + # for fuse_input_embeds |
| 931 | + token_ids[token_ids == self.image_token_index] = self.vocab_size + 1 |
| 932 | + return token_ids.tolist(), { |
| 933 | + "prompt_tuning_config": [mm_embeds, None, None] |
| 934 | + } |
| 935 | + else: |
| 936 | + return processed["input_ids"].squeeze().tolist(), {} |
| 937 | + |
| 938 | + |
| 939 | +@register_auto_model("Llama4ForConditionalGeneration") |
| 940 | +@register_input_processor(Llama4InputProcessor) |
| 941 | +class Llama4ForConditionalGeneration(Llama4ForCausalLM): |
| 942 | + |
| 943 | + @torch.inference_mode() |
| 944 | + def forward( |
| 945 | + self, |
| 946 | + attn_metadata: AttentionMetadata, |
| 947 | + input_ids: Optional[torch.LongTensor] = None, |
| 948 | + position_ids: Optional[torch.LongTensor] = None, |
| 949 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 950 | + return_context_logits: Optional[bool] = False, |
| 951 | + **kwargs, |
| 952 | + ) -> torch.Tensor: |
| 953 | + mm_embed = kwargs.get("multi_modal_data", []) |
| 954 | + input_ids, inputs_embeds = fuse_input_embeds(self.model.embed_tokens, |
| 955 | + input_ids, mm_embed) |
| 956 | + logits = super().forward(attn_metadata, input_ids, position_ids, |
| 957 | + inputs_embeds, return_context_logits) |
| 958 | + return logits |
| 959 | + |
| 960 | + |
879 | 961 | @register_auto_model("MistralForCausalLM") |
880 | 962 | class MistralForCausalLM(DecoderModelForCausalLM[LlamaModel, LlamaConfig]): |
881 | 963 |
|
|
0 commit comments