diff --git a/docsrc/tutorials/compile_hf_models.rst b/docsrc/tutorials/compile_hf_models.rst index f6da87b145..406cbb9140 100644 --- a/docsrc/tutorials/compile_hf_models.rst +++ b/docsrc/tutorials/compile_hf_models.rst @@ -18,6 +18,7 @@ Overview of tools/llm Directory The ``tools/llm`` directory provides the following tools to compile LLM models from Huggingface: * **run_llm.py**: Main entry point for model compilation, generating outputs, and benchmarking +* **run_vlm.py**: Entry point for compiling and benchmarking Visual Language Models (VLMs) * **Static Cache Utilities**: ``static_cache_v1.py`` and ``static_cache_v2.py`` for KV cache optimization * **SDPA Attention**: ``sdpa_converter.py`` and ``register_sdpa.py`` for registering scaled dot-product attention converter and lowering pass. * **Testing Components**: Model-specific test files for validation @@ -60,6 +61,30 @@ We have officially verified support for the following LLM families: - FP16, FP32 - Yes +Supported VLM Models +-------------------- +We have officially verified support for the following Visual Language Models (VLMs): + +.. list-table:: + :widths: 20 40 20 20 20 + :header-rows: 1 + + * - Model Series + - HuggingFace Model Card + - Precision + - KV Cache Support ? + - Component Support + * - Qwen 2.5 VL + - Qwen/Qwen2.5-VL-3B-Instruct + - FP16, FP32 + - Yes (static_v1 only) + - Language Model only (Image Encoder not supported) + * - Eagle2 + - nvidia/Eagle2-2B + - FP16, FP32 + - Yes (static_v1 only) + - Language Model and Image Encoder both supported + Getting Started with run_llm.py ------------------------------- @@ -112,6 +137,36 @@ Other Usage Examples python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP32 --benchmark +Getting Started with run_vlm.py +------------------------------- + +For Visual Language Models (VLMs), use ``run_vlm.py`` to compile and benchmark models that process both text and images. + +Basic Usage +^^^^^^^^^^^ + +.. code-block:: bash + + python tools/llm/run_vlm.py \ + --model Qwen/Qwen2.5-VL-3B-Instruct \ + --precision FP16 \ + --num_tokens 128 \ + --cache static_v1 \ + --enable_pytorch_run \ + --benchmark + +Key Arguments +^^^^^^^^^^^^^ + +* ``--model``: Name or path of the HuggingFace VLM +* ``--prompt``: Input prompt for generation +* ``--image_path``: (Optional) Path to input image file. If not provided, will use a sample image +* ``--precision``: Precision mode (``FP16``, ``FP32``) +* ``--num_tokens``: Number of output tokens to generate +* ``--cache``: KV cache type (``static_v1`` or empty for no KV caching) +* ``--benchmark``: Enable benchmarking mode +* ``--enable_pytorch_run``: Also run and compare PyTorch baseline + KV Caching in Torch-TensorRT --------------------------------- @@ -122,7 +177,7 @@ The length of KV cache = input sequence length + output sequence length (specifi Static Cache v1 ^^^^^^^^^^^^^^^^ -The ``static_cache_v1.py`` implements KV cache in the model graph as follows: +The ``static_cache_v1.py`` implements KV cache in the model graph as follows: .. code-block:: python @@ -210,9 +265,13 @@ Limitations and Known Issues * Sliding window attention (used in Gemma3 and Qwen 3 models) is not yet supported * Some model architectures (e.g. Phi-4) have issues with exporting the torch model. +* For VLMs, Qwen2.5-VL image encoder compilation is not supported due to dynamic operations incompatible with torch.export. Requirements ^^^^^^^^^^^^ * Torch-TensorRT 2.8.0 or later -* Transformers v4.52.3 \ No newline at end of file +* Transformers v4.52.3 +* For VLM models (run_vlm.py): + - ``pip install qwen-vl-utils`` (for Qwen2.5-VL-3B-Instruct model) + - ``pip install flash-attn --no-build-isolation -v`` (for Eagle2-2B model) \ No newline at end of file diff --git a/tools/llm/README.md b/tools/llm/README.md index a141505517..c0a88635f6 100644 --- a/tools/llm/README.md +++ b/tools/llm/README.md @@ -1,10 +1,11 @@ # Optimizing LLMs in Torch-TensorRT -This directory provides utilities and scripts for compiling, optimizing, and benchmarking Large Language Models (LLMs) using Torch-TensorRT, with a focus on efficient inference on NVIDIA GPUs. The main entry point is `run_llm.py`, which demonstrates how to export, compile, and run LLMs with various caching strategies and precision modes. Note that this is an **experimental release** and APIs may change in future versions. +This directory provides utilities and scripts for compiling, optimizing, and benchmarking Large Language Models (LLMs) and Visual Language Models (VLMs) using Torch-TensorRT, with a focus on efficient inference on NVIDIA GPUs. The main entry points are `run_llm.py` for text-only LLMs and `run_vlm.py` for vision-language models. Note that this is an **experimental release** and APIs may change in future versions. ### Key Features - **Model Support:** Works with popular LLMs such as Llama-3, Qwen2.5, etc. +- **VLM Support:** Supports Visual Language Models like Qwen2.5-VL and Eagle2. - **Precision Modes:** Supports FP16, BF16, and FP32. - **KV Cache:** Supports static and dynamic KV cache for efficient autoregressive decoding. - **Benchmarking:** Measures and compares throughput and latency for PyTorch and TensorRT backends. @@ -24,20 +25,33 @@ We have officially verified support for the following models: | Qwen 2.5 | Qwen/Qwen2.5-0.5B-Instruct
Qwen/Qwen2.5-1.5B-Instruct
Qwen/Qwen2.5-4B-Instruct
Qwen/Qwen2.5-7B-Instruct | FP16, FP32 | Yes | | Qwen 3 | Qwen/Qwen3-0.6B
Qwen/Qwen3-1.7B
Qwen/Qwen3-4B
Qwen/Qwen3-8B | FP16, FP32 | Yes | +### Supported VLM Models + +| Model Series | HF Model Card | Precision | KV Cache Supported ? | +|--------------|---------------|-----------|-------------------| +| Qwen 2.5 VL | Qwen/Qwen2.5-VL-3B-Instruct | FP16, FP32 | Yes | +| Eagle2 | nvidia/Eagle2-2B | FP16, FP32 | Yes | ### Usage -The main entry point is : `run_llm.py` +#### Text-only LLMs: `run_llm.py` ```bash python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark ``` +#### Vision Language Models: `run_vlm.py` + +```bash +python run_vlm.py --model Qwen/Qwen2.5-VL-3B-Instruct --precision FP16 --num_tokens 128 --cache static_v1 --enable_pytorch_run --benchmark +``` + #### Key Arguments -- `--model`: Name or path of the HuggingFace LLM. +- `--model`: Name or path of the HuggingFace LLM/VLM. - `--tokenizer`: (Optional) Tokenizer name; defaults to model. - `--prompt`: Input prompt for generation. +- `--image_path`: (Optional) Path to input image file for VLM models. If not provided, will use a sample image. - `--precision`: Precision mode (`FP16`, `FP32`). - `--num_tokens`: Number of output tokens to generate. - `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching). @@ -64,4 +78,7 @@ This codebase can be extended to ## Requirements - Torch-TensorRT 2.8.0 -- Transformers v4.52.3 \ No newline at end of file +- Transformers v4.52.3 +- For VLM models (run_vlm.py): + - `pip install qwen-vl-utils` (for Qwen2.5-VL-3B-Instruct model) + - `pip install flash-attn --no-build-isolation -v` (for Eagle2-2B model) \ No newline at end of file diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py new file mode 100644 index 0000000000..e26d971f32 --- /dev/null +++ b/tools/llm/run_vlm.py @@ -0,0 +1,678 @@ +""" +.. _run_vlm: + +Benchmarking VLM Inference with Torch-TensorRT +========================================================== + +This script provides a framework for benchmarking the performance of Visual-Language +Models (VLMs). It optimizes the two most computationally intensive components of a +VLM—the language model and the vision model (image feature extraction)—using +the Torch-TensorRT dynamo backend. + +Key Features: +- **Component-wise Optimization**: Compiles both the language and vision models + separately with Torch-TensorRT to accelerate inference. +- **Performance Benchmarking**: Runs the model for multiple iterations to + measure and compare inference latency against the PyTorch baseline. +- **Output Verification**: Checks for token-level consistency between the optimized + TensorRT model and the original PyTorch model to ensure correctness. +- **KV Cache Testing**: Includes options to test inference with and without + KV caching to evaluate its impact on performance. + +This tool mirrors the style and structure of `run_llm.py`, providing a clear +workflow for VLM optimization and analysis. + +Dependencies: +- For Qwen VLM models: pip install qwen-vl-utils +- For Eagle2 models: pip install flash-attn --no-build-isolation -v +""" + +import argparse +import copy +import os +import sys +from contextlib import nullcontext +from typing import Tuple + +import requests +import torch +import torch_tensorrt + +# we "monkey-patch" the global attention function map for Qwen2. +# This ensures that any part of the code (including torch.export) requesting +# "flash_attention_2" will receive the "sdpa" implementation instead. +# This patch is global for the script's execution context. +import transformers.models.qwen2.modeling_qwen2 as mq +from PIL import Image +from torchtrt_ext import register_sdpa +from transformers import AutoConfig, AutoModel, AutoProcessor +from utils import ( + export_llm, + generate_mm, + generate_mm_qwen2_5_vl, + generate_mm_qwen2_5_vl_with_static_cache, + generate_mm_with_static_cache, + record_stats, +) + +# --- WORKAROUND FOR EAGLE2 SDPA COMPILATION --- +# Eagle2's language model (Qwen2) implicitly defaults to "flash_attention_2" +# due to settings in its remote code and config.json. This prevents direct +# compilation with SDPA. To work around this without modifying the library, + + +mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"] +# --- END WORKAROUND --- + +# --- Model-specific constants for benchmark and compilation --- +# Centralizing these values improves readability and maintainability. +MODEL_CONSTANTS = { + "nvidia/Eagle2-2B": { + "EXAMPLE_SEQLEN": 2560, # A fixed sequence length for creating the example tensor for TRT compilation. + "IMAGE_TOKENS": 1792, # Number of special tokens used to represent the image patch embeddings in the input sequence for Eagle2-2B VLM. + "PROMPT_WRAPPER_TOKENS": 26, # The number of special/processing tokens added by the processor's chat template in benchmark mode. + }, + "Qwen/Qwen2.5-VL-3B-Instruct": { + "EXAMPLE_SEQLEN": 2560, + "IMAGE_TOKENS": 1426, + "PROMPT_WRAPPER_TOKENS": 21, + }, +} +# --- END Model-specific constants --- + +# -----------------------------------------------------------------------------# +# Model loading helpers +# -----------------------------------------------------------------------------# + + +def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): + """ + Load nvidia/Eagle2-2B model and processor, ensuring the language model uses SDPA. + + Returns + ------- + tuple[torch.nn.Module, transformers.AutoProcessor, torch.nn.Embedding] + The model, its processor and the language-model input embedding layer. + """ + model_id = "nvidia/Eagle2-2B" + try: + with torch.no_grad(): + model = ( + AutoModel.from_pretrained( + model_id, + trust_remote_code=True, + torch_dtype=torch_dtype, + # attn_implementation="sdpa" is ignored due to the model's remote code. + ) + .eval() + .to(device) + ) + except ImportError as e: + if "flash_attn" in str(e): + raise ImportError( + "FlashAttention2 is required for Eagle2 models but not installed. " + "Please install it using: pip install flash-attn --no-build-isolation -v" + ) from e + raise + + processor = AutoProcessor.from_pretrained( + model_id, trust_remote_code=True, use_fast=True + ) + if hasattr(processor, "tokenizer"): + processor.tokenizer.padding_side = "left" + + emb_layer = model.language_model.get_input_embeddings().to(torch_dtype).to(device) + return model, processor, emb_layer + + +def _load_qwen2_5_vl(device, torch_dtype: torch.dtype): + """ + Load Qwen2.5-VL model and processor. + """ + from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + model_id = "Qwen/Qwen2.5-VL-3B-Instruct" + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch_dtype, device_map=device + ).eval() + processor = AutoProcessor.from_pretrained(model_id) + emb_layer = model.model.get_input_embeddings().to(torch_dtype).to(device) + return model, processor, emb_layer + + +def load_model( + model_name: str, device: torch.device, torch_dtype: torch.dtype +) -> Tuple[torch.nn.Module, AutoProcessor, torch.nn.Embedding]: + """Dispatch helper for supported VLMs.""" + if model_name == "nvidia/Eagle2-2B": + return _load_eagle2(device, torch_dtype) + elif model_name == "Qwen/Qwen2.5-VL-3B-Instruct": + return _load_qwen2_5_vl(device, torch_dtype) + msg = f"Unsupported model: '{model_name}'. Supported models are: ['nvidia/Eagle2-2B', 'Qwen/Qwen2.5-VL-3B-Instruct']" + raise ValueError(msg) + + +# -----------------------------------------------------------------------------# +# Input loading helpers +# -----------------------------------------------------------------------------# + + +def load_inputs(args: argparse.Namespace, processor, device: torch.device): + """ + Loads and constructs the input dictionary for the specified VLM model. + """ + # Load image from local path if provided, otherwise use default URL + if args.image_path is not None: + # Use local image file + image = Image.open(args.image_path) + else: + # Use default URL image + url = "https://www.ilankelman.org/stopsigns/australia.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + if args.benchmark: + model_constants = MODEL_CONSTANTS[args.model] + image_tokens = model_constants["IMAGE_TOKENS"] + wrapper_tokens = model_constants["PROMPT_WRAPPER_TOKENS"] + + prompt_len = args.isl - image_tokens - wrapper_tokens + prompt_txt = " ".join(["token"] * max(prompt_len, 0)) + else: + prompt_txt = args.prompt or "Describe this image." + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": prompt_txt}, + ], + } + ] + + text = [ + processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + ] + + # --- Model-specific vision processing --- + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + try: + from qwen_vl_utils import process_vision_info + except ImportError: + raise ImportError( + "The 'qwen-vl-utils' package is required for Qwen VLM models. " + "Please install it using: pip install qwen-vl-utils" + ) + + image_inputs, video_inputs = process_vision_info(messages) + else: # eagle2 + image_inputs, video_inputs = processor.process_vision_info(messages) + + inputs = processor( + text=text, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ).to(device) + + return inputs + + +# -----------------------------------------------------------------------------# +# Torch-TensorRT compilation helpers +# -----------------------------------------------------------------------------# + + +class _LMNoCache(torch.nn.Module): + """ + Thin wrapper that exposes a language model via ``inputs_embeds`` without KV-cache. + """ + + def __init__(self, lm): + super().__init__() + self.lm = lm + + def forward(self, inputs_embeds, position_ids): + out = self.lm(inputs_embeds=inputs_embeds, position_ids=position_ids) + return ( + out.logits + if hasattr(out, "logits") + else out.last_hidden_state if hasattr(out, "last_hidden_state") else out + ) + + +def _compile_lm( + language_model: torch.nn.Module, + input_embeds: torch.Tensor, + args: argparse.Namespace, + device: torch.device, +) -> torch.nn.Module: + """ + Compile the language model component of a VLM with Torch-TensorRT + """ + lm_wrap = _LMNoCache(language_model).to(device).eval() + max_seq_len = input_embeds.shape[1] + args.num_tokens + + seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) + position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(device) + + dyn_shapes = {"inputs_embeds": {1: seq_len}, "position_ids": {1: seq_len}} + + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + else: # FP32 + enabled_precisions = {torch.float32} + + exported_program = export_llm( + lm_wrap, input_embeds, min_seq_len=1, max_seq_len=2560 + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_mod = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[input_embeds, position_ids], + enabled_precisions=enabled_precisions, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=device, + disable_tf32=args.disable_tf32, + use_python_runtime=args.use_python_runtime, + offload_module_to_cpu=args.offload_module_to_cpu, + min_block_size=args.min_block_size, + ) + return trt_mod + + +def compile_lm_torchtrt( + model: torch.nn.Module, args: argparse.Namespace, device: torch.device +) -> torch.nn.Module: + """ + Compiles the Language Model (LLM) component of the VLM using Torch-TensorRT. + """ + torch_dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + lm_model = ( + model.model + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct" + else model.language_model + ) + + model_constants = MODEL_CONSTANTS.get( + args.model, {"EXAMPLE_SEQLEN": args.num_tokens} + ) + example_seq_len = model_constants["EXAMPLE_SEQLEN"] + + example_embeds = torch.randn( + args.batch_size, + example_seq_len, + lm_model.config.hidden_size, + dtype=torch_dtype, + device=device, + ) + + # All supported models use the same compilation helper. + if args.model in ["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"]: + return _compile_lm(lm_model, example_embeds, args, device) + else: + msg = f"Unsupported model: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B', 'Qwen/Qwen2.5-VL-3B-Instruct']" + raise ValueError(msg) + + +def _compile_eagle2_vision( + vision_model: torch.nn.Module, + example_pixel_values: torch.Tensor, + args: argparse.Namespace, + device: torch.device, +) -> torch.nn.Module: + """ + Compile Eagle2 vision model with Torch-TensorRT. + """ + # Set precision-specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + else: # FP32 + enabled_precisions = {torch.float32} + + with torch.inference_mode(): + exported_program = torch.export.export( + vision_model, + (example_pixel_values,), + strict=False, + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_mod = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[example_pixel_values], + enabled_precisions=enabled_precisions, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=device, + disable_tf32=args.disable_tf32, + use_python_runtime=args.use_python_runtime, + offload_module_to_cpu=args.offload_module_to_cpu, + min_block_size=args.min_block_size, + ) + return trt_mod + + +def compile_vision_torchtrt( + model: torch.nn.Module, + args: argparse.Namespace, + example_pixel_values: torch.Tensor, + device: torch.device, +) -> torch.nn.Module: + """ + Dispatcher function for vision model compilation. + """ + if args.model == "nvidia/Eagle2-2B": + return _compile_eagle2_vision( + model.vision_model, example_pixel_values, args, device + ) + elif args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + # TODO: Vision model compilation for Qwen2.5-VL is currently skipped. + # The model's `get_window_index` method uses dynamic Python list operations + # (e.g., .tolist(), .extend()) to process variable-sized image grids for + # windowed attention. These operations are incompatible with torch.export's + # static graph tracing, preventing successful compilation. + return model.visual + else: + raise ValueError(f"Unsupported model: {args.model}") + + +# -----------------------------------------------------------------------------# +# Utility helpers +# -----------------------------------------------------------------------------# + + +def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): + """Pretty-print generated text for comparison.""" + print(f"========= {backend_name} =========") + print( + f"{backend_name} model generated text: ", + tokenizer.decode(gen_tokens[0], skip_special_tokens=True), + ) + print("===================================") + + +# -----------------------------------------------------------------------------# +# Main driver +# -----------------------------------------------------------------------------# +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run VLM inference (PyTorch & TensorRT back-ends)" + ) + parser.add_argument( + "--model", + default="nvidia/Eagle2-2B", + choices=["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"], + help="VLM model name", + ) + parser.add_argument("--prompt", default="Describe this image.", help="Prompt text") + parser.add_argument( + "--precision", + default="FP16", + choices=["FP16", "FP32"], + help="Computation precision", + ) + parser.add_argument("--iterations", type=int, default=5, help="# iterations") + parser.add_argument("--min_block_size", type=int, default=1, help="Min block size") + parser.add_argument("--num_tokens", type=int, default=128, help="# new tokens") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--isl", type=int, default=2048, help="Input seq length") + parser.add_argument( + "--enable_pytorch_run", + action="store_true", + help="Run the PyTorch baseline as well", + ) + parser.add_argument( + "--cache", + default="", + choices=["", "static_v1"], + help="KV-cache variant to use", + ) + parser.add_argument( + "--debug", action="store_true", help="Enable Torch-TensorRT debug logs" + ) + parser.add_argument( + "--benchmark", action="store_true", help="Enable benchmarking mode" + ) + parser.add_argument( + "--image_path", + type=str, + default=None, + help="Path to local image file. If not provided, uses default URL image.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda:0", + help="Device to run inference on (e.g., 'cuda:0', 'cuda:1')", + ) + parser.add_argument( + "--disable_tf32", + action="store_false", + default=True, + help="Disable TF32 precision for TensorRT compilation (default: True)", + ) + parser.add_argument( + "--use_python_runtime", + action="store_false", + default=True, + help="Use Python runtime for TensorRT compilation (default: True)", + ) + parser.add_argument( + "--offload_module_to_cpu", + action="store_false", + default=True, + help="Offload module to CPU for TensorRT compilation (default: True)", + ) + + args = parser.parse_args() + + device = torch.device(args.device) + if device.type == "cuda": + torch.cuda.set_device(device) + + # -------------------------------------------------------------------------# + # 1. Model / processor / embeddings + # -------------------------------------------------------------------------# + dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + model, processor, emb_layer = load_model(args.model, device, dtype) + + # -------------------------------------------------------------------------# + # 2. Input construction (image + text prompt) + # -------------------------------------------------------------------------# + inputs = load_inputs(args, processor, device) + + max_output_len = inputs["input_ids"].shape[1] + args.num_tokens + + # -------------------------------------------------------------------------# + # 3. Optional: PyTorch baseline + # -------------------------------------------------------------------------# + pyt_gen_tokens = pyt_timings = pyt_stats = None + if args.enable_pytorch_run: + # For benchmarking, we run the generation with timing enabled. + # For regular runs, we run without timing for a single output. + if args.benchmark: + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + ( + pyt_gen_tokens, + _, + overall_time, + _, + _, + ) = generate_mm_qwen2_5_vl( + model, + inputs["pixel_values"], + inputs["input_ids"], + inputs["image_grid_thw"], + processor.tokenizer.eos_token_id, + emb_layer, + max_new_tokens=args.num_tokens, + with_timing=True, + ) + else: # eagle2 + ( + pyt_gen_tokens, + _, + overall_time, + _, + _, + ) = generate_mm( + model, + inputs["pixel_values"], + inputs["input_ids"], + processor.tokenizer.eos_token_id, + emb_layer, + max_new_tokens=args.num_tokens, + with_timing=True, + ) + pyt_stats = record_stats( + "PyTorch", + [overall_time / 1000], # time_generate returns seconds + args.precision, + batch_size=args.batch_size, + ) + else: + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + pyt_gen_tokens = generate_mm_qwen2_5_vl( + model, + inputs["pixel_values"], + inputs["input_ids"], + inputs["image_grid_thw"], + processor.tokenizer.eos_token_id, + emb_layer, + max_new_tokens=args.num_tokens, + ) + else: # eagle2 + pyt_gen_tokens = generate_mm( + model, + inputs["pixel_values"], + inputs["input_ids"], + processor.tokenizer.eos_token_id, + emb_layer, + max_new_tokens=args.num_tokens, + ) + + # -------------------------------------------------------------------------# + # 4. Torch-TensorRT compile & run + # -------------------------------------------------------------------------# + + trt_model = copy.deepcopy(model) + # 4.1. Vision model compilation + # --- Add vision model compilation --- # + example_pixel_values = inputs["pixel_values"] + trt_vision = compile_vision_torchtrt(model, args, example_pixel_values, device) + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + trt_model.visual = trt_vision + else: + trt_model.vision_model = trt_vision + + # -------------------------------------------------------------------------# + # 4.2. Language model compilation + # -------------------------------------------------------------------------# + # Register static cache lowering passes if requested + # Cache is not applied to vision model. + if args.cache == "static_v1": + import static_cache_v1 # noqa: F401 + elif args.cache not in ("", None): + raise ValueError( + f"Cache mode '{args.cache}' is not supported. Only 'static_v1' is supported." + ) + + trt_lm = compile_lm_torchtrt(model, args, device) + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + trt_model.model = trt_lm + else: + trt_model.language_model = trt_lm + + emb_layer = emb_layer.to(device) + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + trt_model.lm_head = trt_model.lm_head.to(device) + + if args.cache == "static_v1": + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + trt_generate = generate_mm_qwen2_5_vl_with_static_cache + else: # eagle2 + trt_generate = generate_mm_with_static_cache + else: + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + trt_generate = generate_mm_qwen2_5_vl + else: # eagle2 + trt_generate = generate_mm + + # Prepare args for generate function + generate_args = { + "model": trt_model, + "pixel_values": inputs["pixel_values"], + "input_ids": inputs["input_ids"], + "eos_token_id": processor.tokenizer.eos_token_id, + "emb_layer": emb_layer, + "max_new_tokens": args.num_tokens, + } + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + generate_args["image_grid_thw"] = inputs["image_grid_thw"] + + if args.cache == "static_v1" or args.benchmark: + generate_args["with_timing"] = True + + if args.cache == "static_v1": + generate_args["device"] = device + + # Run TRT generation + trt_output = trt_generate(**generate_args) + + # Unpack results + if args.benchmark or args.cache == "static_v1": + trt_gen_tokens, _, overall_time, _, _ = trt_output + trt_stats = record_stats( + "TensorRT", + [overall_time / 1000], # time is in ms, convert to s + args.precision, + batch_size=args.batch_size, + ) + else: + trt_gen_tokens = trt_output + + # -------------------------------------------------------------------------# + # 5. Reporting + # -------------------------------------------------------------------------# + if not args.benchmark: + if args.enable_pytorch_run: + print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) + print_outputs("TensorRT", trt_gen_tokens, processor.tokenizer) + + if args.enable_pytorch_run: + print( + f"PyTorch and TensorRT outputs match: " + f"{torch.equal(pyt_gen_tokens, trt_gen_tokens)}" + ) + + if args.benchmark: + if args.enable_pytorch_run: + print("========= PyTorch PERFORMANCE =========\n") + print(pyt_stats) + print("=====================\n") + print("========= TensorRT PERFORMANCE =========\n") + print(trt_stats) diff --git a/tools/llm/static_cache_v1.py b/tools/llm/static_cache_v1.py index b60396c08b..58daacedf5 100644 --- a/tools/llm/static_cache_v1.py +++ b/tools/llm/static_cache_v1.py @@ -201,7 +201,7 @@ def insert_kv_slicing_before_sdpa( args=(slice_7, 3), kwargs={}, ) - # =============================================== # + # Concatenate the sliced tensors to build KV cache cat = gm.graph.create_node( "call_function", diff --git a/tools/llm/test_qwen2.5_components.py b/tools/llm/test_qwen2.5_components.py index 60482bf22d..1c1366bd0c 100644 --- a/tools/llm/test_qwen2.5_components.py +++ b/tools/llm/test_qwen2.5_components.py @@ -16,7 +16,7 @@ # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from register_sdpa import * +from torchtrt_ext import register_sdpa ATOL = 1e-5 RTOL = 1e-5 diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 2c3434b0ed..77ef26b33a 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -47,7 +47,7 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): return ep -def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule): +def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule, device: str = "cuda:0"): """ Extracts and returns zeroed static KV cache tensors from a torch.fx.GraphModule. This should only be used for static cache_v1 and static cache_v2. @@ -71,7 +71,7 @@ def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule): torch.zeros( input.meta["val"].shape, dtype=input.meta["val"].dtype, - device=torch.device("cuda:0"), + device=torch.device(device), ) ) @@ -242,3 +242,628 @@ def record_stats(backend, timings, precision, batch_size=1, compile_time_s=None) "Compile Time(s)": compile_time_s, } return stats + + +def _prepare_mm_inputs( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + emb_layer: torch.nn.Embedding, + with_timing: bool = False, +): + """ + Prepares multimodal inputs for Eagle2-style VLMs by encoding images and merging with text embeddings. + Optionally times the vision and MLP parts. + """ + vision_time = 0.0 + mlp_time = 0.0 + vit_embeds = None + + if pixel_values is not None: + if with_timing: + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + mlp_start = torch.cuda.Event(enable_timing=True) + mlp_end = torch.cuda.Event(enable_timing=True) + + vision_start.record() + vit_out = model.vision_model(pixel_values) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + else: + vit_out = model.vision_model(pixel_values) + + vit_embeds = ( + vit_out.last_hidden_state + if hasattr(vit_out, "last_hidden_state") + else vit_out + ) + + if with_timing: + mlp_start.record() + + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = model.pixel_shuffle( + vit_embeds, scale_factor=model.downsample_ratio + ) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = model.mlp1(vit_embeds) + + if with_timing: + mlp_end.record() + torch.cuda.synchronize() + mlp_time = mlp_start.elapsed_time(mlp_end) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat_emb = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + try: + flat_emb[mask] = vit_embeds.reshape(-1, C).to(flat_emb.dtype)[: mask.sum()] + except Exception: + # Fallback in unlikely size-mismatch cases + flat_emb[mask] = vit_embeds.reshape(-1, C)[: mask.sum()].to(flat_emb.dtype) + seq_embeds = flat_emb.view(B, N, C) + + if with_timing: + return seq_tokens, seq_embeds, vision_time, mlp_time + else: + return seq_tokens, seq_embeds + + +def generate_mm( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, + device: str = "cuda:0", + with_timing: bool = False, +): + """Greedy decode for Eagle2-style VLM, with optional detailed timing. + + Parameters + ---------- + model : nn.Module + Must expose vision_model, mlp1, language_model, pixel_shuffle, downsample_ratio, image_token_index. + pixel_values : Tensor | None + Input image batch (B,C,H,W) or None. + input_ids : LongTensor (B, N_prompt) + Text prompt token ids including [IMG] placeholder(s). + eos_token_id : int + Stop generation when all sequences emit EOS. + emb_layer : nn.Embedding + Embedding layer for input_ids. + max_new_tokens : int + Maximum number of new tokens to generate. + with_timing : bool + If True, returns detailed timing information. + + Returns + ------- + if with_timing is False: + torch.LongTensor: Generated token sequence (only new tokens). + if with_timing is True: + tuple: ( + seq_tokens: Full generated token sequence, + step_times: List of latencies for each generation step, + overall_time: Total generation time, + vision_time: Vision encoder latency, + mlp_time: MLP latency + ) + """ + if with_timing: + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + overall_start.record() + + # --- Input preparation --- + if with_timing: + seq_tokens, seq_embeds, vision_time, mlp_time = _prepare_mm_inputs( + model, pixel_values, input_ids, emb_layer, with_timing=True + ) + else: + seq_tokens, seq_embeds = _prepare_mm_inputs( + model, pixel_values, input_ids, emb_layer, with_timing=False + ) + + # ───────────────────────────────── Greedy loop ─────────────────────────────────────────────────── + step_times = [] + generated = 0 + + while generated < max_new_tokens: + if with_timing: + lm_start.record() + + cur_embeds = seq_embeds + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + with torch.no_grad(): + logits = model.language_model( + inputs_embeds=cur_embeds, position_ids=position_ids + ) + if hasattr(logits, "logits"): + logits = logits.logits + + next_tok = torch.argmax(logits[:, -1, :], dim=-1) + + if with_timing: + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=-1) + seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + if with_timing: + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + return ( + seq_tokens[:, input_ids.shape[1] :], + step_times, + overall_time, + vision_time, + mlp_time, + ) + else: + return seq_tokens[:, input_ids.shape[1] :] + + +@torch.inference_mode() +def generate_mm_with_static_cache( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, + device: str = "cuda:0", + with_timing: bool = False, +): + """ + Greedy Decoder for multimodal VLM (using static KV-cache v1), with optional timing. + Basic structure is identical to LM version (generate_with_static_cache) but + * Input is `inputs_embeds` + * Vision tokens are sent together only in the first step + """ + if with_timing: + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + overall_start.record() + vision_time, mlp_time = 0.0, 0.0 + + if with_timing: + seq_tokens, seq_embeds, vision_time, mlp_time = _prepare_mm_inputs( + model, pixel_values, input_ids, emb_layer, with_timing=True + ) + else: + seq_tokens, seq_embeds = _prepare_mm_inputs( + model, pixel_values, input_ids, emb_layer, with_timing=False + ) + + # ───────────────────── KV-cache initialization ───────────────────── + kv_cache = get_zeroed_static_cache_inputs(model.language_model, device=device) + start_idx = 0 + end_idx = seq_embeds.size(1) + generated = 0 + max_total_len = end_idx + max_new_tokens + output_tokens = seq_tokens.clone() + step_times = [] + + # ───────────────────── Greedy loop ─────────────────────── + while output_tokens.size(1) < max_total_len: + if with_timing: + lm_start.record() + + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + ) + + logits_and_kv = model.language_model(*input_signature) + logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] + + next_tok = logits[:, -1, :].argmax(dim=-1) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + next_embed = emb_layer(next_tok)[:, None, :] + seq_embeds = next_embed + + generated += 1 + start_idx = end_idx + end_idx += 1 + + if with_timing: + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + if (next_tok == eos_token_id).all(): + break + + if with_timing: + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + return ( + output_tokens[:, input_ids.shape[1] :], + step_times, + overall_time, + vision_time, + mlp_time, + ) + else: + return output_tokens[:, input_ids.shape[1] :] + + +def _prepare_qwen_mm_inputs( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + emb_layer: torch.nn.Embedding, + with_timing: bool = False, +): + """ + Prepares multimodal inputs for Qwen2.5-VL by encoding images and merging with text embeddings. + Optionally times the vision part. + """ + vision_time = 0.0 + image_embeds = None + + if pixel_values is not None: + if with_timing: + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + vision_start.record() + + image_embeds = model.visual(pixel_values, image_grid_thw) + + if with_timing: + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if image_embeds is not None: + mask = seq_tokens == model.config.image_token_id + num_image_tokens = mask.sum().item() + if num_image_tokens != image_embeds.shape[0]: + raise ValueError( + f"Number of image tokens ({num_image_tokens}) does not match number of image embeddings ({image_embeds.shape[0]})." + ) + mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds) + seq_embeds = seq_embeds.masked_scatter( + mask_expanded, image_embeds.to(seq_embeds.dtype) + ) + + if with_timing: + return seq_tokens, seq_embeds, vision_time + else: + return seq_tokens, seq_embeds + + +def generate_mm_qwen2_5_vl( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, + with_timing: bool = False, +): + """ + Custom generation function for the Qwen2_5_VLForConditionalGeneration model, with optional timing. + """ + if with_timing: + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + overall_start.record() + + if with_timing: + seq_tokens, seq_embeds, vision_time = _prepare_qwen_mm_inputs( + model, + pixel_values, + input_ids, + image_grid_thw, + emb_layer, + with_timing=True, + ) + else: + seq_tokens, seq_embeds = _prepare_qwen_mm_inputs( + model, + pixel_values, + input_ids, + image_grid_thw, + emb_layer, + with_timing=False, + ) + + step_times = [] + generated = 0 + while generated < max_new_tokens: + if with_timing: + lm_start.record() + + position_ids = ( + torch.arange( + 0, seq_tokens.size(1), dtype=torch.long, device=seq_tokens.device + ) + .unsqueeze(0) + .expand(seq_embeds.size(0), seq_embeds.size(1)) + ) + + with torch.no_grad(): + outputs = model.model( + inputs_embeds=seq_embeds, + position_ids=position_ids, + ) + hidden_states = ( + outputs + if isinstance(outputs, torch.Tensor) + else outputs.last_hidden_state + ) + + logits = model.lm_head(hidden_states[:, -1, :]) + next_tok = torch.argmax(logits, dim=-1) + + if with_timing: + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) + next_emb = emb_layer(next_tok)[:, None, :] + seq_embeds = torch.cat([seq_embeds, next_emb], dim=1) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + if with_timing: + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + return ( + seq_tokens[:, input_ids.shape[1] :], + step_times, + overall_time, + vision_time, + 0.0, + ) + else: + return seq_tokens[:, input_ids.shape[1] :] + + +def generate_mm_qwen2_5_vl_with_static_cache( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, + device: str = "cuda:0", + with_timing: bool = False, +) -> torch.LongTensor: + """ + Greedy Decoder for Qwen-2.5-VL using static KV-cache, with optional timing. + """ + if with_timing: + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + overall_start.record() + + if with_timing: + seq_tokens, seq_embeds, vision_time = _prepare_qwen_mm_inputs( + model, + pixel_values, + input_ids, + image_grid_thw, + emb_layer, + with_timing=True, + ) + else: + seq_tokens, seq_embeds = _prepare_qwen_mm_inputs( + model, + pixel_values, + input_ids, + image_grid_thw, + emb_layer, + with_timing=False, + ) + + kv_cache = get_zeroed_static_cache_inputs(model.model, device=device) + start_idx = 0 + end_idx = seq_embeds.size(1) + generated = 0 + max_total_len = end_idx + max_new_tokens + output_tokens = seq_tokens.clone() + step_times = [] + + while output_tokens.size(1) < max_total_len: + if with_timing: + lm_start.record() + + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + ) + + outputs_and_kv = model.model(*input_signature) + hidden_states, kv_cache = outputs_and_kv[0], outputs_and_kv[1:] + + logits = model.lm_head(hidden_states[:, -1, :]) + next_tok = logits.argmax(dim=-1) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + next_embed = emb_layer(next_tok)[:, None, :] + seq_embeds = next_embed + + generated += 1 + start_idx = end_idx + end_idx += 1 + + if with_timing: + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + if (next_tok == eos_token_id).all(): + break + + if with_timing: + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + # For Qwen, there is no separate MLP part like in Eagle, so mlp_time is 0. + return ( + output_tokens[:, input_ids.shape[1] :], + step_times, + overall_time, + vision_time, + 0.0, + ) + else: + return output_tokens[:, input_ids.shape[1] :] + + +@torch.inference_mode() +def generate_mm_qwen2_5_vl_with_timing( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, +): + """ + Custom generation function for the Qwen2_5_VLForConditionalGeneration model with timing. + """ + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + + overall_start.record() + + vision_time = 0.0 + image_embeds = None + if pixel_values is not None: + vision_start.record() + image_embeds = model.visual(pixel_values, image_grid_thw) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if image_embeds is not None: + mask = seq_tokens == model.config.image_token_id + num_image_tokens = mask.sum().item() + if num_image_tokens != image_embeds.shape[0]: + raise ValueError( + f"Number of image tokens ({num_image_tokens}) does not match number of image embeddings ({image_embeds.shape[0]})." + ) + mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds) + seq_embeds = seq_embeds.masked_scatter( + mask_expanded, image_embeds.to(seq_embeds.dtype) + ) + + step_times = [] + generated = 0 + while generated < max_new_tokens: + lm_start.record() + position_ids = ( + torch.arange( + 0, seq_tokens.size(1), dtype=torch.long, device=seq_tokens.device + ) + .unsqueeze(0) + .expand(seq_embeds.size(0), seq_embeds.size(1)) + ) + + with torch.no_grad(): + outputs = model.model( + inputs_embeds=seq_embeds, + position_ids=position_ids, + ) + hidden_states = ( + outputs + if isinstance(outputs, torch.Tensor) + else outputs.last_hidden_state + ) + + logits = model.lm_head(hidden_states[:, -1, :]) + next_tok = torch.argmax(logits, dim=-1) + + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) + next_emb = emb_layer(next_tok)[:, None, :] + seq_embeds = torch.cat([seq_embeds, next_emb], dim=1) + + generated += 1 + + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + + # For Qwen, there is no separate MLP part like in Eagle, so mlp_time is 0. + return seq_tokens, step_times, overall_time, vision_time, 0.0