diff --git a/docsrc/tutorials/compile_hf_models.rst b/docsrc/tutorials/compile_hf_models.rst
index f6da87b145..406cbb9140 100644
--- a/docsrc/tutorials/compile_hf_models.rst
+++ b/docsrc/tutorials/compile_hf_models.rst
@@ -18,6 +18,7 @@ Overview of tools/llm Directory
The ``tools/llm`` directory provides the following tools to compile LLM models from Huggingface:
* **run_llm.py**: Main entry point for model compilation, generating outputs, and benchmarking
+* **run_vlm.py**: Entry point for compiling and benchmarking Visual Language Models (VLMs)
* **Static Cache Utilities**: ``static_cache_v1.py`` and ``static_cache_v2.py`` for KV cache optimization
* **SDPA Attention**: ``sdpa_converter.py`` and ``register_sdpa.py`` for registering scaled dot-product attention converter and lowering pass.
* **Testing Components**: Model-specific test files for validation
@@ -60,6 +61,30 @@ We have officially verified support for the following LLM families:
- FP16, FP32
- Yes
+Supported VLM Models
+--------------------
+We have officially verified support for the following Visual Language Models (VLMs):
+
+.. list-table::
+ :widths: 20 40 20 20 20
+ :header-rows: 1
+
+ * - Model Series
+ - HuggingFace Model Card
+ - Precision
+ - KV Cache Support ?
+ - Component Support
+ * - Qwen 2.5 VL
+ - Qwen/Qwen2.5-VL-3B-Instruct
+ - FP16, FP32
+ - Yes (static_v1 only)
+ - Language Model only (Image Encoder not supported)
+ * - Eagle2
+ - nvidia/Eagle2-2B
+ - FP16, FP32
+ - Yes (static_v1 only)
+ - Language Model and Image Encoder both supported
+
Getting Started with run_llm.py
-------------------------------
@@ -112,6 +137,36 @@ Other Usage Examples
python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP32 --benchmark
+Getting Started with run_vlm.py
+-------------------------------
+
+For Visual Language Models (VLMs), use ``run_vlm.py`` to compile and benchmark models that process both text and images.
+
+Basic Usage
+^^^^^^^^^^^
+
+.. code-block:: bash
+
+ python tools/llm/run_vlm.py \
+ --model Qwen/Qwen2.5-VL-3B-Instruct \
+ --precision FP16 \
+ --num_tokens 128 \
+ --cache static_v1 \
+ --enable_pytorch_run \
+ --benchmark
+
+Key Arguments
+^^^^^^^^^^^^^
+
+* ``--model``: Name or path of the HuggingFace VLM
+* ``--prompt``: Input prompt for generation
+* ``--image_path``: (Optional) Path to input image file. If not provided, will use a sample image
+* ``--precision``: Precision mode (``FP16``, ``FP32``)
+* ``--num_tokens``: Number of output tokens to generate
+* ``--cache``: KV cache type (``static_v1`` or empty for no KV caching)
+* ``--benchmark``: Enable benchmarking mode
+* ``--enable_pytorch_run``: Also run and compare PyTorch baseline
+
KV Caching in Torch-TensorRT
---------------------------------
@@ -122,7 +177,7 @@ The length of KV cache = input sequence length + output sequence length (specifi
Static Cache v1
^^^^^^^^^^^^^^^^
-The ``static_cache_v1.py`` implements KV cache in the model graph as follows:
+The ``static_cache_v1.py`` implements KV cache in the model graph as follows:
.. code-block:: python
@@ -210,9 +265,13 @@ Limitations and Known Issues
* Sliding window attention (used in Gemma3 and Qwen 3 models) is not yet supported
* Some model architectures (e.g. Phi-4) have issues with exporting the torch model.
+* For VLMs, Qwen2.5-VL image encoder compilation is not supported due to dynamic operations incompatible with torch.export.
Requirements
^^^^^^^^^^^^
* Torch-TensorRT 2.8.0 or later
-* Transformers v4.52.3
\ No newline at end of file
+* Transformers v4.52.3
+* For VLM models (run_vlm.py):
+ - ``pip install qwen-vl-utils`` (for Qwen2.5-VL-3B-Instruct model)
+ - ``pip install flash-attn --no-build-isolation -v`` (for Eagle2-2B model)
\ No newline at end of file
diff --git a/tools/llm/README.md b/tools/llm/README.md
index a141505517..c0a88635f6 100644
--- a/tools/llm/README.md
+++ b/tools/llm/README.md
@@ -1,10 +1,11 @@
# Optimizing LLMs in Torch-TensorRT
-This directory provides utilities and scripts for compiling, optimizing, and benchmarking Large Language Models (LLMs) using Torch-TensorRT, with a focus on efficient inference on NVIDIA GPUs. The main entry point is `run_llm.py`, which demonstrates how to export, compile, and run LLMs with various caching strategies and precision modes. Note that this is an **experimental release** and APIs may change in future versions.
+This directory provides utilities and scripts for compiling, optimizing, and benchmarking Large Language Models (LLMs) and Visual Language Models (VLMs) using Torch-TensorRT, with a focus on efficient inference on NVIDIA GPUs. The main entry points are `run_llm.py` for text-only LLMs and `run_vlm.py` for vision-language models. Note that this is an **experimental release** and APIs may change in future versions.
### Key Features
- **Model Support:** Works with popular LLMs such as Llama-3, Qwen2.5, etc.
+- **VLM Support:** Supports Visual Language Models like Qwen2.5-VL and Eagle2.
- **Precision Modes:** Supports FP16, BF16, and FP32.
- **KV Cache:** Supports static and dynamic KV cache for efficient autoregressive decoding.
- **Benchmarking:** Measures and compares throughput and latency for PyTorch and TensorRT backends.
@@ -24,20 +25,33 @@ We have officially verified support for the following models:
| Qwen 2.5 | Qwen/Qwen2.5-0.5B-Instruct
Qwen/Qwen2.5-1.5B-Instruct
Qwen/Qwen2.5-4B-Instruct
Qwen/Qwen2.5-7B-Instruct | FP16, FP32 | Yes |
| Qwen 3 | Qwen/Qwen3-0.6B
Qwen/Qwen3-1.7B
Qwen/Qwen3-4B
Qwen/Qwen3-8B | FP16, FP32 | Yes |
+### Supported VLM Models
+
+| Model Series | HF Model Card | Precision | KV Cache Supported ? |
+|--------------|---------------|-----------|-------------------|
+| Qwen 2.5 VL | Qwen/Qwen2.5-VL-3B-Instruct | FP16, FP32 | Yes |
+| Eagle2 | nvidia/Eagle2-2B | FP16, FP32 | Yes |
### Usage
-The main entry point is : `run_llm.py`
+#### Text-only LLMs: `run_llm.py`
```bash
python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark
```
+#### Vision Language Models: `run_vlm.py`
+
+```bash
+python run_vlm.py --model Qwen/Qwen2.5-VL-3B-Instruct --precision FP16 --num_tokens 128 --cache static_v1 --enable_pytorch_run --benchmark
+```
+
#### Key Arguments
-- `--model`: Name or path of the HuggingFace LLM.
+- `--model`: Name or path of the HuggingFace LLM/VLM.
- `--tokenizer`: (Optional) Tokenizer name; defaults to model.
- `--prompt`: Input prompt for generation.
+- `--image_path`: (Optional) Path to input image file for VLM models. If not provided, will use a sample image.
- `--precision`: Precision mode (`FP16`, `FP32`).
- `--num_tokens`: Number of output tokens to generate.
- `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching).
@@ -64,4 +78,7 @@ This codebase can be extended to
## Requirements
- Torch-TensorRT 2.8.0
-- Transformers v4.52.3
\ No newline at end of file
+- Transformers v4.52.3
+- For VLM models (run_vlm.py):
+ - `pip install qwen-vl-utils` (for Qwen2.5-VL-3B-Instruct model)
+ - `pip install flash-attn --no-build-isolation -v` (for Eagle2-2B model)
\ No newline at end of file
diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py
new file mode 100644
index 0000000000..e26d971f32
--- /dev/null
+++ b/tools/llm/run_vlm.py
@@ -0,0 +1,678 @@
+"""
+.. _run_vlm:
+
+Benchmarking VLM Inference with Torch-TensorRT
+==========================================================
+
+This script provides a framework for benchmarking the performance of Visual-Language
+Models (VLMs). It optimizes the two most computationally intensive components of a
+VLM—the language model and the vision model (image feature extraction)—using
+the Torch-TensorRT dynamo backend.
+
+Key Features:
+- **Component-wise Optimization**: Compiles both the language and vision models
+ separately with Torch-TensorRT to accelerate inference.
+- **Performance Benchmarking**: Runs the model for multiple iterations to
+ measure and compare inference latency against the PyTorch baseline.
+- **Output Verification**: Checks for token-level consistency between the optimized
+ TensorRT model and the original PyTorch model to ensure correctness.
+- **KV Cache Testing**: Includes options to test inference with and without
+ KV caching to evaluate its impact on performance.
+
+This tool mirrors the style and structure of `run_llm.py`, providing a clear
+workflow for VLM optimization and analysis.
+
+Dependencies:
+- For Qwen VLM models: pip install qwen-vl-utils
+- For Eagle2 models: pip install flash-attn --no-build-isolation -v
+"""
+
+import argparse
+import copy
+import os
+import sys
+from contextlib import nullcontext
+from typing import Tuple
+
+import requests
+import torch
+import torch_tensorrt
+
+# we "monkey-patch" the global attention function map for Qwen2.
+# This ensures that any part of the code (including torch.export) requesting
+# "flash_attention_2" will receive the "sdpa" implementation instead.
+# This patch is global for the script's execution context.
+import transformers.models.qwen2.modeling_qwen2 as mq
+from PIL import Image
+from torchtrt_ext import register_sdpa
+from transformers import AutoConfig, AutoModel, AutoProcessor
+from utils import (
+ export_llm,
+ generate_mm,
+ generate_mm_qwen2_5_vl,
+ generate_mm_qwen2_5_vl_with_static_cache,
+ generate_mm_with_static_cache,
+ record_stats,
+)
+
+# --- WORKAROUND FOR EAGLE2 SDPA COMPILATION ---
+# Eagle2's language model (Qwen2) implicitly defaults to "flash_attention_2"
+# due to settings in its remote code and config.json. This prevents direct
+# compilation with SDPA. To work around this without modifying the library,
+
+
+mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"]
+# --- END WORKAROUND ---
+
+# --- Model-specific constants for benchmark and compilation ---
+# Centralizing these values improves readability and maintainability.
+MODEL_CONSTANTS = {
+ "nvidia/Eagle2-2B": {
+ "EXAMPLE_SEQLEN": 2560, # A fixed sequence length for creating the example tensor for TRT compilation.
+ "IMAGE_TOKENS": 1792, # Number of special tokens used to represent the image patch embeddings in the input sequence for Eagle2-2B VLM.
+ "PROMPT_WRAPPER_TOKENS": 26, # The number of special/processing tokens added by the processor's chat template in benchmark mode.
+ },
+ "Qwen/Qwen2.5-VL-3B-Instruct": {
+ "EXAMPLE_SEQLEN": 2560,
+ "IMAGE_TOKENS": 1426,
+ "PROMPT_WRAPPER_TOKENS": 21,
+ },
+}
+# --- END Model-specific constants ---
+
+# -----------------------------------------------------------------------------#
+# Model loading helpers
+# -----------------------------------------------------------------------------#
+
+
+def _load_eagle2(device: torch.device, torch_dtype: torch.dtype):
+ """
+ Load nvidia/Eagle2-2B model and processor, ensuring the language model uses SDPA.
+
+ Returns
+ -------
+ tuple[torch.nn.Module, transformers.AutoProcessor, torch.nn.Embedding]
+ The model, its processor and the language-model input embedding layer.
+ """
+ model_id = "nvidia/Eagle2-2B"
+ try:
+ with torch.no_grad():
+ model = (
+ AutoModel.from_pretrained(
+ model_id,
+ trust_remote_code=True,
+ torch_dtype=torch_dtype,
+ # attn_implementation="sdpa" is ignored due to the model's remote code.
+ )
+ .eval()
+ .to(device)
+ )
+ except ImportError as e:
+ if "flash_attn" in str(e):
+ raise ImportError(
+ "FlashAttention2 is required for Eagle2 models but not installed. "
+ "Please install it using: pip install flash-attn --no-build-isolation -v"
+ ) from e
+ raise
+
+ processor = AutoProcessor.from_pretrained(
+ model_id, trust_remote_code=True, use_fast=True
+ )
+ if hasattr(processor, "tokenizer"):
+ processor.tokenizer.padding_side = "left"
+
+ emb_layer = model.language_model.get_input_embeddings().to(torch_dtype).to(device)
+ return model, processor, emb_layer
+
+
+def _load_qwen2_5_vl(device, torch_dtype: torch.dtype):
+ """
+ Load Qwen2.5-VL model and processor.
+ """
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
+
+ model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ model_id, torch_dtype=torch_dtype, device_map=device
+ ).eval()
+ processor = AutoProcessor.from_pretrained(model_id)
+ emb_layer = model.model.get_input_embeddings().to(torch_dtype).to(device)
+ return model, processor, emb_layer
+
+
+def load_model(
+ model_name: str, device: torch.device, torch_dtype: torch.dtype
+) -> Tuple[torch.nn.Module, AutoProcessor, torch.nn.Embedding]:
+ """Dispatch helper for supported VLMs."""
+ if model_name == "nvidia/Eagle2-2B":
+ return _load_eagle2(device, torch_dtype)
+ elif model_name == "Qwen/Qwen2.5-VL-3B-Instruct":
+ return _load_qwen2_5_vl(device, torch_dtype)
+ msg = f"Unsupported model: '{model_name}'. Supported models are: ['nvidia/Eagle2-2B', 'Qwen/Qwen2.5-VL-3B-Instruct']"
+ raise ValueError(msg)
+
+
+# -----------------------------------------------------------------------------#
+# Input loading helpers
+# -----------------------------------------------------------------------------#
+
+
+def load_inputs(args: argparse.Namespace, processor, device: torch.device):
+ """
+ Loads and constructs the input dictionary for the specified VLM model.
+ """
+ # Load image from local path if provided, otherwise use default URL
+ if args.image_path is not None:
+ # Use local image file
+ image = Image.open(args.image_path)
+ else:
+ # Use default URL image
+ url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ image = Image.open(requests.get(url, stream=True).raw)
+
+ if args.benchmark:
+ model_constants = MODEL_CONSTANTS[args.model]
+ image_tokens = model_constants["IMAGE_TOKENS"]
+ wrapper_tokens = model_constants["PROMPT_WRAPPER_TOKENS"]
+
+ prompt_len = args.isl - image_tokens - wrapper_tokens
+ prompt_txt = " ".join(["token"] * max(prompt_len, 0))
+ else:
+ prompt_txt = args.prompt or "Describe this image."
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": image},
+ {"type": "text", "text": prompt_txt},
+ ],
+ }
+ ]
+
+ text = [
+ processor.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ ]
+
+ # --- Model-specific vision processing ---
+ if args.model == "Qwen/Qwen2.5-VL-3B-Instruct":
+ try:
+ from qwen_vl_utils import process_vision_info
+ except ImportError:
+ raise ImportError(
+ "The 'qwen-vl-utils' package is required for Qwen VLM models. "
+ "Please install it using: pip install qwen-vl-utils"
+ )
+
+ image_inputs, video_inputs = process_vision_info(messages)
+ else: # eagle2
+ image_inputs, video_inputs = processor.process_vision_info(messages)
+
+ inputs = processor(
+ text=text,
+ images=image_inputs,
+ videos=video_inputs,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+
+ return inputs
+
+
+# -----------------------------------------------------------------------------#
+# Torch-TensorRT compilation helpers
+# -----------------------------------------------------------------------------#
+
+
+class _LMNoCache(torch.nn.Module):
+ """
+ Thin wrapper that exposes a language model via ``inputs_embeds`` without KV-cache.
+ """
+
+ def __init__(self, lm):
+ super().__init__()
+ self.lm = lm
+
+ def forward(self, inputs_embeds, position_ids):
+ out = self.lm(inputs_embeds=inputs_embeds, position_ids=position_ids)
+ return (
+ out.logits
+ if hasattr(out, "logits")
+ else out.last_hidden_state if hasattr(out, "last_hidden_state") else out
+ )
+
+
+def _compile_lm(
+ language_model: torch.nn.Module,
+ input_embeds: torch.Tensor,
+ args: argparse.Namespace,
+ device: torch.device,
+) -> torch.nn.Module:
+ """
+ Compile the language model component of a VLM with Torch-TensorRT
+ """
+ lm_wrap = _LMNoCache(language_model).to(device).eval()
+ max_seq_len = input_embeds.shape[1] + args.num_tokens
+
+ seq_len = torch.export.Dim("seq", min=1, max=max_seq_len)
+ position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(device)
+
+ dyn_shapes = {"inputs_embeds": {1: seq_len}, "position_ids": {1: seq_len}}
+
+ use_fp32_acc = False
+ use_explicit_typing = False
+ if args.precision == "FP16":
+ enabled_precisions = {torch.float32}
+ use_fp32_acc = True
+ use_explicit_typing = True
+ else: # FP32
+ enabled_precisions = {torch.float32}
+
+ exported_program = export_llm(
+ lm_wrap, input_embeds, min_seq_len=1, max_seq_len=2560
+ )
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_mod = torch_tensorrt.dynamo.compile(
+ exported_program,
+ inputs=[input_embeds, position_ids],
+ enabled_precisions=enabled_precisions,
+ use_explicit_typing=use_explicit_typing,
+ use_fp32_acc=use_fp32_acc,
+ device=device,
+ disable_tf32=args.disable_tf32,
+ use_python_runtime=args.use_python_runtime,
+ offload_module_to_cpu=args.offload_module_to_cpu,
+ min_block_size=args.min_block_size,
+ )
+ return trt_mod
+
+
+def compile_lm_torchtrt(
+ model: torch.nn.Module, args: argparse.Namespace, device: torch.device
+) -> torch.nn.Module:
+ """
+ Compiles the Language Model (LLM) component of the VLM using Torch-TensorRT.
+ """
+ torch_dtype = {
+ "FP16": torch.float16,
+ "BF16": torch.bfloat16,
+ }.get(args.precision, torch.float32)
+
+ lm_model = (
+ model.model
+ if args.model == "Qwen/Qwen2.5-VL-3B-Instruct"
+ else model.language_model
+ )
+
+ model_constants = MODEL_CONSTANTS.get(
+ args.model, {"EXAMPLE_SEQLEN": args.num_tokens}
+ )
+ example_seq_len = model_constants["EXAMPLE_SEQLEN"]
+
+ example_embeds = torch.randn(
+ args.batch_size,
+ example_seq_len,
+ lm_model.config.hidden_size,
+ dtype=torch_dtype,
+ device=device,
+ )
+
+ # All supported models use the same compilation helper.
+ if args.model in ["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"]:
+ return _compile_lm(lm_model, example_embeds, args, device)
+ else:
+ msg = f"Unsupported model: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B', 'Qwen/Qwen2.5-VL-3B-Instruct']"
+ raise ValueError(msg)
+
+
+def _compile_eagle2_vision(
+ vision_model: torch.nn.Module,
+ example_pixel_values: torch.Tensor,
+ args: argparse.Namespace,
+ device: torch.device,
+) -> torch.nn.Module:
+ """
+ Compile Eagle2 vision model with Torch-TensorRT.
+ """
+ # Set precision-specific flags
+ use_fp32_acc = False
+ use_explicit_typing = False
+ if args.precision == "FP16":
+ enabled_precisions = {torch.float32}
+ use_fp32_acc = True
+ use_explicit_typing = True
+ elif args.precision == "BF16":
+ enabled_precisions = {torch.bfloat16}
+ else: # FP32
+ enabled_precisions = {torch.float32}
+
+ with torch.inference_mode():
+ exported_program = torch.export.export(
+ vision_model,
+ (example_pixel_values,),
+ strict=False,
+ )
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_mod = torch_tensorrt.dynamo.compile(
+ exported_program,
+ inputs=[example_pixel_values],
+ enabled_precisions=enabled_precisions,
+ use_explicit_typing=use_explicit_typing,
+ use_fp32_acc=use_fp32_acc,
+ device=device,
+ disable_tf32=args.disable_tf32,
+ use_python_runtime=args.use_python_runtime,
+ offload_module_to_cpu=args.offload_module_to_cpu,
+ min_block_size=args.min_block_size,
+ )
+ return trt_mod
+
+
+def compile_vision_torchtrt(
+ model: torch.nn.Module,
+ args: argparse.Namespace,
+ example_pixel_values: torch.Tensor,
+ device: torch.device,
+) -> torch.nn.Module:
+ """
+ Dispatcher function for vision model compilation.
+ """
+ if args.model == "nvidia/Eagle2-2B":
+ return _compile_eagle2_vision(
+ model.vision_model, example_pixel_values, args, device
+ )
+ elif args.model == "Qwen/Qwen2.5-VL-3B-Instruct":
+ # TODO: Vision model compilation for Qwen2.5-VL is currently skipped.
+ # The model's `get_window_index` method uses dynamic Python list operations
+ # (e.g., .tolist(), .extend()) to process variable-sized image grids for
+ # windowed attention. These operations are incompatible with torch.export's
+ # static graph tracing, preventing successful compilation.
+ return model.visual
+ else:
+ raise ValueError(f"Unsupported model: {args.model}")
+
+
+# -----------------------------------------------------------------------------#
+# Utility helpers
+# -----------------------------------------------------------------------------#
+
+
+def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer):
+ """Pretty-print generated text for comparison."""
+ print(f"========= {backend_name} =========")
+ print(
+ f"{backend_name} model generated text: ",
+ tokenizer.decode(gen_tokens[0], skip_special_tokens=True),
+ )
+ print("===================================")
+
+
+# -----------------------------------------------------------------------------#
+# Main driver
+# -----------------------------------------------------------------------------#
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Run VLM inference (PyTorch & TensorRT back-ends)"
+ )
+ parser.add_argument(
+ "--model",
+ default="nvidia/Eagle2-2B",
+ choices=["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"],
+ help="VLM model name",
+ )
+ parser.add_argument("--prompt", default="Describe this image.", help="Prompt text")
+ parser.add_argument(
+ "--precision",
+ default="FP16",
+ choices=["FP16", "FP32"],
+ help="Computation precision",
+ )
+ parser.add_argument("--iterations", type=int, default=5, help="# iterations")
+ parser.add_argument("--min_block_size", type=int, default=1, help="Min block size")
+ parser.add_argument("--num_tokens", type=int, default=128, help="# new tokens")
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
+ parser.add_argument("--isl", type=int, default=2048, help="Input seq length")
+ parser.add_argument(
+ "--enable_pytorch_run",
+ action="store_true",
+ help="Run the PyTorch baseline as well",
+ )
+ parser.add_argument(
+ "--cache",
+ default="",
+ choices=["", "static_v1"],
+ help="KV-cache variant to use",
+ )
+ parser.add_argument(
+ "--debug", action="store_true", help="Enable Torch-TensorRT debug logs"
+ )
+ parser.add_argument(
+ "--benchmark", action="store_true", help="Enable benchmarking mode"
+ )
+ parser.add_argument(
+ "--image_path",
+ type=str,
+ default=None,
+ help="Path to local image file. If not provided, uses default URL image.",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda:0",
+ help="Device to run inference on (e.g., 'cuda:0', 'cuda:1')",
+ )
+ parser.add_argument(
+ "--disable_tf32",
+ action="store_false",
+ default=True,
+ help="Disable TF32 precision for TensorRT compilation (default: True)",
+ )
+ parser.add_argument(
+ "--use_python_runtime",
+ action="store_false",
+ default=True,
+ help="Use Python runtime for TensorRT compilation (default: True)",
+ )
+ parser.add_argument(
+ "--offload_module_to_cpu",
+ action="store_false",
+ default=True,
+ help="Offload module to CPU for TensorRT compilation (default: True)",
+ )
+
+ args = parser.parse_args()
+
+ device = torch.device(args.device)
+ if device.type == "cuda":
+ torch.cuda.set_device(device)
+
+ # -------------------------------------------------------------------------#
+ # 1. Model / processor / embeddings
+ # -------------------------------------------------------------------------#
+ dtype = {
+ "FP16": torch.float16,
+ "BF16": torch.bfloat16,
+ }.get(args.precision, torch.float32)
+
+ model, processor, emb_layer = load_model(args.model, device, dtype)
+
+ # -------------------------------------------------------------------------#
+ # 2. Input construction (image + text prompt)
+ # -------------------------------------------------------------------------#
+ inputs = load_inputs(args, processor, device)
+
+ max_output_len = inputs["input_ids"].shape[1] + args.num_tokens
+
+ # -------------------------------------------------------------------------#
+ # 3. Optional: PyTorch baseline
+ # -------------------------------------------------------------------------#
+ pyt_gen_tokens = pyt_timings = pyt_stats = None
+ if args.enable_pytorch_run:
+ # For benchmarking, we run the generation with timing enabled.
+ # For regular runs, we run without timing for a single output.
+ if args.benchmark:
+ if args.model == "Qwen/Qwen2.5-VL-3B-Instruct":
+ (
+ pyt_gen_tokens,
+ _,
+ overall_time,
+ _,
+ _,
+ ) = generate_mm_qwen2_5_vl(
+ model,
+ inputs["pixel_values"],
+ inputs["input_ids"],
+ inputs["image_grid_thw"],
+ processor.tokenizer.eos_token_id,
+ emb_layer,
+ max_new_tokens=args.num_tokens,
+ with_timing=True,
+ )
+ else: # eagle2
+ (
+ pyt_gen_tokens,
+ _,
+ overall_time,
+ _,
+ _,
+ ) = generate_mm(
+ model,
+ inputs["pixel_values"],
+ inputs["input_ids"],
+ processor.tokenizer.eos_token_id,
+ emb_layer,
+ max_new_tokens=args.num_tokens,
+ with_timing=True,
+ )
+ pyt_stats = record_stats(
+ "PyTorch",
+ [overall_time / 1000], # time_generate returns seconds
+ args.precision,
+ batch_size=args.batch_size,
+ )
+ else:
+ if args.model == "Qwen/Qwen2.5-VL-3B-Instruct":
+ pyt_gen_tokens = generate_mm_qwen2_5_vl(
+ model,
+ inputs["pixel_values"],
+ inputs["input_ids"],
+ inputs["image_grid_thw"],
+ processor.tokenizer.eos_token_id,
+ emb_layer,
+ max_new_tokens=args.num_tokens,
+ )
+ else: # eagle2
+ pyt_gen_tokens = generate_mm(
+ model,
+ inputs["pixel_values"],
+ inputs["input_ids"],
+ processor.tokenizer.eos_token_id,
+ emb_layer,
+ max_new_tokens=args.num_tokens,
+ )
+
+ # -------------------------------------------------------------------------#
+ # 4. Torch-TensorRT compile & run
+ # -------------------------------------------------------------------------#
+
+ trt_model = copy.deepcopy(model)
+ # 4.1. Vision model compilation
+ # --- Add vision model compilation --- #
+ example_pixel_values = inputs["pixel_values"]
+ trt_vision = compile_vision_torchtrt(model, args, example_pixel_values, device)
+ if args.model == "Qwen/Qwen2.5-VL-3B-Instruct":
+ trt_model.visual = trt_vision
+ else:
+ trt_model.vision_model = trt_vision
+
+ # -------------------------------------------------------------------------#
+ # 4.2. Language model compilation
+ # -------------------------------------------------------------------------#
+ # Register static cache lowering passes if requested
+ # Cache is not applied to vision model.
+ if args.cache == "static_v1":
+ import static_cache_v1 # noqa: F401
+ elif args.cache not in ("", None):
+ raise ValueError(
+ f"Cache mode '{args.cache}' is not supported. Only 'static_v1' is supported."
+ )
+
+ trt_lm = compile_lm_torchtrt(model, args, device)
+ if args.model == "Qwen/Qwen2.5-VL-3B-Instruct":
+ trt_model.model = trt_lm
+ else:
+ trt_model.language_model = trt_lm
+
+ emb_layer = emb_layer.to(device)
+ if args.model == "Qwen/Qwen2.5-VL-3B-Instruct":
+ trt_model.lm_head = trt_model.lm_head.to(device)
+
+ if args.cache == "static_v1":
+ if args.model == "Qwen/Qwen2.5-VL-3B-Instruct":
+ trt_generate = generate_mm_qwen2_5_vl_with_static_cache
+ else: # eagle2
+ trt_generate = generate_mm_with_static_cache
+ else:
+ if args.model == "Qwen/Qwen2.5-VL-3B-Instruct":
+ trt_generate = generate_mm_qwen2_5_vl
+ else: # eagle2
+ trt_generate = generate_mm
+
+ # Prepare args for generate function
+ generate_args = {
+ "model": trt_model,
+ "pixel_values": inputs["pixel_values"],
+ "input_ids": inputs["input_ids"],
+ "eos_token_id": processor.tokenizer.eos_token_id,
+ "emb_layer": emb_layer,
+ "max_new_tokens": args.num_tokens,
+ }
+ if args.model == "Qwen/Qwen2.5-VL-3B-Instruct":
+ generate_args["image_grid_thw"] = inputs["image_grid_thw"]
+
+ if args.cache == "static_v1" or args.benchmark:
+ generate_args["with_timing"] = True
+
+ if args.cache == "static_v1":
+ generate_args["device"] = device
+
+ # Run TRT generation
+ trt_output = trt_generate(**generate_args)
+
+ # Unpack results
+ if args.benchmark or args.cache == "static_v1":
+ trt_gen_tokens, _, overall_time, _, _ = trt_output
+ trt_stats = record_stats(
+ "TensorRT",
+ [overall_time / 1000], # time is in ms, convert to s
+ args.precision,
+ batch_size=args.batch_size,
+ )
+ else:
+ trt_gen_tokens = trt_output
+
+ # -------------------------------------------------------------------------#
+ # 5. Reporting
+ # -------------------------------------------------------------------------#
+ if not args.benchmark:
+ if args.enable_pytorch_run:
+ print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer)
+ print_outputs("TensorRT", trt_gen_tokens, processor.tokenizer)
+
+ if args.enable_pytorch_run:
+ print(
+ f"PyTorch and TensorRT outputs match: "
+ f"{torch.equal(pyt_gen_tokens, trt_gen_tokens)}"
+ )
+
+ if args.benchmark:
+ if args.enable_pytorch_run:
+ print("========= PyTorch PERFORMANCE =========\n")
+ print(pyt_stats)
+ print("=====================\n")
+ print("========= TensorRT PERFORMANCE =========\n")
+ print(trt_stats)
diff --git a/tools/llm/static_cache_v1.py b/tools/llm/static_cache_v1.py
index b60396c08b..58daacedf5 100644
--- a/tools/llm/static_cache_v1.py
+++ b/tools/llm/static_cache_v1.py
@@ -201,7 +201,7 @@ def insert_kv_slicing_before_sdpa(
args=(slice_7, 3),
kwargs={},
)
- # =============================================== #
+
# Concatenate the sliced tensors to build KV cache
cat = gm.graph.create_node(
"call_function",
diff --git a/tools/llm/test_qwen2.5_components.py b/tools/llm/test_qwen2.5_components.py
index 60482bf22d..1c1366bd0c 100644
--- a/tools/llm/test_qwen2.5_components.py
+++ b/tools/llm/test_qwen2.5_components.py
@@ -16,7 +16,7 @@
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
-from register_sdpa import *
+from torchtrt_ext import register_sdpa
ATOL = 1e-5
RTOL = 1e-5
diff --git a/tools/llm/utils.py b/tools/llm/utils.py
index 2c3434b0ed..77ef26b33a 100644
--- a/tools/llm/utils.py
+++ b/tools/llm/utils.py
@@ -47,7 +47,7 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
return ep
-def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule):
+def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule, device: str = "cuda:0"):
"""
Extracts and returns zeroed static KV cache tensors from a torch.fx.GraphModule. This should only be used for static cache_v1 and static cache_v2.
@@ -71,7 +71,7 @@ def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule):
torch.zeros(
input.meta["val"].shape,
dtype=input.meta["val"].dtype,
- device=torch.device("cuda:0"),
+ device=torch.device(device),
)
)
@@ -242,3 +242,628 @@ def record_stats(backend, timings, precision, batch_size=1, compile_time_s=None)
"Compile Time(s)": compile_time_s,
}
return stats
+
+
+def _prepare_mm_inputs(
+ model,
+ pixel_values: torch.Tensor | None,
+ input_ids: torch.Tensor,
+ emb_layer: torch.nn.Embedding,
+ with_timing: bool = False,
+):
+ """
+ Prepares multimodal inputs for Eagle2-style VLMs by encoding images and merging with text embeddings.
+ Optionally times the vision and MLP parts.
+ """
+ vision_time = 0.0
+ mlp_time = 0.0
+ vit_embeds = None
+
+ if pixel_values is not None:
+ if with_timing:
+ vision_start = torch.cuda.Event(enable_timing=True)
+ vision_end = torch.cuda.Event(enable_timing=True)
+ mlp_start = torch.cuda.Event(enable_timing=True)
+ mlp_end = torch.cuda.Event(enable_timing=True)
+
+ vision_start.record()
+ vit_out = model.vision_model(pixel_values)
+ vision_end.record()
+ torch.cuda.synchronize()
+ vision_time = vision_start.elapsed_time(vision_end)
+ else:
+ vit_out = model.vision_model(pixel_values)
+
+ vit_embeds = (
+ vit_out.last_hidden_state
+ if hasattr(vit_out, "last_hidden_state")
+ else vit_out
+ )
+
+ if with_timing:
+ mlp_start.record()
+
+ h = w = int(vit_embeds.shape[1] ** 0.5)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
+ vit_embeds = model.pixel_shuffle(
+ vit_embeds, scale_factor=model.downsample_ratio
+ )
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
+ vit_embeds = model.mlp1(vit_embeds)
+
+ if with_timing:
+ mlp_end.record()
+ torch.cuda.synchronize()
+ mlp_time = mlp_start.elapsed_time(mlp_end)
+
+ seq_tokens = input_ids.clone()
+ seq_embeds = emb_layer(seq_tokens)
+
+ if vit_embeds is not None:
+ B, N, C = seq_embeds.shape
+ flat_emb = seq_embeds.view(B * N, C)
+ mask = seq_tokens.view(B * N) == model.image_token_index
+ try:
+ flat_emb[mask] = vit_embeds.reshape(-1, C).to(flat_emb.dtype)[: mask.sum()]
+ except Exception:
+ # Fallback in unlikely size-mismatch cases
+ flat_emb[mask] = vit_embeds.reshape(-1, C)[: mask.sum()].to(flat_emb.dtype)
+ seq_embeds = flat_emb.view(B, N, C)
+
+ if with_timing:
+ return seq_tokens, seq_embeds, vision_time, mlp_time
+ else:
+ return seq_tokens, seq_embeds
+
+
+def generate_mm(
+ model,
+ pixel_values: torch.Tensor | None,
+ input_ids: torch.Tensor,
+ eos_token_id: int,
+ emb_layer: torch.nn.Embedding,
+ max_new_tokens: int = 64,
+ device: str = "cuda:0",
+ with_timing: bool = False,
+):
+ """Greedy decode for Eagle2-style VLM, with optional detailed timing.
+
+ Parameters
+ ----------
+ model : nn.Module
+ Must expose vision_model, mlp1, language_model, pixel_shuffle, downsample_ratio, image_token_index.
+ pixel_values : Tensor | None
+ Input image batch (B,C,H,W) or None.
+ input_ids : LongTensor (B, N_prompt)
+ Text prompt token ids including [IMG] placeholder(s).
+ eos_token_id : int
+ Stop generation when all sequences emit EOS.
+ emb_layer : nn.Embedding
+ Embedding layer for input_ids.
+ max_new_tokens : int
+ Maximum number of new tokens to generate.
+ with_timing : bool
+ If True, returns detailed timing information.
+
+ Returns
+ -------
+ if with_timing is False:
+ torch.LongTensor: Generated token sequence (only new tokens).
+ if with_timing is True:
+ tuple: (
+ seq_tokens: Full generated token sequence,
+ step_times: List of latencies for each generation step,
+ overall_time: Total generation time,
+ vision_time: Vision encoder latency,
+ mlp_time: MLP latency
+ )
+ """
+ if with_timing:
+ overall_start = torch.cuda.Event(enable_timing=True)
+ overall_end = torch.cuda.Event(enable_timing=True)
+ lm_start = torch.cuda.Event(enable_timing=True)
+ lm_end = torch.cuda.Event(enable_timing=True)
+ overall_start.record()
+
+ # --- Input preparation ---
+ if with_timing:
+ seq_tokens, seq_embeds, vision_time, mlp_time = _prepare_mm_inputs(
+ model, pixel_values, input_ids, emb_layer, with_timing=True
+ )
+ else:
+ seq_tokens, seq_embeds = _prepare_mm_inputs(
+ model, pixel_values, input_ids, emb_layer, with_timing=False
+ )
+
+ # ───────────────────────────────── Greedy loop ───────────────────────────────────────────────────
+ step_times = []
+ generated = 0
+
+ while generated < max_new_tokens:
+ if with_timing:
+ lm_start.record()
+
+ cur_embeds = seq_embeds
+ position_ids = (
+ torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device)
+ )
+ with torch.no_grad():
+ logits = model.language_model(
+ inputs_embeds=cur_embeds, position_ids=position_ids
+ )
+ if hasattr(logits, "logits"):
+ logits = logits.logits
+
+ next_tok = torch.argmax(logits[:, -1, :], dim=-1)
+
+ if with_timing:
+ lm_end.record()
+ torch.cuda.synchronize()
+ step_times.append(lm_start.elapsed_time(lm_end))
+
+ seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=-1)
+ seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1)
+
+ generated += 1
+ if (next_tok == eos_token_id).all():
+ break
+
+ if with_timing:
+ overall_end.record()
+ torch.cuda.synchronize()
+ overall_time = overall_start.elapsed_time(overall_end)
+ return (
+ seq_tokens[:, input_ids.shape[1] :],
+ step_times,
+ overall_time,
+ vision_time,
+ mlp_time,
+ )
+ else:
+ return seq_tokens[:, input_ids.shape[1] :]
+
+
+@torch.inference_mode()
+def generate_mm_with_static_cache(
+ model,
+ pixel_values: torch.Tensor | None,
+ input_ids: torch.Tensor,
+ eos_token_id: int,
+ emb_layer: torch.nn.Embedding,
+ max_new_tokens: int = 64,
+ device: str = "cuda:0",
+ with_timing: bool = False,
+):
+ """
+ Greedy Decoder for multimodal VLM (using static KV-cache v1), with optional timing.
+ Basic structure is identical to LM version (generate_with_static_cache) but
+ * Input is `inputs_embeds`
+ * Vision tokens are sent together only in the first step
+ """
+ if with_timing:
+ overall_start = torch.cuda.Event(enable_timing=True)
+ overall_end = torch.cuda.Event(enable_timing=True)
+ lm_start = torch.cuda.Event(enable_timing=True)
+ lm_end = torch.cuda.Event(enable_timing=True)
+ overall_start.record()
+ vision_time, mlp_time = 0.0, 0.0
+
+ if with_timing:
+ seq_tokens, seq_embeds, vision_time, mlp_time = _prepare_mm_inputs(
+ model, pixel_values, input_ids, emb_layer, with_timing=True
+ )
+ else:
+ seq_tokens, seq_embeds = _prepare_mm_inputs(
+ model, pixel_values, input_ids, emb_layer, with_timing=False
+ )
+
+ # ───────────────────── KV-cache initialization ─────────────────────
+ kv_cache = get_zeroed_static_cache_inputs(model.language_model, device=device)
+ start_idx = 0
+ end_idx = seq_embeds.size(1)
+ generated = 0
+ max_total_len = end_idx + max_new_tokens
+ output_tokens = seq_tokens.clone()
+ step_times = []
+
+ # ───────────────────── Greedy loop ───────────────────────
+ while output_tokens.size(1) < max_total_len:
+ if with_timing:
+ lm_start.record()
+
+ cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :]
+
+ if generated == 0:
+ position_ids = (
+ torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device)
+ )
+ else:
+ position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to(
+ cur_embeds.device
+ )
+
+ input_signature = (
+ cur_embeds,
+ position_ids,
+ *kv_cache,
+ start_idx,
+ end_idx,
+ )
+
+ logits_and_kv = model.language_model(*input_signature)
+ logits, kv_cache = logits_and_kv[0], logits_and_kv[1:]
+
+ next_tok = logits[:, -1, :].argmax(dim=-1)
+ output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1)
+
+ next_embed = emb_layer(next_tok)[:, None, :]
+ seq_embeds = next_embed
+
+ generated += 1
+ start_idx = end_idx
+ end_idx += 1
+
+ if with_timing:
+ lm_end.record()
+ torch.cuda.synchronize()
+ step_times.append(lm_start.elapsed_time(lm_end))
+
+ if (next_tok == eos_token_id).all():
+ break
+
+ if with_timing:
+ overall_end.record()
+ torch.cuda.synchronize()
+ overall_time = overall_start.elapsed_time(overall_end)
+ return (
+ output_tokens[:, input_ids.shape[1] :],
+ step_times,
+ overall_time,
+ vision_time,
+ mlp_time,
+ )
+ else:
+ return output_tokens[:, input_ids.shape[1] :]
+
+
+def _prepare_qwen_mm_inputs(
+ model,
+ pixel_values: torch.Tensor | None,
+ input_ids: torch.Tensor,
+ image_grid_thw: torch.Tensor,
+ emb_layer: torch.nn.Embedding,
+ with_timing: bool = False,
+):
+ """
+ Prepares multimodal inputs for Qwen2.5-VL by encoding images and merging with text embeddings.
+ Optionally times the vision part.
+ """
+ vision_time = 0.0
+ image_embeds = None
+
+ if pixel_values is not None:
+ if with_timing:
+ vision_start = torch.cuda.Event(enable_timing=True)
+ vision_end = torch.cuda.Event(enable_timing=True)
+ vision_start.record()
+
+ image_embeds = model.visual(pixel_values, image_grid_thw)
+
+ if with_timing:
+ vision_end.record()
+ torch.cuda.synchronize()
+ vision_time = vision_start.elapsed_time(vision_end)
+
+ seq_tokens = input_ids.clone()
+ seq_embeds = emb_layer(seq_tokens)
+
+ if image_embeds is not None:
+ mask = seq_tokens == model.config.image_token_id
+ num_image_tokens = mask.sum().item()
+ if num_image_tokens != image_embeds.shape[0]:
+ raise ValueError(
+ f"Number of image tokens ({num_image_tokens}) does not match number of image embeddings ({image_embeds.shape[0]})."
+ )
+ mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds)
+ seq_embeds = seq_embeds.masked_scatter(
+ mask_expanded, image_embeds.to(seq_embeds.dtype)
+ )
+
+ if with_timing:
+ return seq_tokens, seq_embeds, vision_time
+ else:
+ return seq_tokens, seq_embeds
+
+
+def generate_mm_qwen2_5_vl(
+ model,
+ pixel_values: torch.Tensor | None,
+ input_ids: torch.Tensor,
+ image_grid_thw: torch.Tensor,
+ eos_token_id: int,
+ emb_layer: torch.nn.Embedding,
+ max_new_tokens: int = 64,
+ with_timing: bool = False,
+):
+ """
+ Custom generation function for the Qwen2_5_VLForConditionalGeneration model, with optional timing.
+ """
+ if with_timing:
+ overall_start = torch.cuda.Event(enable_timing=True)
+ overall_end = torch.cuda.Event(enable_timing=True)
+ lm_start = torch.cuda.Event(enable_timing=True)
+ lm_end = torch.cuda.Event(enable_timing=True)
+ overall_start.record()
+
+ if with_timing:
+ seq_tokens, seq_embeds, vision_time = _prepare_qwen_mm_inputs(
+ model,
+ pixel_values,
+ input_ids,
+ image_grid_thw,
+ emb_layer,
+ with_timing=True,
+ )
+ else:
+ seq_tokens, seq_embeds = _prepare_qwen_mm_inputs(
+ model,
+ pixel_values,
+ input_ids,
+ image_grid_thw,
+ emb_layer,
+ with_timing=False,
+ )
+
+ step_times = []
+ generated = 0
+ while generated < max_new_tokens:
+ if with_timing:
+ lm_start.record()
+
+ position_ids = (
+ torch.arange(
+ 0, seq_tokens.size(1), dtype=torch.long, device=seq_tokens.device
+ )
+ .unsqueeze(0)
+ .expand(seq_embeds.size(0), seq_embeds.size(1))
+ )
+
+ with torch.no_grad():
+ outputs = model.model(
+ inputs_embeds=seq_embeds,
+ position_ids=position_ids,
+ )
+ hidden_states = (
+ outputs
+ if isinstance(outputs, torch.Tensor)
+ else outputs.last_hidden_state
+ )
+
+ logits = model.lm_head(hidden_states[:, -1, :])
+ next_tok = torch.argmax(logits, dim=-1)
+
+ if with_timing:
+ lm_end.record()
+ torch.cuda.synchronize()
+ step_times.append(lm_start.elapsed_time(lm_end))
+
+ seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1)
+ next_emb = emb_layer(next_tok)[:, None, :]
+ seq_embeds = torch.cat([seq_embeds, next_emb], dim=1)
+
+ generated += 1
+ if (next_tok == eos_token_id).all():
+ break
+
+ if with_timing:
+ overall_end.record()
+ torch.cuda.synchronize()
+ overall_time = overall_start.elapsed_time(overall_end)
+ return (
+ seq_tokens[:, input_ids.shape[1] :],
+ step_times,
+ overall_time,
+ vision_time,
+ 0.0,
+ )
+ else:
+ return seq_tokens[:, input_ids.shape[1] :]
+
+
+def generate_mm_qwen2_5_vl_with_static_cache(
+ model,
+ pixel_values: torch.Tensor | None,
+ input_ids: torch.Tensor,
+ image_grid_thw: torch.Tensor,
+ eos_token_id: int,
+ emb_layer: torch.nn.Embedding,
+ max_new_tokens: int = 64,
+ device: str = "cuda:0",
+ with_timing: bool = False,
+) -> torch.LongTensor:
+ """
+ Greedy Decoder for Qwen-2.5-VL using static KV-cache, with optional timing.
+ """
+ if with_timing:
+ overall_start = torch.cuda.Event(enable_timing=True)
+ overall_end = torch.cuda.Event(enable_timing=True)
+ lm_start = torch.cuda.Event(enable_timing=True)
+ lm_end = torch.cuda.Event(enable_timing=True)
+ overall_start.record()
+
+ if with_timing:
+ seq_tokens, seq_embeds, vision_time = _prepare_qwen_mm_inputs(
+ model,
+ pixel_values,
+ input_ids,
+ image_grid_thw,
+ emb_layer,
+ with_timing=True,
+ )
+ else:
+ seq_tokens, seq_embeds = _prepare_qwen_mm_inputs(
+ model,
+ pixel_values,
+ input_ids,
+ image_grid_thw,
+ emb_layer,
+ with_timing=False,
+ )
+
+ kv_cache = get_zeroed_static_cache_inputs(model.model, device=device)
+ start_idx = 0
+ end_idx = seq_embeds.size(1)
+ generated = 0
+ max_total_len = end_idx + max_new_tokens
+ output_tokens = seq_tokens.clone()
+ step_times = []
+
+ while output_tokens.size(1) < max_total_len:
+ if with_timing:
+ lm_start.record()
+
+ cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :]
+
+ if generated == 0:
+ position_ids = (
+ torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device)
+ )
+ else:
+ position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to(
+ cur_embeds.device
+ )
+
+ input_signature = (
+ cur_embeds,
+ position_ids,
+ *kv_cache,
+ start_idx,
+ end_idx,
+ )
+
+ outputs_and_kv = model.model(*input_signature)
+ hidden_states, kv_cache = outputs_and_kv[0], outputs_and_kv[1:]
+
+ logits = model.lm_head(hidden_states[:, -1, :])
+ next_tok = logits.argmax(dim=-1)
+ output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1)
+
+ next_embed = emb_layer(next_tok)[:, None, :]
+ seq_embeds = next_embed
+
+ generated += 1
+ start_idx = end_idx
+ end_idx += 1
+
+ if with_timing:
+ lm_end.record()
+ torch.cuda.synchronize()
+ step_times.append(lm_start.elapsed_time(lm_end))
+
+ if (next_tok == eos_token_id).all():
+ break
+
+ if with_timing:
+ overall_end.record()
+ torch.cuda.synchronize()
+ overall_time = overall_start.elapsed_time(overall_end)
+ # For Qwen, there is no separate MLP part like in Eagle, so mlp_time is 0.
+ return (
+ output_tokens[:, input_ids.shape[1] :],
+ step_times,
+ overall_time,
+ vision_time,
+ 0.0,
+ )
+ else:
+ return output_tokens[:, input_ids.shape[1] :]
+
+
+@torch.inference_mode()
+def generate_mm_qwen2_5_vl_with_timing(
+ model,
+ pixel_values: torch.Tensor | None,
+ input_ids: torch.Tensor,
+ image_grid_thw: torch.Tensor,
+ eos_token_id: int,
+ emb_layer: torch.nn.Embedding,
+ max_new_tokens: int = 64,
+):
+ """
+ Custom generation function for the Qwen2_5_VLForConditionalGeneration model with timing.
+ """
+ overall_start = torch.cuda.Event(enable_timing=True)
+ overall_end = torch.cuda.Event(enable_timing=True)
+ vision_start = torch.cuda.Event(enable_timing=True)
+ vision_end = torch.cuda.Event(enable_timing=True)
+ lm_start = torch.cuda.Event(enable_timing=True)
+ lm_end = torch.cuda.Event(enable_timing=True)
+
+ overall_start.record()
+
+ vision_time = 0.0
+ image_embeds = None
+ if pixel_values is not None:
+ vision_start.record()
+ image_embeds = model.visual(pixel_values, image_grid_thw)
+ vision_end.record()
+ torch.cuda.synchronize()
+ vision_time = vision_start.elapsed_time(vision_end)
+
+ seq_tokens = input_ids.clone()
+ seq_embeds = emb_layer(seq_tokens)
+
+ if image_embeds is not None:
+ mask = seq_tokens == model.config.image_token_id
+ num_image_tokens = mask.sum().item()
+ if num_image_tokens != image_embeds.shape[0]:
+ raise ValueError(
+ f"Number of image tokens ({num_image_tokens}) does not match number of image embeddings ({image_embeds.shape[0]})."
+ )
+ mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds)
+ seq_embeds = seq_embeds.masked_scatter(
+ mask_expanded, image_embeds.to(seq_embeds.dtype)
+ )
+
+ step_times = []
+ generated = 0
+ while generated < max_new_tokens:
+ lm_start.record()
+ position_ids = (
+ torch.arange(
+ 0, seq_tokens.size(1), dtype=torch.long, device=seq_tokens.device
+ )
+ .unsqueeze(0)
+ .expand(seq_embeds.size(0), seq_embeds.size(1))
+ )
+
+ with torch.no_grad():
+ outputs = model.model(
+ inputs_embeds=seq_embeds,
+ position_ids=position_ids,
+ )
+ hidden_states = (
+ outputs
+ if isinstance(outputs, torch.Tensor)
+ else outputs.last_hidden_state
+ )
+
+ logits = model.lm_head(hidden_states[:, -1, :])
+ next_tok = torch.argmax(logits, dim=-1)
+
+ lm_end.record()
+ torch.cuda.synchronize()
+ step_times.append(lm_start.elapsed_time(lm_end))
+
+ seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1)
+ next_emb = emb_layer(next_tok)[:, None, :]
+ seq_embeds = torch.cat([seq_embeds, next_emb], dim=1)
+
+ generated += 1
+
+ overall_end.record()
+ torch.cuda.synchronize()
+ overall_time = overall_start.elapsed_time(overall_end)
+
+ # For Qwen, there is no separate MLP part like in Eagle, so mlp_time is 0.
+ return seq_tokens, step_times, overall_time, vision_time, 0.0