Skip to content

Commit 5bf54e0

Browse files
mmoskalJC1DA
andauthored
add vision model support via python plugin (#17)
* start on python scripting * more python * somewhat working chat templates * experiments with tensor wrapping * typed output from input_processor * prompt params plumbing * rework dtype support * make TlcShape into value type * plumb args py -> C++ * pass bos/eos to python * enable debug mode in release profile * add additional args * fix script * Add simple qwen2-vl preprocess func test (#15) * Simple test for Qwen2-VL image processing * update vlm test * move files out of tests folder * model is bf16 native; disable double-load of tokenizer * disable warmup when py enabled * better logging * fix tensor location * prompt_tasks -> input_token_extra_ids * add standalone trt example * print out errors * mrope on CPU * allow image inputs * Add input_processor for phi-3.5-vision * better error for missing n_vocab_override; abort on failed CHECK() * bump llg * fix copy instructions in docs * use strum::FromRepr on C enums * fixes for batch size = 1 * add more kv cache control params * bump llg * pass max_new_tokens etc to plugin * Add llama-3.2-vision input_processor * cargo update * Clean up input_processor for qwen2-vl * remove unused imports * remove playground files * don't pip install on plugin construction * add link to gh issue --------- Co-authored-by: JC1DA <jc1da.3011@gmail.com>
1 parent 5573f01 commit 5bf54e0

32 files changed

+1684
-172
lines changed

Cargo.lock

Lines changed: 32 additions & 17 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ exclude = [
1616
resolver = "2"
1717

1818
[profile.release]
19-
# debug = 1
19+
debug = 1
2020

2121
[patch.crates-io]
2222
derivre = { path = "derivre" }

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ trtllm-build --checkpoint_dir /models/model-ckpt \
139139
# Clean up checkpoint (optional)
140140
rm -rf /models/model-ckpt
141141

142-
# Finally, copy tokenizer.json and tokenizer_config.json
143-
cp /models/Meta-Llama-3.1-8B-Instruct/tokenizer.json /models/model-engine
144-
cp /models/Meta-Llama-3.1-8B-Instruct/tokenizer_config.json /models/model-engine
142+
# Finally, copy tokenizer and preprocessor files to engine folder
143+
cp /models/Meta-Llama-3.1-8B-Instruct/tokenizer*.json /models/model-engine
144+
cp /models/Meta-Llama-3.1-8B-Instruct/preprocessor*.json /models/model-engine # this may be missing
145145

146146
# Exit the container
147147
exit

docker/Dockerfile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ RUN pip uninstall -y guidance
6363

6464
RUN pip install --upgrade transformers
6565

66+
# TODO test this
67+
# RUN pip install flash-attn --no-build-isolation
68+
# RUN pip install qwen_vl_utils
69+
6670
RUN cd /usr/local/lib/python3.12/dist-packages/tensorrt_llm/libs/ && \
6771
ln -s libnvinfer_plugin_tensorrt_llm.so libnvinfer_plugin_tensorrt_llm.so.10
6872

llgtrt/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ version = "0.16.6"
44
edition = "2021"
55

66
[dependencies]
7-
axum = { version = "0.7" }
7+
axum = { version = "0.7", features = ["macros"] }
88
tokio = { version = "1.33.0", features = ["full"] }
99
async-stream = "0.3.5"
1010
anyhow = { version = "1.0.75", features = ["backtrace"] }
@@ -27,3 +27,5 @@ json5 = "0.4.1"
2727
minijinja-contrib = { version = "2.3.1", features = ["pycompat"] }
2828
safetensors = "0.5.2"
2929
memmap2 = "0.9.5"
30+
pyo3 = { version = "0.23.4", features = ["anyhow", "serde"] }
31+
num-traits = "0.2.19"
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# this currently doesn't work due to https://github.com/NVIDIA/TensorRT-LLM/issues/2796
2+
3+
import copy
4+
import requests
5+
import llgtrt_base
6+
import torch
7+
8+
from PIL import Image
9+
from transformers import MllamaForConditionalGeneration, AutoProcessor
10+
from llgtrt_native import PluginInit
11+
12+
class Plugin(llgtrt_base.PluginBase):
13+
def __init__(self, init: PluginInit):
14+
super().__init__(init)
15+
self.model = MllamaForConditionalGeneration.from_pretrained(
16+
init.hf_model_dir,
17+
device_map="cpu",
18+
trust_remote_code=True
19+
)
20+
self.processor = AutoProcessor.from_pretrained(
21+
init.hf_model_dir,
22+
trust_remote_code=True,
23+
)
24+
25+
# move visual model to gpu
26+
self.model.vision_model = self.model.vision_model.to("cpu")
27+
self.model.multi_modal_projector = self.model.multi_modal_projector.to("cpu")
28+
print("Plugin initialized from HF model directory:", init.hf_model_dir)
29+
30+
def process_input(
31+
self, params: llgtrt_base.ProcessInputParams
32+
) -> llgtrt_base.ProcessInputResult:
33+
messages = params.messages
34+
print("process_input called, ", messages)
35+
36+
messages, urls = self._process_messages(messages)
37+
38+
images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
39+
40+
prompt = self.processor.tokenizer.apply_chat_template(
41+
messages,
42+
tokenize=False,
43+
add_generation_prompt=True
44+
)
45+
46+
if len(images) == 0:
47+
return llgtrt_base.ProcessInputResult(
48+
prompt=prompt,
49+
tokens=self.processor.tokenizer.apply_chat_template(messages, tokenize=True,
50+
add_generation_prompt=True
51+
)
52+
)
53+
54+
inputs = self.processor(
55+
images,
56+
prompt,
57+
add_special_tokens=False,
58+
return_tensors="pt"
59+
).to(self.model.vision_model.device)
60+
61+
vision_outputs = self.model.vision_model(
62+
pixel_values=inputs["pixel_values"],
63+
aspect_ratio_ids=inputs["aspect_ratio_ids"],
64+
aspect_ratio_mask=inputs["aspect_ratio_mask"],
65+
output_hidden_states=False,
66+
output_attentions=False,
67+
return_dict=True,
68+
)
69+
cross_attention_states = vision_outputs[0]
70+
cross_attention_states = self.model.multi_modal_projector(cross_attention_states).reshape(
71+
-1, self.model.hidden_size
72+
)
73+
74+
cross_attention_mask = _prepare_cross_attention_mask(
75+
inputs["cross_attention_mask"][0],
76+
num_vision_tokens=self.model.vision_model.num_patches,
77+
dtype=self.model.dtype,
78+
max_new_tokens=params.max_new_tokens
79+
)
80+
81+
cross_attention_mask = cross_attention_mask.reshape(-1, cross_attention_states.shape[0])
82+
83+
84+
r = llgtrt_base.ProcessInputResult(
85+
prompt=prompt,
86+
tokens=inputs["input_ids"].cpu().numpy()[0].tolist()
87+
)
88+
r.encoder_input_features = cross_attention_states.cuda().half() # Change this to bfloat16 if engine is using bfloat16
89+
r.cross_attention_masks = (cross_attention_mask).cuda()
90+
r.skip_cross_attn_blocks = torch.Tensor([False]).cuda()
91+
r.encoder_output_length = cross_attention_states.shape[0]
92+
93+
return r
94+
95+
def _process_messages(self, messages: list[dict]):
96+
urls = []
97+
messages = copy.deepcopy(messages)
98+
for m in messages:
99+
c = m.get("content", None)
100+
if isinstance(c, list):
101+
parts_to_change = []
102+
for part in c:
103+
if part["type"] == "image_url":
104+
url = part["image_url"]["url"]
105+
urls.append(url)
106+
parts_to_change.append(part)
107+
108+
for part in parts_to_change:
109+
part["type"] = "image"
110+
part.pop("image_url", None)
111+
112+
return messages, urls
113+
114+
115+
def _prepare_cross_attention_mask(
116+
cross_attention_mask: torch.Tensor,
117+
num_vision_tokens: int,
118+
dtype: str,
119+
max_new_tokens=100,
120+
) -> torch.Tensor:
121+
text_total_length, *_ = cross_attention_mask.shape
122+
cross_attention_mask = cross_attention_mask.repeat_interleave(
123+
num_vision_tokens, dim=2)
124+
125+
cross_attention_mask = cross_attention_mask.view(
126+
text_total_length, -1)
127+
cross_attention_mask = cross_attention_mask.unsqueeze(1)
128+
cross_attention_mask = cross_attention_mask.to(
129+
dtype).to(torch.bool).reshape(
130+
[-1, cross_attention_mask.shape[-1]])
131+
132+
# prepare cross_attention_mask for generation phase and concat them
133+
tmp_mask = [cross_attention_mask] + [
134+
cross_attention_mask[-1:, :] for _ in range(max_new_tokens)
135+
]
136+
cross_attention_mask = torch.concat(tmp_mask)
137+
138+
return cross_attention_mask

0 commit comments

Comments
 (0)