Skip to content

Commit 7020b3e

Browse files
authored
integrated vlm code for benchmark for Eagle2 and Qwen2.5-VL (#3698)
1 parent 0559c35 commit 7020b3e

File tree

10 files changed

+1388
-37
lines changed

10 files changed

+1388
-37
lines changed

docsrc/tutorials/compile_hf_models.rst

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Overview of tools/llm Directory
1818
The ``tools/llm`` directory provides the following tools to compile LLM models from Huggingface:
1919

2020
* **run_llm.py**: Main entry point for model compilation, generating outputs, and benchmarking
21+
* **run_vlm.py**: Entry point for compiling and benchmarking Visual Language Models (VLMs)
2122
* **Static Cache Utilities**: ``static_cache_v1.py`` and ``static_cache_v2.py`` for KV cache optimization
2223
* **SDPA Attention**: ``sdpa_converter.py`` and ``register_sdpa.py`` for registering scaled dot-product attention converter and lowering pass.
2324
* **Testing Components**: Model-specific test files for validation
@@ -64,6 +65,30 @@ We have officially verified support for the following LLM families:
6465
- FP16, FP32
6566
- Yes
6667

68+
Supported VLM Models
69+
--------------------
70+
We have officially verified support for the following Visual Language Models (VLMs):
71+
72+
.. list-table::
73+
:widths: 20 40 20 20 20
74+
:header-rows: 1
75+
76+
* - Model Series
77+
- HuggingFace Model Card
78+
- Precision
79+
- KV Cache Support ?
80+
- Component Support
81+
* - Qwen 2.5 VL
82+
- Qwen/Qwen2.5-VL-3B-Instruct
83+
- FP16, FP32
84+
- Yes (static_v1 only)
85+
- Language Model only (Image Encoder not supported)
86+
* - Eagle2
87+
- nvidia/Eagle2-2B
88+
- FP16, FP32
89+
- Yes (static_v1 only)
90+
- Language Model and Image Encoder both supported
91+
6792
Getting Started with run_llm.py
6893
-------------------------------
6994

@@ -116,6 +141,36 @@ Other Usage Examples
116141
python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP32 --benchmark
117142
118143
144+
Getting Started with run_vlm.py
145+
-------------------------------
146+
147+
For Visual Language Models (VLMs), use ``run_vlm.py`` to compile and benchmark models that process both text and images.
148+
149+
Basic Usage
150+
^^^^^^^^^^^
151+
152+
.. code-block:: bash
153+
154+
python tools/llm/run_vlm.py \
155+
--model Qwen/Qwen2.5-VL-3B-Instruct \
156+
--precision FP16 \
157+
--num_tokens 128 \
158+
--cache static_v1 \
159+
--enable_pytorch_run \
160+
--benchmark
161+
162+
Key Arguments
163+
^^^^^^^^^^^^^
164+
165+
* ``--model``: Name or path of the HuggingFace VLM
166+
* ``--prompt``: Input prompt for generation
167+
* ``--image_path``: (Optional) Path to input image file. If not provided, will use a sample image
168+
* ``--precision``: Precision mode (``FP16``, ``FP32``)
169+
* ``--num_tokens``: Number of output tokens to generate
170+
* ``--cache``: KV cache type (``static_v1`` or empty for no KV caching)
171+
* ``--benchmark``: Enable benchmarking mode
172+
* ``--enable_pytorch_run``: Also run and compare PyTorch baseline
173+
119174
KV Caching in Torch-TensorRT
120175
---------------------------------
121176

@@ -126,7 +181,7 @@ The length of KV cache = input sequence length + output sequence length (specifi
126181
Static Cache v1
127182
^^^^^^^^^^^^^^^^
128183

129-
The ``static_cache_v1.py`` implements KV cache in the model graph as follows:
184+
The ``static_cache_v1.py`` implements KV cache in the model graph as follows:
130185

131186
.. code-block:: python
132187
@@ -214,9 +269,13 @@ Limitations and Known Issues
214269

215270
* Sliding window attention (used in Gemma3 and Qwen 3 models) is not yet supported
216271
* Some model architectures (e.g. Phi-4) have issues with exporting the torch model.
272+
* For VLMs, Qwen2.5-VL image encoder compilation is not supported due to dynamic operations incompatible with torch.export.
217273

218274
Requirements
219275
^^^^^^^^^^^^
220276

221277
* Torch-TensorRT 2.8.0 or later
222-
* Transformers v4.52.3
278+
* Transformers v4.52.3
279+
* For VLM models (run_vlm.py):
280+
- ``pip install qwen-vl-utils`` (for Qwen2.5-VL-3B-Instruct model)
281+
- ``pip install flash-attn --no-build-isolation -v`` (for Eagle2-2B model)

py/torch_tensorrt/dynamo/conversion/impl/matmul.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,13 @@ def matrix_multiply(
4848
input, other = broadcast(
4949
ctx, input, other, f"{name}_input", f"{name}_other", preset_diff
5050
)
51+
# Get the original input dtype
52+
input_dtype = _enums.dtype._from(input.dtype).to(torch.dtype)
53+
5154
if (
5255
ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
5356
and ctx.compilation_settings.use_fp32_acc
57+
and input_dtype == torch.float16
5458
):
5559
input = cast_trt_tensor(ctx, input, torch.float32, f"{name}_input_casted")
5660
other = cast_trt_tensor(ctx, other, torch.float32, f"{name}_other_casted")
@@ -63,9 +67,10 @@ def matrix_multiply(
6367
if (
6468
ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
6569
and ctx.compilation_settings.use_fp32_acc
70+
and input_dtype == torch.float16
6671
):
6772
matmul_output = cast_trt_tensor(
68-
ctx, matmul_output, torch.float16, f"{name}_output_casted"
73+
ctx, matmul_output, input_dtype, f"{name}_output_casted"
6974
)
7075

7176
set_layer_name(matmul_layer, target, name, source_ir)

tests/py/dynamo/models/test_llm_models.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,7 @@ def test_llm_decoder_layer(precision):
4444
.to("cuda")
4545
)
4646

47-
if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None:
48-
register_sdpa._SDPA_MAPPING[args.model](model_config=model.config)
49-
else:
50-
register_sdpa._SDPA_MAPPING["default"](model_config=model.config)
47+
register_sdpa.enable_sdpa_converter(args.model, model.config)
5148
model = model.to(dtype)
5249
# use randint will generate nan values in the logits, use a fixed input_ids for now
5350
# input_ids = torch.randint(0, model.config.vocab_size, (1, args.num_tokens)).to("cuda")

tools/llm/README.md

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Optimizing LLMs in Torch-TensorRT
22

3-
This directory provides utilities and scripts for compiling, optimizing, and benchmarking Large Language Models (LLMs) using Torch-TensorRT, with a focus on efficient inference on NVIDIA GPUs. The main entry point is `run_llm.py`, which demonstrates how to export, compile, and run LLMs with various caching strategies and precision modes. Note that this is an **experimental release** and APIs may change in future versions.
3+
This directory provides utilities and scripts for compiling, optimizing, and benchmarking Large Language Models (LLMs) and Visual Language Models (VLMs) using Torch-TensorRT, with a focus on efficient inference on NVIDIA GPUs. The main entry points are `run_llm.py` for text-only LLMs and `run_vlm.py` for vision-language models. Note that this is an **experimental release** and APIs may change in future versions.
44

55
### Key Features
66

77
- **Model Support:** Works with popular LLMs such as Llama-3, Qwen2.5, etc.
8+
- **VLM Support:** Supports Visual Language Models like Qwen2.5-VL and Eagle2.
89
- **Precision Modes:** Supports FP16, BF16, and FP32.
910
- **KV Cache:** Supports static and dynamic KV cache for efficient autoregressive decoding.
1011
- **Benchmarking:** Measures and compares throughput and latency for PyTorch and TensorRT backends.
@@ -25,20 +26,33 @@ We have officially verified support for the following models:
2526
| Qwen 3 | Qwen/Qwen3-0.6B<br>Qwen/Qwen3-1.7B<br>Qwen/Qwen3-4B<br>Qwen/Qwen3-8B | FP16, FP32 | Yes |
2627
| Gemma 3 | google/gemma-3-1b-it | FP16, FP32 | Yes |
2728

29+
### Supported VLM Models
30+
31+
| Model Series | HF Model Card | Precision | KV Cache Supported ? |
32+
|--------------|---------------|-----------|-------------------|
33+
| Qwen 2.5 VL | Qwen/Qwen2.5-VL-3B-Instruct | FP16, FP32 | Yes |
34+
| Eagle2 | nvidia/Eagle2-2B | FP16, FP32 | Yes |
2835

2936
### Usage
3037

31-
The main entry point is : `run_llm.py`
38+
#### Text-only LLMs: `run_llm.py`
3239

3340
```bash
3441
python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark
3542
```
3643

44+
#### Vision Language Models: `run_vlm.py`
45+
46+
```bash
47+
python run_vlm.py --model nvidia/Eagle2-2B --precision FP16 --num_tokens 128 --cache static_v1 --enable_pytorch_run --benchmark
48+
```
49+
3750
#### Key Arguments
3851

39-
- `--model`: Name or path of the HuggingFace LLM.
52+
- `--model`: Name or path of the HuggingFace LLM/VLM.
4053
- `--tokenizer`: (Optional) Tokenizer name; defaults to model.
4154
- `--prompt`: Input prompt for generation.
55+
- `--image_path`: (Optional) Path to input image file for VLM models. If not provided, will use a sample image.
4256
- `--precision`: Precision mode (`FP16`, `FP32`).
4357
- `--num_tokens`: Number of output tokens to generate.
4458
- `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching).
@@ -61,8 +75,15 @@ This codebase can be extended to
6175

6276
## Limitations
6377
- We do not currently support sliding window attention (used in Gemma3 and Qwen 3 models) yet.
78+
- **Flash Attention Limitation**: Some models (e.g., Eagle2-2B) internally use flash attention operations (`torch.ops.flash_attn._flash_attn_forward.default`) which require the `flash-attn` package to be installed. Without flash-attn, these models will fail to load or run properly.
79+
- **Qwen2.5‑VL vision is not compiled (LLM-only)**: We only compile the language model for Qwen2.5‑VL. The vision encoder is skipped because its `get_window_index` relies on dynamic Python operations.
6480

6581
## Requirements
6682

6783
- Torch-TensorRT 2.8.0
68-
- Transformers v4.52.3
84+
- Transformers v4.52.3
85+
- For VLM models (run_vlm.py):
86+
- `pip install qwen-vl-utils` (for Qwen2.5-VL-3B-Instruct model)
87+
- **Flash Attention**: For models using flash attention operations (e.g., Eagle2-2B), install one of the following:
88+
- **Fast installation (recommended)**: `pip install flash-attn==2.8.1` (pre-built wheel, should work)
89+
- **Source build (slow)**: `pip install flash-attn --no-build-isolation -v` (fallback if pre-built wheels fail)

tools/llm/run_llm.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,7 @@ def get_model(args):
5959
.cuda()
6060
)
6161
# register SDPA variant for the model
62-
if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None:
63-
register_sdpa._SDPA_MAPPING[args.model](model_config=model.config)
64-
else:
65-
register_sdpa._SDPA_MAPPING["default"](model_config=model.config)
62+
register_sdpa.enable_sdpa_converter(args.model, model.config)
6663

6764
if args.precision == "FP16":
6865
model = model.to(torch.float16)

0 commit comments

Comments
 (0)