Skip to content

Commit 7bc2a40

Browse files
committed
llama4 vlm
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent fb3a82b commit 7bc2a40

File tree

6 files changed

+393
-42
lines changed

6 files changed

+393
-42
lines changed

examples/auto_deploy/build_and_run_ad.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
# Global torch config, set the torch compile cache to fix up to llama 405B
2727
torch._dynamo.config.cache_size_limit = 20
2828

29+
# simple string, TRT-LLM style text-only prompt or full-scale HF message template
30+
PromptInput = Union[str, Dict, List[Dict]]
31+
2932

3033
class PromptConfig(BaseModel):
3134
"""Prompt configuration.
@@ -35,17 +38,27 @@ class PromptConfig(BaseModel):
3538
"""
3639

3740
batch_size: int = Field(default=2, description="Number of queries")
38-
queries: Union[str, List[str]] = Field(
41+
queries: Union[PromptInput, List[PromptInput]] = Field(
3942
default_factory=lambda: [
43+
# OPTION 1: simple text prompt
4044
"How big is the universe? ",
41-
"In simple words and in a single sentence, explain the concept of gravity: ",
42-
"How to fix slicing in golf? ",
43-
"Where is the capital of Iceland? ",
44-
"How big is the universe? ",
45-
"In simple words and in a single sentence, explain the concept of gravity: ",
46-
"How to fix slicing in golf? ",
47-
"Where is the capital of Iceland? ",
48-
]
45+
# OPTION 2: wrapped text prompt for TRT-LLM
46+
{"prompt": "In simple words and a single sentence, explain the concept of gravity: "},
47+
# OPTION 3: a full-scale HF message template (this one works for text-only models!)
48+
# Learn more about chat templates: https://huggingface.co/docs/transformers/en/chat_templating
49+
# and multi-modal templates: https://huggingface.co/docs/transformers/en/chat_templating_multimodal
50+
[
51+
{
52+
"role": "user",
53+
"content": "How to fix slicing in golf?",
54+
}
55+
],
56+
# More prompts...
57+
{"prompt": "Where is the capital of Iceland? "},
58+
],
59+
description="Example queries to prompt the model with. We support both TRT-LLM text-only "
60+
"queries via the 'prompt' key and full-scale HF message template called via "
61+
"apply_chat_template.",
4962
)
5063
sp_kwargs: Dict[str, Any] = Field(
5164
default_factory=lambda: {"max_tokens": 100, "top_k": 200, "temperature": 1.0},
@@ -59,10 +72,28 @@ def model_post_init(self, __context: Any):
5972
NOTE (lucaslie): has to be done with model_post_init to ensure it's always run. field
6073
validators are only run if a value is provided.
6174
"""
62-
queries = [self.queries] if isinstance(self.queries, str) else self.queries
75+
queries = self.queries if isinstance(self.queries, list) else [self.queries]
6376
batch_size = self.batch_size
6477
queries = queries * (batch_size // len(queries) + 1)
65-
self.queries = queries[:batch_size]
78+
queries = queries[:batch_size]
79+
80+
# now let's standardize the queries for the LLM api to understand them
81+
queries_processed = []
82+
for query in queries:
83+
if isinstance(query, str):
84+
queries_processed.append({"prompt": query})
85+
elif isinstance(query, dict):
86+
queries_processed.append(query)
87+
elif isinstance(query, list):
88+
queries_processed.append(
89+
{
90+
"prompt": "Fake prompt. Check out messages field for the HF chat template.",
91+
"messages": query, # contains the actual HF chat template
92+
}
93+
)
94+
else:
95+
raise ValueError(f"Invalid query type: {type(query)}")
96+
self.queries = queries_processed
6697

6798
@field_validator("sp_kwargs", mode="after")
6899
@classmethod
@@ -239,6 +270,9 @@ def main(config: Optional[ExperimentConfig] = None):
239270

240271
# prompt the model and print its output
241272
ad_logger.info("Running example prompts...")
273+
274+
# now let's try piping through multimodal data
275+
242276
outs = llm.generate(
243277
config.prompt.queries,
244278
sampling_params=SamplingParams(**config.prompt.sp_kwargs),

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 102 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,21 @@
44
import os
55
import types
66
from contextlib import contextmanager, nullcontext
7-
from typing import Any, Dict, List, Optional, Union
7+
from typing import Any, Dict, List, Optional, Tuple, Union
88

99
import torch
1010
import torch.nn as nn
1111
from accelerate import init_empty_weights, load_checkpoint_in_model
1212
from accelerate.utils import modeling
1313
from huggingface_hub import HfApi, snapshot_download
1414
from huggingface_hub.utils import HFValidationError, filter_repo_objects, validate_repo_id
15+
from PIL import Image
1516
from torch._prims_common import DeviceLikeType
1617
from transformers import (
1718
AutoConfig,
1819
AutoModelForCausalLM,
1920
AutoModelForImageTextToText,
21+
AutoProcessor,
2022
AutoTokenizer,
2123
PretrainedConfig,
2224
)
@@ -27,7 +29,7 @@
2729
WEIGHTS_NAME,
2830
)
2931

30-
from ..custom_ops.attention_interface import CacheConfig
32+
from ..custom_ops.attention_interface import CacheConfig, Dim, DynamicShapeCallback
3133
from ..utils._config import deep_merge_dicts
3234
from ..utils.logger import ad_logger
3335
from .factory import ModelFactory, ModelFactoryRegistry
@@ -108,10 +110,6 @@ def __init__(self, *args, **kwargs):
108110
def autoconfig_from_pretrained(self):
109111
return AutoConfig.from_pretrained
110112

111-
@property
112-
def autotokenizer_from_pretrained(self):
113-
return AutoTokenizer.from_pretrained
114-
115113
# TODO (@lucaslie): Do we ever want to switch to from_pretrained?
116114
@property
117115
def automodel_from_config(self):
@@ -200,7 +198,7 @@ def init_tokenizer(self) -> Optional[Any]:
200198
"""Initialize the tokenizer—either a custom name or the model's default."""
201199
if self.tokenizer is None:
202200
return None
203-
return self.autotokenizer_from_pretrained(self.tokenizer, **self.tokenizer_kwargs)
201+
return AutoTokenizer.from_pretrained(self.tokenizer, **self.tokenizer_kwargs)
204202

205203
@staticmethod
206204
def _get_ignore_patterns(repo_id: str, skip_prefetch_weights: bool) -> List[str]:
@@ -366,3 +364,100 @@ def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
366364
@property
367365
def automodel_from_config(self):
368366
return AutoModelForImageTextToText.from_config
367+
368+
def init_tokenizer(self) -> Optional[Any]:
369+
"""Initialize the tokenizer—either a custom name or the model's default."""
370+
processor = self.init_processor()
371+
if processor is None:
372+
return None
373+
return processor.tokenizer
374+
375+
def init_processor(self) -> Optional[Any]:
376+
"""Initialize the processor for the model."""
377+
if self.tokenizer is None:
378+
return None
379+
return AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs)
380+
381+
@staticmethod
382+
def _simple_forward(
383+
model: nn.Module,
384+
input_ids: torch.Tensor,
385+
position_ids: torch.Tensor,
386+
pixel_values: torch.Tensor,
387+
):
388+
"""A simple forward pass for the model to functionalize the args.
389+
390+
This follows the standard function signature as expected by factory.py.
391+
"""
392+
return type(model).forward(
393+
model,
394+
input_ids=input_ids,
395+
position_ids=position_ids,
396+
pixel_values=pixel_values,
397+
)
398+
399+
def get_example_inputs(self) -> Dict[str, torch.Tensor]:
400+
"""Return a dictionary of example inputs for the model."""
401+
402+
def _prep_seq(text, img1, img2):
403+
return [
404+
{
405+
"role": "user",
406+
"content": [
407+
{"type": "image", "image": img1},
408+
{"type": "image", "image": img2},
409+
{"type": "text", "text": text},
410+
],
411+
}
412+
]
413+
414+
# Create a batch of conversations (batch_size = 2)
415+
batch_messages = [
416+
_prep_seq(
417+
"Describe what you see in the two images and their differences.",
418+
Image.new("RGB", (16, 16), color=(128, 128, 128)),
419+
Image.new("RGB", (16, 16), color=(64, 64, 64)),
420+
),
421+
_prep_seq(
422+
"What are the main differences between these two images?",
423+
Image.new("RGB", (16, 16), color=(255, 0, 0)),
424+
Image.new("RGB", (16, 16), color=(0, 255, 0)),
425+
),
426+
]
427+
428+
processor = AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs)
429+
inputs = processor.apply_chat_template(
430+
batch_messages,
431+
add_generation_prompt=True,
432+
tokenize=True,
433+
return_dict=True,
434+
return_tensors="pt",
435+
padding=True,
436+
return_attention_mask=False,
437+
)
438+
439+
return {
440+
"input_ids": inputs["input_ids"],
441+
"pixel_values": inputs["pixel_values"],
442+
}
443+
444+
def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]:
445+
"""Return a dictionary of extra inputs for the model.
446+
447+
Returns:
448+
A dictionary of extra inputs for the model where the key corresponds to the argument
449+
name and the value corresponds to a tuple of (example_input, dynamic_shape_callback).
450+
The dynamic shape callback is a function that returns the dynamic shape of the extra
451+
input.
452+
"""
453+
454+
def _get_dynamic_shape():
455+
return {
456+
# TODO (lucaslie): how to set default values for dynamic shapes?
457+
0: Dim("img_batch_size", max=10),
458+
2: Dim("img_height", min=32, max=2048),
459+
3: Dim("img_width", min=32, max=2048),
460+
}
461+
462+
none_pixel_values = torch.zeros(0, 3, 336, 336)
463+
return {"pixel_values": (none_pixel_values, _get_dynamic_shape)}

0 commit comments

Comments
 (0)