From e4e09bb82964960b4bda44b684d7f0f5e6e933b1 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 2 Jul 2025 14:59:45 +0000 Subject: [PATCH 1/9] integrated vlm code for benchmark --- tools/llm/run_vlm.py | 387 ++++++++++++++++++++++++++++++++++++ tools/llm/utils.py | 461 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 848 insertions(+) create mode 100644 tools/llm/run_vlm.py diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py new file mode 100644 index 0000000000..f6bd62624f --- /dev/null +++ b/tools/llm/run_vlm.py @@ -0,0 +1,387 @@ +""" +.. _run_vlm: + +Running VLM inference with Torch-TensorRT +========================================================== + +This script mirrors the style and structure of *run_llm.py*, illustrating a +Torch-TensorRT (dynamo backend) workflow for Visual-Language Models (VLMs). +""" + +import argparse +import copy +import os +import sys +from contextlib import nullcontext +from typing import Tuple + +import requests +import torch +import torch_tensorrt +from PIL import Image +from torchtrt_ext import register_sdpa +from transformers import AutoModel, AutoProcessor +from utils import ( + generate_mm, + generate_mm_with_static_cache, + record_stats, + time_generate_mm, +) + +# -----------------------------------------------------------------------------# +# Global configuration +# -----------------------------------------------------------------------------# +DEVICE = torch.device("cuda:0") + +# Register SDPA as a standalone operator. Converter & lowering pass are defined +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +import transformers.models.qwen2.modeling_qwen2 as mq # noqa: E402 + +mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"] + +# -----------------------------------------------------------------------------# +# Model loading helpers +# -----------------------------------------------------------------------------# + + +def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): + """ + Load Eagle2 model and processor. + + Returns + ------- + tuple[torch.nn.Module, transformers.AutoProcessor, torch.nn.Embedding] + The model, its processor and the language-model input embedding layer. + """ + model_id = "nvidia/Eagle2-2B" + with torch.no_grad(): + model = ( + AutoModel.from_pretrained( + model_id, trust_remote_code=True, torch_dtype=torch_dtype + ) + .eval() + .to(device) + ) + + processor = AutoProcessor.from_pretrained( + model_id, trust_remote_code=True, use_fast=True + ) + if hasattr(processor, "tokenizer"): + processor.tokenizer.padding_side = "left" + + emb_layer = model.language_model.get_input_embeddings().to(torch_dtype).to(device) + return model, processor, emb_layer + + +def _load_model( + model_name: str, device: torch.device, torch_dtype: torch.dtype +) -> Tuple[torch.nn.Module, AutoProcessor, torch.nn.Embedding]: + """Dispatch helper for supported VLMs.""" + if model_name.lower() == "eagle2": + return _load_eagle2(device, torch_dtype) + msg = f"Unsupported model: {model_name}" + raise ValueError(msg) + + +# -----------------------------------------------------------------------------# +# Torch-TensorRT compilation helpers +# -----------------------------------------------------------------------------# + + +class _LMNoCache(torch.nn.Module): + """ + Thin wrapper that exposes a language model via ``inputs_embeds`` without KV-cache. + """ + + def __init__(self, lm): + super().__init__() + self.lm = lm + + def forward(self, inputs_embeds, position_ids): + out = self.lm(inputs_embeds=inputs_embeds, position_ids=position_ids) + return out.logits if hasattr(out, "logits") else out + + +def _compile_eagle2_lm( + language_model: torch.nn.Module, + input_embeds: torch.Tensor, + args: argparse.Namespace, +) -> torch.nn.Module: + """ + Compile Eagle2 language model with Torch-TensorRT. + + The function follows the same precision-specific flag logic used in + *run_llm.py* for consistency. + """ + lm_wrap = _LMNoCache(language_model).to(DEVICE).eval() + max_seq_len = input_embeds.shape[1] + args.num_tokens + + S = torch.export.Dim("seq", min=1, max=max_seq_len) + position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) + dyn_shapes = {"inputs_embeds": {1: S}, "position_ids": {1: S}} + + # Precision-specific flags --------------------------------------------------# + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + else: # FP32 + enabled_precisions = {torch.float32} + + with torch.inference_mode(): + exported = torch.export.export( + lm_wrap, + (input_embeds, position_ids), + dynamic_shapes=dyn_shapes, + strict=False, + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_mod = torch_tensorrt.dynamo.compile( + exported, + inputs=[input_embeds, position_ids], + enabled_precisions=enabled_precisions, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + offload_module_to_cpu=True, + min_block_size=args.min_block_size, + ) + return trt_mod + + +def compile_torchtrt( + model: torch.nn.Module, args: argparse.Namespace +) -> torch.nn.Module: + """ + Front-end dispatcher mirroring *run_llm.py*’s `compile_torchtrt`. + + Depending on the target VLM, delegates to the appropriate compile routine. + """ + torch_dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + example_embeds = torch.randn( + 1, + 2560, + model.language_model.config.hidden_size, + dtype=torch_dtype, + device=DEVICE, + ) + + if args.model.lower() == "eagle2": + return _compile_eagle2_lm(model.language_model, example_embeds, args) + + msg = f"Unsupported model for compilation: {args.model}" + raise ValueError(msg) + + +# -----------------------------------------------------------------------------# +# Utility helpers +# -----------------------------------------------------------------------------# + + +def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): + """Pretty-print generated text for comparison.""" + print(f"========= {backend_name} =========") + print( + f"{backend_name} model generated text: ", + tokenizer.decode(gen_tokens[0], skip_special_tokens=True), + ) + print("===================================") + + +# -----------------------------------------------------------------------------# +# Main driver +# -----------------------------------------------------------------------------# +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run VLM inference (PyTorch & TensorRT back-ends)" + ) + parser.add_argument("--model", default="eagle2", help="VLM model name") + parser.add_argument("--prompt", default="Describe this image.", help="Prompt text") + parser.add_argument( + "--precision", + default="FP16", + choices=["FP16", "BF16", "FP32"], + help="Computation precision", + ) + parser.add_argument("--iterations", type=int, default=5, help="# iterations") + parser.add_argument("--min_block_size", type=int, default=1, help="Min block size") + parser.add_argument("--num_tokens", type=int, default=128, help="# new tokens") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--isl", type=int, default=2048, help="Input seq length") + parser.add_argument( + "--enable_pytorch_run", + action="store_true", + help="Run the PyTorch baseline as well", + ) + parser.add_argument( + "--cache", + default="", + choices=["", "static_v1"], + help="KV-cache variant to use", + ) + parser.add_argument( + "--debug", action="store_true", help="Enable Torch-TensorRT debug logs" + ) + parser.add_argument( + "--benchmark", action="store_true", help="Enable benchmarking mode" + ) + + args = parser.parse_args() + + # -------------------------------------------------------------------------# + # 1. Model / processor / embeddings + # -------------------------------------------------------------------------# + dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + model, processor, emb_layer = _load_model(args.model, DEVICE, dtype) + + # -------------------------------------------------------------------------# + # 2. Input construction (image + text prompt) + # -------------------------------------------------------------------------# + url = "https://cdn.pixabay.com/photo/2019/08/08/23/33/car-4393990_1280.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + if args.benchmark: + prompt_len = args.isl - 1792 - 26 + prompt_txt = " ".join(["token"] * max(prompt_len, 0)) + else: + prompt_txt = args.prompt + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": prompt_txt}, + ], + } + ] + + txt = [ + processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + ] + img_in, vid_in = processor.process_vision_info(messages) + inputs = processor( + text=txt, images=img_in, videos=vid_in, return_tensors="pt", padding=True + ).to(DEVICE) + + max_output_len = inputs["input_ids"].shape[1] + args.num_tokens + + # -------------------------------------------------------------------------# + # 3. Optional: PyTorch baseline + # -------------------------------------------------------------------------# + pyt_gen_tokens = pyt_timings = pyt_stats = None + if args.enable_pytorch_run: + pyt_gen_tokens = generate_mm( + model, + inputs["pixel_values"], + inputs["input_ids"], + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + ) + if args.benchmark: + pyt_timings = time_generate_mm( + generate_mm, + model, + inputs["pixel_values"].clone(), + inputs["input_ids"].clone(), + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + iterations=args.iterations, + ) + pyt_stats = record_stats( + "PyTorch", + pyt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + # Register static cache lowering passes if requested + if args.cache == "static_v1": + import static_cache_v1 # noqa: F401 + + # -------------------------------------------------------------------------# + # 4. Torch-TensorRT compile & run + # -------------------------------------------------------------------------# + trt_lm = compile_torchtrt(model, args) + trt_model = copy.deepcopy(model) + trt_model.language_model = trt_lm + + emb_layer = emb_layer.to(DEVICE) + + if args.cache == "static_v1": + trt_generate = generate_mm_with_static_cache + else: + trt_generate = generate_mm + + trt_gen_tokens = trt_generate( + trt_model, + inputs["pixel_values"], + inputs["input_ids"], + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + DEVICE if args.cache == "static_v1" else None, # device arg only for static_v1 + ) + + if args.benchmark: + trt_timings = time_generate_mm( + trt_generate, + trt_model, + inputs["pixel_values"].clone(), + inputs["input_ids"].clone(), + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + iterations=args.iterations, + device=DEVICE if args.cache == "static_v1" else None, + ) + trt_stats = record_stats( + "TensorRT", + trt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + # -------------------------------------------------------------------------# + # 5. Reporting + # -------------------------------------------------------------------------# + if not args.benchmark: + if args.enable_pytorch_run: + print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) + print_outputs("TensorRT", trt_gen_tokens, processor.tokenizer) + + if args.enable_pytorch_run: + print( + f"PyTorch and TensorRT outputs match: " + f"{torch.equal(pyt_gen_tokens, trt_gen_tokens)}" + ) + + if args.benchmark: + if args.enable_pytorch_run: + print("========= PyTorch PERFORMANCE =========\n") + print(pyt_stats) + print("=====================\n") + print("========= TensorRT PERFORMANCE =========\n") + print(trt_stats) diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 2c3434b0ed..5e188f0e8b 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -242,3 +242,464 @@ def record_stats(backend, timings, precision, batch_size=1, compile_time_s=None) "Compile Time(s)": compile_time_s, } return stats + + +def generate_mm( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, +): + """Greedy decode for Eagle2-style VLM. + + Parameters + ---------- + model : nn.Module + Must expose vision_model, mlp1, language_model, pixel_shuffle, downsample_ratio, image_token_index. + pixel_values : Tensor | None + Input image batch (B,C,H,W) or None. + input_ids : LongTensor (B, N_prompt) + Text prompt token ids including [IMG] placeholder(s). + max_output_seq_length : int + Maximum tokens to generate **in addition to** the prompt. + eos_token_id : int + Stop generation when all sequences emit EOS. + emb_layer : nn.Embedding + Embedding layer for input_ids. + """ + + vit_embeds = None + + if pixel_values is not None: + # --- Vision encoder timing --- + vis_s = torch.cuda.Event(enable_timing=True) + vis_e = torch.cuda.Event(enable_timing=True) + vis_s.record() + vit_out = model.vision_model(pixel_values) + vis_e.record() + torch.cuda.synchronize() + + vit_embeds = ( + vit_out.last_hidden_state + if hasattr(vit_out, "last_hidden_state") + else vit_out + ) + + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = model.pixel_shuffle( + vit_embeds, scale_factor=model.downsample_ratio + ) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = model.mlp1(vit_embeds) + + # 2) Text token embeddings + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat_emb = seq_embeds.view(B * N, C) + + mask = seq_tokens.view(B * N) == model.image_token_index + try: + flat_emb[mask] = vit_embeds.reshape(-1, C).to(flat_emb.dtype)[: mask.sum()] + except Exception: + # Fallback in unlikely size-mismatch cases + flat_emb[mask] = vit_embeds.reshape(-1, C)[: mask.sum()].to(flat_emb.dtype) + seq_embeds = flat_emb.view(B, N, C) + + # ───────────────────────────────── Greedy loop ─────────────────────────────────────────────────── + isl = seq_tokens.shape[1] + osl = max_output_seq_length - isl + + generated = 0 + + while generated < osl: + cur_embeds = seq_embeds # full seq first step or cache off + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + with torch.no_grad(): + logits = model.language_model(inputs_embeds=cur_embeds, position_ids=position_ids) + if hasattr(logits, "logits"): + logits = logits.logits + + next_tok = torch.argmax(logits[:, -1, :], dim=-1) # (B,) + # append token & embed + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=-1) + seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + return seq_tokens[:, input_ids.shape[1] :] + + +@torch.inference_mode() +def generate_mm_with_static_cache( + model, # Complete VLM module + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, # (B, N_prompt) + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + device: str = "cuda:0", +) -> torch.LongTensor: # (B, N_prompt + new) + """ + Greedy Decoder for multimodal VLM (using static KV-cache v1). + Basic structure is identical to LM version (generate_with_static_cache) but + * Input is `inputs_embeds` + * Vision tokens are sent together only in the first step + """ + + # ───────────────────── Vision encoding ───────────────────── + vit_embeds = None + if pixel_values is not None: + vit_latent = model.vision_model(pixel_values) + vit_embeds = ( + vit_latent.last_hidden_state + if hasattr(vit_latent, "last_hidden_state") + else vit_latent + ) + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.view(vit_embeds.size(0), h, w, -1) + vit_embeds = model.pixel_shuffle(vit_embeds, model.downsample_ratio) + vit_embeds = vit_embeds.view(vit_embeds.size(0), -1, vit_embeds.size(-1)) + vit_embeds = model.mlp1(vit_embeds) # (B, N_img, C) + + # ───────────────────── Text embedding & [IMG] replacement ───────────── + seq_tokens = input_ids.clone() # (B, N_txt) + seq_embeds = emb_layer(seq_tokens) # (B, N_txt, C) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + # ───────────────────── KV-cache initialization ───────────────────── + kv_cache = get_zeroed_static_cache_inputs( + model.language_model + ) + start_idx = 0 # First token index + end_idx = seq_embeds.size(1) # Prompt length + generated = 0 + max_total_len = max_output_seq_length + output_tokens = seq_tokens.clone() + + # ───────────────────── Greedy loop ─────────────────────── + while output_tokens.size(1) < max_total_len: + + # When using static cache: + # - First step: Use full prompt embedding + # - Subsequent steps: Use only new token embedding (KV cache remembers previous tokens) + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + + # position_ids: Same pattern as generate_with_static_cache + # - First step: Position of entire sequence + # - Subsequent steps: Position of current token only + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + + # is_causal = True if cur_embeds.shape[1] > 1 else False + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + # is_causal, + ) + + logits_and_kv = model.language_model(*input_signature) + logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] + + next_tok = logits[:, -1, :].argmax(dim=-1) # (B,) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + # Prepare for next step - Static cache only needs new token + next_embed = emb_layer(next_tok)[:, None, :] # (B, 1, C) + seq_embeds = next_embed # Next step uses only new token + + generated += 1 + start_idx = end_idx + end_idx += 1 + # is_causal = True # Causal mask active from now on + + if (next_tok == eos_token_id).all(): + break + + return output_tokens + + +def generate_mm_with_timing( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + use_cache: bool = False, +): + # Create timing events + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + mlp_start = torch.cuda.Event(enable_timing=True) + mlp_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + + overall_start.record() + + vit_embeds = None + if pixel_values is not None: + vision_start.record() + vit_out = model.vision_model(pixel_values) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + + vit_embeds = ( + vit_out.last_hidden_state + if hasattr(vit_out, "last_hidden_state") + else vit_out + ) + + mlp_start.record() + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = model.pixel_shuffle( + vit_embeds, scale_factor=model.downsample_ratio + ) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = model.mlp1(vit_embeds) + mlp_end.record() + torch.cuda.synchronize() + mlp_time = mlp_start.elapsed_time(mlp_end) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat_emb = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat_emb[mask] = vit_embeds.reshape(-1, C).to(flat_emb.dtype)[: mask.sum()] + seq_embeds = flat_emb.view(B, N, C) + + step_times = [] + generated = 0 + past_key_values = None + + while generated < max_output_seq_length: + lm_start.record() + cur_embeds = seq_embeds + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + with torch.no_grad(): + logits = model.language_model( + inputs_embeds=cur_embeds, position_ids=position_ids + ) + if hasattr(logits, "logits"): + logits = logits.logits + + next_tok = torch.argmax(logits[:, -1, :], dim=-1) + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=-1) + seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + + return seq_tokens, step_times, overall_time, vision_time, mlp_time + + +@torch.inference_mode() +def generate_mm_with_static_cache_timing( + model, # Complete VLM module + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, # (B, N_prompt) + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, + device: str = "cuda:0", +) -> tuple: # (seq_tokens, step_times, overall_time, vision_time, mlp_time) + """ + Greedy Decoder for multimodal VLM (using static KV-cache v1) + detailed timing measurement. + + Returns: + seq_tokens: Generated token sequence + step_times: Language model inference time for each step (ms) + overall_time: Total execution time (ms) + vision_time: Vision encoding time (ms) + mlp_time: MLP processing time (ms) + """ + + # ───────────────────── Create timing events ───────────────────── + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + mlp_start = torch.cuda.Event(enable_timing=True) + mlp_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + + overall_start.record() + + # ───────────────────── Vision encoding ───────────────────── + vit_embeds = None + vision_time = 0.0 + mlp_time = 0.0 + + if pixel_values is not None: + vision_start.record() + vit_latent = model.vision_model(pixel_values) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + + vit_embeds = ( + vit_latent.last_hidden_state + if hasattr(vit_latent, "last_hidden_state") + else vit_latent + ) + + mlp_start.record() + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.view(vit_embeds.size(0), h, w, -1) + vit_embeds = model.pixel_shuffle(vit_embeds, model.downsample_ratio) + vit_embeds = vit_embeds.view(vit_embeds.size(0), -1, vit_embeds.size(-1)) + vit_embeds = model.mlp1(vit_embeds) # (B, N_img, C) + mlp_end.record() + torch.cuda.synchronize() + mlp_time = mlp_start.elapsed_time(mlp_end) + + # ───────────────────── Text embedding & [IMG] replacement ───────────── + seq_tokens = input_ids.clone() # (B, N_txt) + seq_embeds = emb_layer(seq_tokens) # (B, N_txt, C) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + # ───────────────────── KV-cache initialization ───────────────────── + kv_cache = get_zeroed_static_cache_inputs( + model.language_model + ) + start_idx = 0 # First token index + end_idx = seq_embeds.size(1) # Prompt length + generated = 0 + max_total_len = end_idx + max_new_tokens + output_tokens = seq_tokens.clone() + step_times = [] # Timing for each step + + # ───────────────────── Greedy loop ─────────────────────── + while output_tokens.size(1) < max_total_len: + lm_start.record() + + # When using static cache: + # - First step: Use full prompt embedding + # - Subsequent steps: Use only new token embedding (KV cache remembers previous tokens) + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + + # position_ids: Same pattern as generate_with_static_cache + # - First step: Position of entire sequence + # - Subsequent steps: Position of current token only + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + + # is_causal = True if cur_embeds.shape[1] > 1 else False + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + # is_causal, + ) + + logits_and_kv = model.language_model(*input_signature) + logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] + + next_tok = logits[:, -1, :].argmax(dim=-1) # (B,) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + # Prepare for next step - Static cache only needs new token + next_embed = emb_layer(next_tok)[:, None, :] # (B, 1, C) + seq_embeds = next_embed # Next step uses only new token + + generated += 1 + start_idx = end_idx + end_idx += 1 + + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + if (next_tok == eos_token_id).all(): + break + + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + + return output_tokens, step_times, overall_time, vision_time, mlp_time + + +def time_generate_mm( + generate_fn, + model, + pixel_values, + input_ids, + output_seq_length, + eos_token_id, + emb_layer, + iterations=10, + device="cuda:0", +): + """ + Measure the time for generating a sentence over certain number of iterations + """ + timings = [] + for _ in range(iterations): + start_time = timeit.default_timer() + _ = generate_fn( + model, pixel_values, input_ids, output_seq_length, eos_token_id, emb_layer + ) + torch.cuda.synchronize() + end_time = timeit.default_timer() + timings.append(end_time - start_time) + + return timings From 9980c4cc741852a51d13ea698281d14125a2c98b Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 17 Jul 2025 16:16:22 +0000 Subject: [PATCH 2/9] add vision_model compile --- tools/llm/run_vlm.py | 78 +++++++++++++++++++++++++++++++++++++++++--- tools/llm/utils.py | 17 +++++----- 2 files changed, 82 insertions(+), 13 deletions(-) diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py index f6bd62624f..bae102aef8 100644 --- a/tools/llm/run_vlm.py +++ b/tools/llm/run_vlm.py @@ -185,6 +185,64 @@ def compile_torchtrt( raise ValueError(msg) +def _compile_eagle2_vision( + vision_model: torch.nn.Module, + example_pixel_values: torch.Tensor, + args: argparse.Namespace, +) -> torch.nn.Module: + """ + Compile Eagle2 vision model with Torch-TensorRT. + """ + # Set precision-specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + else: # FP32 + enabled_precisions = {torch.float32} + + with torch.inference_mode(): + exported = torch.export.export( + vision_model, + (example_pixel_values,), + strict=False, + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_mod = torch_tensorrt.dynamo.compile( + exported, + inputs=[example_pixel_values], + enabled_precisions=enabled_precisions, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + offload_module_to_cpu=True, + min_block_size=args.min_block_size, + ) + return trt_mod + + +def compile_vision_torchtrt( + model: torch.nn.Module, + args: argparse.Namespace, + example_pixel_values: torch.Tensor, +) -> torch.nn.Module: + """ + Dispatcher function for vision model compilation. + """ + if args.model.lower() == "eagle2": + return _compile_eagle2_vision(model.vision_model, example_pixel_values, args) + else: + raise ValueError(f"Unsupported model: {args.model}") + + # -----------------------------------------------------------------------------# # Utility helpers # -----------------------------------------------------------------------------# @@ -297,6 +355,7 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): processor.tokenizer.eos_token_id, emb_layer, ) + print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) if args.benchmark: pyt_timings = time_generate_mm( generate_mm, @@ -316,15 +375,26 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): compile_time_s=None, ) + # -------------------------------------------------------------------------# + # 4. Torch-TensorRT compile & run + # -------------------------------------------------------------------------# + + trt_model = copy.deepcopy(model) + # 4.1. Vision model compilation + # --- Add vision model compilation --- # + example_pixel_values = inputs["pixel_values"] + trt_vision = compile_vision_torchtrt(model, args, example_pixel_values) + trt_model.vision_model = trt_vision + + # -------------------------------------------------------------------------# + # 4.2. Language model compilation + # -------------------------------------------------------------------------# # Register static cache lowering passes if requested + # Cache is not applied to vision model. if args.cache == "static_v1": import static_cache_v1 # noqa: F401 - # -------------------------------------------------------------------------# - # 4. Torch-TensorRT compile & run - # -------------------------------------------------------------------------# trt_lm = compile_torchtrt(model, args) - trt_model = copy.deepcopy(model) trt_model.language_model = trt_lm emb_layer = emb_layer.to(DEVICE) diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 5e188f0e8b..a5b8662f27 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -251,6 +251,7 @@ def generate_mm( max_output_seq_length: int, eos_token_id: int, emb_layer: torch.nn.Embedding, + device: str = "cuda:0", ): """Greedy decode for Eagle2-style VLM. @@ -320,10 +321,12 @@ def generate_mm( while generated < osl: cur_embeds = seq_embeds # full seq first step or cache off position_ids = ( - torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) - ) + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) with torch.no_grad(): - logits = model.language_model(inputs_embeds=cur_embeds, position_ids=position_ids) + logits = model.language_model( + inputs_embeds=cur_embeds, position_ids=position_ids + ) if hasattr(logits, "logits"): logits = logits.logits @@ -383,9 +386,7 @@ def generate_mm_with_static_cache( seq_embeds = flat.view(B, N, C) # ───────────────────── KV-cache initialization ───────────────────── - kv_cache = get_zeroed_static_cache_inputs( - model.language_model - ) + kv_cache = get_zeroed_static_cache_inputs(model.language_model) start_idx = 0 # First token index end_idx = seq_embeds.size(1) # Prompt length generated = 0 @@ -609,9 +610,7 @@ def generate_mm_with_static_cache_timing( seq_embeds = flat.view(B, N, C) # ───────────────────── KV-cache initialization ───────────────────── - kv_cache = get_zeroed_static_cache_inputs( - model.language_model - ) + kv_cache = get_zeroed_static_cache_inputs(model.language_model) start_idx = 0 # First token index end_idx = seq_embeds.size(1) # Prompt length generated = 0 From e5e63e58dc8138798d9a4b29db359c8b327b5b4c Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 21 Jul 2025 16:20:49 +0000 Subject: [PATCH 3/9] Improve clarity of naming and comments --- tools/llm/run_vlm.py | 183 ++++++++++++++++++++++++++++++------------- 1 file changed, 127 insertions(+), 56 deletions(-) diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py index bae102aef8..0bafd9ecc9 100644 --- a/tools/llm/run_vlm.py +++ b/tools/llm/run_vlm.py @@ -1,11 +1,26 @@ """ .. _run_vlm: -Running VLM inference with Torch-TensorRT +Benchmarking VLM Inference with Torch-TensorRT ========================================================== -This script mirrors the style and structure of *run_llm.py*, illustrating a -Torch-TensorRT (dynamo backend) workflow for Visual-Language Models (VLMs). +This script provides a framework for benchmarking the performance of Visual-Language +Models (VLMs). It optimizes the two most computationally intensive components of a +VLM—the language model and the vision model (image feature extraction)—using +the Torch-TensorRT dynamo backend. + +Key Features: +- **Component-wise Optimization**: Compiles both the language and vision models + separately with Torch-TensorRT to accelerate inference. +- **Performance Benchmarking**: Runs the model for multiple iterations to + measure and compare inference latency against the PyTorch baseline. +- **Output Verification**: Checks for token-level consistency between the optimized + TensorRT model and the original PyTorch model to ensure correctness. +- **KV Cache Testing**: Includes options to test inference with and without + KV caching to evaluate its impact on performance. + +This tool mirrors the style and structure of `run_llm.py`, providing a clear +workflow for VLM optimization and analysis. """ import argparse @@ -20,7 +35,7 @@ import torch_tensorrt from PIL import Image from torchtrt_ext import register_sdpa -from transformers import AutoModel, AutoProcessor +from transformers import AutoConfig, AutoModel, AutoProcessor from utils import ( generate_mm, generate_mm_with_static_cache, @@ -33,11 +48,30 @@ # -----------------------------------------------------------------------------# DEVICE = torch.device("cuda:0") -# Register SDPA as a standalone operator. Converter & lowering pass are defined -sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -import transformers.models.qwen2.modeling_qwen2 as mq # noqa: E402 +# --- WORKAROUND FOR EAGLE2 SDPA COMPILATION --- +# Eagle2's language model (Qwen2) implicitly defaults to "flash_attention_2" +# due to settings in its remote code and config.json. This prevents direct +# compilation with SDPA. To work around this without modifying the library, + +# we "monkey-patch" the global attention function map for Qwen2. +# This ensures that any part of the code (including torch.export) requesting +# "flash_attention_2" will receive the "sdpa" implementation instead. +# This patch is global for the script's execution context. +import transformers.models.qwen2.modeling_qwen2 as mq mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"] +# --- END WORKAROUND --- + +# --- Model-specific constants for benchmark and compilation --- +# Centralizing these values improves readability and maintainability. +MODEL_CONSTANTS = { + "nvidia/Eagle2-2B": { + "EXAMPLE_SEQLEN": 2560, # A fixed sequence length for creating the example tensor for TRT compilation. + "IMAGE_TOKENS": 1792, # Number of special tokens used to represent the image patch embeddings in the input sequence for Eagle2-2B VLM. + "PROMPT_WRAPPER_TOKENS": 26, # The number of special/processing tokens added by the processor's chat template in benchmark mode. + } +} +# --- END Model-specific constants --- # -----------------------------------------------------------------------------# # Model loading helpers @@ -46,7 +80,7 @@ def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): """ - Load Eagle2 model and processor. + Load nvidia/Eagle2-2B model and processor, ensuring the language model uses SDPA. Returns ------- @@ -57,7 +91,10 @@ def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): with torch.no_grad(): model = ( AutoModel.from_pretrained( - model_id, trust_remote_code=True, torch_dtype=torch_dtype + model_id, + trust_remote_code=True, + torch_dtype=torch_dtype, + # attn_implementation="sdpa" is ignored due to the model's remote code. ) .eval() .to(device) @@ -73,13 +110,68 @@ def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): return model, processor, emb_layer -def _load_model( +def load_model( model_name: str, device: torch.device, torch_dtype: torch.dtype ) -> Tuple[torch.nn.Module, AutoProcessor, torch.nn.Embedding]: """Dispatch helper for supported VLMs.""" - if model_name.lower() == "eagle2": + if model_name == "nvidia/Eagle2-2B": return _load_eagle2(device, torch_dtype) - msg = f"Unsupported model: {model_name}" + msg = ( + f"Unsupported model: '{model_name}'. Supported models are: ['nvidia/Eagle2-2B']" + ) + raise ValueError(msg) + + +# -----------------------------------------------------------------------------# +# Input loading helpers +# -----------------------------------------------------------------------------# + + +def _load_inputs_eagle2(args: argparse.Namespace, processor, device: torch.device): + """ + Loads the input dictionary for the Eagle2 model. + """ + url = "https://cdn.pixabay.com/photo/2019/08/08/23/33/car-4393990_1280.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + if args.benchmark: + model_constants = MODEL_CONSTANTS[args.model] + image_tokens = model_constants["IMAGE_TOKENS"] + wrapper_tokens = model_constants["PROMPT_WRAPPER_TOKENS"] + + prompt_len = args.isl - image_tokens - wrapper_tokens + prompt_txt = " ".join(["token"] * max(prompt_len, 0)) + else: + prompt_txt = args.prompt + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": prompt_txt}, + ], + } + ] + + txt = [ + processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + ] + img_in, vid_in = processor.process_vision_info(messages) + inputs = processor( + text=txt, images=img_in, videos=vid_in, return_tensors="pt", padding=True + ).to(device) + return inputs + + +def load_inputs(args: argparse.Namespace, processor, device: torch.device): + """Dispatch helper for input loading for supported VLMs.""" + if args.model == "nvidia/Eagle2-2B": + return _load_inputs_eagle2(args, processor, device) + + msg = f"Unsupported model for input loading: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B']" raise ValueError(msg) @@ -116,9 +208,9 @@ def _compile_eagle2_lm( lm_wrap = _LMNoCache(language_model).to(DEVICE).eval() max_seq_len = input_embeds.shape[1] + args.num_tokens - S = torch.export.Dim("seq", min=1, max=max_seq_len) + seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) - dyn_shapes = {"inputs_embeds": {1: S}, "position_ids": {1: S}} + dyn_shapes = {"inputs_embeds": {1: seq_len}, "position_ids": {1: seq_len}} # Precision-specific flags --------------------------------------------------# use_fp32_acc = False @@ -133,7 +225,7 @@ def _compile_eagle2_lm( enabled_precisions = {torch.float32} with torch.inference_mode(): - exported = torch.export.export( + exported_program = torch.export.export( lm_wrap, (input_embeds, position_ids), dynamic_shapes=dyn_shapes, @@ -142,7 +234,7 @@ def _compile_eagle2_lm( with torch_tensorrt.logging.debug() if args.debug else nullcontext(): trt_mod = torch_tensorrt.dynamo.compile( - exported, + exported_program, inputs=[input_embeds, position_ids], enabled_precisions=enabled_precisions, use_explicit_typing=use_explicit_typing, @@ -157,31 +249,37 @@ def _compile_eagle2_lm( return trt_mod -def compile_torchtrt( +def compile_lm_torchtrt( model: torch.nn.Module, args: argparse.Namespace ) -> torch.nn.Module: """ - Front-end dispatcher mirroring *run_llm.py*’s `compile_torchtrt`. + Compiles the Language Model (LLM) component of the VLM using Torch-TensorRT. - Depending on the target VLM, delegates to the appropriate compile routine. + This function acts as a dispatcher, delegating to the appropriate routine + (e.g., `_compile_eagle2_lm`) based on the target model. """ torch_dtype = { "FP16": torch.float16, "BF16": torch.bfloat16, }.get(args.precision, torch.float32) + model_constants = MODEL_CONSTANTS[args.model] + example_seq_len = model_constants["EXAMPLE_SEQLEN"] + example_embeds = torch.randn( 1, - 2560, + example_seq_len, model.language_model.config.hidden_size, dtype=torch_dtype, device=DEVICE, ) - if args.model.lower() == "eagle2": + if args.model == "nvidia/Eagle2-2B": return _compile_eagle2_lm(model.language_model, example_embeds, args) - msg = f"Unsupported model for compilation: {args.model}" + msg = ( + f"Unsupported model: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B']" + ) raise ValueError(msg) @@ -206,7 +304,7 @@ def _compile_eagle2_vision( enabled_precisions = {torch.float32} with torch.inference_mode(): - exported = torch.export.export( + exported_program = torch.export.export( vision_model, (example_pixel_values,), strict=False, @@ -214,7 +312,7 @@ def _compile_eagle2_vision( with torch_tensorrt.logging.debug() if args.debug else nullcontext(): trt_mod = torch_tensorrt.dynamo.compile( - exported, + exported_program, inputs=[example_pixel_values], enabled_precisions=enabled_precisions, use_explicit_typing=use_explicit_typing, @@ -237,7 +335,7 @@ def compile_vision_torchtrt( """ Dispatcher function for vision model compilation. """ - if args.model.lower() == "eagle2": + if args.model == "nvidia/Eagle2-2B": return _compile_eagle2_vision(model.vision_model, example_pixel_values, args) else: raise ValueError(f"Unsupported model: {args.model}") @@ -265,7 +363,7 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): parser = argparse.ArgumentParser( description="Run VLM inference (PyTorch & TensorRT back-ends)" ) - parser.add_argument("--model", default="eagle2", help="VLM model name") + parser.add_argument("--model", default="nvidia/Eagle2-2B", help="VLM model name") parser.add_argument("--prompt", default="Describe this image.", help="Prompt text") parser.add_argument( "--precision", @@ -306,39 +404,12 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): "BF16": torch.bfloat16, }.get(args.precision, torch.float32) - model, processor, emb_layer = _load_model(args.model, DEVICE, dtype) + model, processor, emb_layer = load_model(args.model, DEVICE, dtype) # -------------------------------------------------------------------------# # 2. Input construction (image + text prompt) # -------------------------------------------------------------------------# - url = "https://cdn.pixabay.com/photo/2019/08/08/23/33/car-4393990_1280.jpg" - image = Image.open(requests.get(url, stream=True).raw) - - if args.benchmark: - prompt_len = args.isl - 1792 - 26 - prompt_txt = " ".join(["token"] * max(prompt_len, 0)) - else: - prompt_txt = args.prompt - - messages = [ - { - "role": "user", - "content": [ - {"type": "image", "image": image}, - {"type": "text", "text": prompt_txt}, - ], - } - ] - - txt = [ - processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - ] - img_in, vid_in = processor.process_vision_info(messages) - inputs = processor( - text=txt, images=img_in, videos=vid_in, return_tensors="pt", padding=True - ).to(DEVICE) + inputs = load_inputs(args, processor, DEVICE) max_output_len = inputs["input_ids"].shape[1] + args.num_tokens @@ -394,7 +465,7 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): if args.cache == "static_v1": import static_cache_v1 # noqa: F401 - trt_lm = compile_torchtrt(model, args) + trt_lm = compile_lm_torchtrt(model, args) trt_model.language_model = trt_lm emb_layer = emb_layer.to(DEVICE) From 5d98dc47581c01c42fb1418f4d632f37ac1a05e8 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 24 Jul 2025 13:12:32 +0000 Subject: [PATCH 4/9] support qwen2.5_vl with cache --- tools/llm/run_vlm.py | 257 +++++++++++++------ tools/llm/static_cache_v1.py | 16 +- tools/llm/test_qwen2.5_components.py | 2 +- tools/llm/utils.py | 354 ++++++++++++++++++++++++++- 4 files changed, 537 insertions(+), 92 deletions(-) diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py index 0bafd9ecc9..02ffec7875 100644 --- a/tools/llm/run_vlm.py +++ b/tools/llm/run_vlm.py @@ -38,6 +38,8 @@ from transformers import AutoConfig, AutoModel, AutoProcessor from utils import ( generate_mm, + generate_mm_qwen2_5_vl, + generate_mm_qwen2_5_vl_with_static_cache, generate_mm_with_static_cache, record_stats, time_generate_mm, @@ -69,7 +71,12 @@ "EXAMPLE_SEQLEN": 2560, # A fixed sequence length for creating the example tensor for TRT compilation. "IMAGE_TOKENS": 1792, # Number of special tokens used to represent the image patch embeddings in the input sequence for Eagle2-2B VLM. "PROMPT_WRAPPER_TOKENS": 26, # The number of special/processing tokens added by the processor's chat template in benchmark mode. - } + }, + "Qwen/Qwen2.5-VL-3B-Instruct": { + "EXAMPLE_SEQLEN": 2560, + "IMAGE_TOKENS": 391, + "PROMPT_WRAPPER_TOKENS": 21, + }, } # --- END Model-specific constants --- @@ -110,15 +117,30 @@ def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): return model, processor, emb_layer +def _load_qwen2_5_vl(device, torch_dtype: torch.dtype): + """ + Load Qwen2.5-VL model and processor. + """ + from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + model_id = "Qwen/Qwen2.5-VL-3B-Instruct" + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch_dtype, device_map=device + ).eval() + processor = AutoProcessor.from_pretrained(model_id) + emb_layer = model.model.get_input_embeddings().to(torch_dtype).to(device) + return model, processor, emb_layer + + def load_model( model_name: str, device: torch.device, torch_dtype: torch.dtype ) -> Tuple[torch.nn.Module, AutoProcessor, torch.nn.Embedding]: """Dispatch helper for supported VLMs.""" if model_name == "nvidia/Eagle2-2B": return _load_eagle2(device, torch_dtype) - msg = ( - f"Unsupported model: '{model_name}'. Supported models are: ['nvidia/Eagle2-2B']" - ) + elif model_name == "Qwen/Qwen2.5-VL-3B-Instruct": + return _load_qwen2_5_vl(device, torch_dtype) + msg = f"Unsupported model: '{model_name}'. Supported models are: ['nvidia/Eagle2-2B', 'Qwen/Qwen2.5-VL-3B-Instruct']" raise ValueError(msg) @@ -127,11 +149,11 @@ def load_model( # -----------------------------------------------------------------------------# -def _load_inputs_eagle2(args: argparse.Namespace, processor, device: torch.device): +def load_inputs(args: argparse.Namespace, processor, device: torch.device): """ - Loads the input dictionary for the Eagle2 model. + Loads and constructs the input dictionary for the specified VLM model. """ - url = "https://cdn.pixabay.com/photo/2019/08/08/23/33/car-4393990_1280.jpg" + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" image = Image.open(requests.get(url, stream=True).raw) if args.benchmark: @@ -142,7 +164,7 @@ def _load_inputs_eagle2(args: argparse.Namespace, processor, device: torch.devic prompt_len = args.isl - image_tokens - wrapper_tokens prompt_txt = " ".join(["token"] * max(prompt_len, 0)) else: - prompt_txt = args.prompt + prompt_txt = args.prompt or "Describe this image." messages = [ { @@ -154,25 +176,28 @@ def _load_inputs_eagle2(args: argparse.Namespace, processor, device: torch.devic } ] - txt = [ + text = [ processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) ] - img_in, vid_in = processor.process_vision_info(messages) - inputs = processor( - text=txt, images=img_in, videos=vid_in, return_tensors="pt", padding=True - ).to(device) - return inputs + # --- Model-specific vision processing --- + if "qwen" in args.model.lower(): + from qwen_vl_utils import process_vision_info -def load_inputs(args: argparse.Namespace, processor, device: torch.device): - """Dispatch helper for input loading for supported VLMs.""" - if args.model == "nvidia/Eagle2-2B": - return _load_inputs_eagle2(args, processor, device) + image_inputs, video_inputs = process_vision_info(messages) + else: # eagle2 + image_inputs, video_inputs = processor.process_vision_info(messages) - msg = f"Unsupported model for input loading: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B']" - raise ValueError(msg) + inputs = processor( + text=text, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ).to(device) + return inputs # -----------------------------------------------------------------------------# @@ -191,28 +216,39 @@ def __init__(self, lm): def forward(self, inputs_embeds, position_ids): out = self.lm(inputs_embeds=inputs_embeds, position_ids=position_ids) - return out.logits if hasattr(out, "logits") else out + return ( + out.logits + if hasattr(out, "logits") + else out.last_hidden_state if hasattr(out, "last_hidden_state") else out + ) -def _compile_eagle2_lm( +def _compile_lm( language_model: torch.nn.Module, input_embeds: torch.Tensor, args: argparse.Namespace, ) -> torch.nn.Module: """ - Compile Eagle2 language model with Torch-TensorRT. - - The function follows the same precision-specific flag logic used in - *run_llm.py* for consistency. + Compile the language model component of a VLM with Torch-TensorRT """ lm_wrap = _LMNoCache(language_model).to(DEVICE).eval() max_seq_len = input_embeds.shape[1] + args.num_tokens - seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) - position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) + # --- Model-specific dynamic shape definition --- + if "qwen" in args.model.lower(): + _seq = torch.export.Dim("_seq", min=1, max=512) + seq_len = 8 * _seq + position_ids = ( + torch.arange(input_embeds.shape[1], device=DEVICE, dtype=torch.long) + .unsqueeze(0) + .expand(input_embeds.size(0), input_embeds.size(1)) + ) + else: # eagle2 + seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) + position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) + dyn_shapes = {"inputs_embeds": {1: seq_len}, "position_ids": {1: seq_len}} - # Precision-specific flags --------------------------------------------------# use_fp32_acc = False use_explicit_typing = False if args.precision == "FP16": @@ -254,33 +290,33 @@ def compile_lm_torchtrt( ) -> torch.nn.Module: """ Compiles the Language Model (LLM) component of the VLM using Torch-TensorRT. - - This function acts as a dispatcher, delegating to the appropriate routine - (e.g., `_compile_eagle2_lm`) based on the target model. """ torch_dtype = { "FP16": torch.float16, "BF16": torch.bfloat16, }.get(args.precision, torch.float32) - model_constants = MODEL_CONSTANTS[args.model] + lm_model = model.model if "qwen" in args.model.lower() else model.language_model + + model_constants = MODEL_CONSTANTS.get( + args.model, {"EXAMPLE_SEQLEN": args.num_tokens} + ) example_seq_len = model_constants["EXAMPLE_SEQLEN"] example_embeds = torch.randn( - 1, + args.batch_size, example_seq_len, - model.language_model.config.hidden_size, + lm_model.config.hidden_size, dtype=torch_dtype, device=DEVICE, ) - if args.model == "nvidia/Eagle2-2B": - return _compile_eagle2_lm(model.language_model, example_embeds, args) - - msg = ( - f"Unsupported model: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B']" - ) - raise ValueError(msg) + # All supported models use the same compilation helper. + if args.model in ["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"]: + return _compile_lm(lm_model, example_embeds, args) + else: + msg = f"Unsupported model: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B', 'Qwen/Qwen2.5-VL-3B-Instruct']" + raise ValueError(msg) def _compile_eagle2_vision( @@ -337,6 +373,13 @@ def compile_vision_torchtrt( """ if args.model == "nvidia/Eagle2-2B": return _compile_eagle2_vision(model.vision_model, example_pixel_values, args) + elif args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + # TODO: Vision model compilation for Qwen2.5-VL is currently skipped. + # The model's `get_window_index` method uses dynamic Python list operations + # (e.g., .tolist(), .extend()) to process variable-sized image grids for + # windowed attention. These operations are incompatible with torch.export's + # static graph tracing, preventing successful compilation. + return model.visual else: raise ValueError(f"Unsupported model: {args.model}") @@ -363,7 +406,12 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): parser = argparse.ArgumentParser( description="Run VLM inference (PyTorch & TensorRT back-ends)" ) - parser.add_argument("--model", default="nvidia/Eagle2-2B", help="VLM model name") + parser.add_argument( + "--model", + default="Qwen/Qwen2.5-VL-3B-Instruct", + choices=["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"], + help="VLM model name", + ) parser.add_argument("--prompt", default="Describe this image.", help="Prompt text") parser.add_argument( "--precision", @@ -418,25 +466,46 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): # -------------------------------------------------------------------------# pyt_gen_tokens = pyt_timings = pyt_stats = None if args.enable_pytorch_run: - pyt_gen_tokens = generate_mm( - model, - inputs["pixel_values"], - inputs["input_ids"], - max_output_len, - processor.tokenizer.eos_token_id, - emb_layer, - ) - print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) - if args.benchmark: - pyt_timings = time_generate_mm( - generate_mm, + if "qwen" in args.model.lower(): + pyt_gen_tokens = generate_mm_qwen2_5_vl( model, - inputs["pixel_values"].clone(), - inputs["input_ids"].clone(), + inputs["pixel_values"], + inputs["input_ids"], + inputs["image_grid_thw"], max_output_len, processor.tokenizer.eos_token_id, emb_layer, - iterations=args.iterations, + ) + else: # eagle2 + pyt_gen_tokens = generate_mm( + model, + inputs["pixel_values"], + inputs["input_ids"], + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + ) + print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) + if args.benchmark: + # Prepare args for the timing function + time_generate_args = { + "model": model, + "pixel_values": inputs["pixel_values"].clone(), + "input_ids": inputs["input_ids"].clone(), + "max_output_seq_length": max_output_len, + "eos_token_id": processor.tokenizer.eos_token_id, + "emb_layer": emb_layer, + } + + # Select the correct generation function and add model-specific args + if "qwen" in args.model.lower(): + generate_fn_for_timing = generate_mm_qwen2_5_vl + time_generate_args["image_grid_thw"] = inputs["image_grid_thw"] + else: # eagle2 + generate_fn_for_timing = generate_mm + + pyt_timings = time_generate_mm( + generate_fn_for_timing, iterations=args.iterations, **time_generate_args ) pyt_stats = record_stats( "PyTorch", @@ -455,7 +524,10 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): # --- Add vision model compilation --- # example_pixel_values = inputs["pixel_values"] trt_vision = compile_vision_torchtrt(model, args, example_pixel_values) - trt_model.vision_model = trt_vision + if "qwen" in args.model.lower(): + trt_model.visual = trt_vision + else: + trt_model.vision_model = trt_vision # -------------------------------------------------------------------------# # 4.2. Language model compilation @@ -466,36 +538,63 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): import static_cache_v1 # noqa: F401 trt_lm = compile_lm_torchtrt(model, args) - trt_model.language_model = trt_lm + if "qwen" in args.model.lower(): + trt_model.model = trt_lm + else: + trt_model.language_model = trt_lm emb_layer = emb_layer.to(DEVICE) + if "qwen" in args.model.lower(): + trt_model.lm_head = trt_model.lm_head.to(DEVICE) if args.cache == "static_v1": - trt_generate = generate_mm_with_static_cache + if "qwen" in args.model.lower(): + trt_generate = generate_mm_qwen2_5_vl_with_static_cache + else: # eagle2 + trt_generate = generate_mm_with_static_cache else: - trt_generate = generate_mm - - trt_gen_tokens = trt_generate( - trt_model, - inputs["pixel_values"], - inputs["input_ids"], - max_output_len, - processor.tokenizer.eos_token_id, - emb_layer, - DEVICE if args.cache == "static_v1" else None, # device arg only for static_v1 - ) + if "qwen" in args.model.lower(): + trt_generate = generate_mm_qwen2_5_vl + else: # eagle2 + trt_generate = generate_mm + + # Prepare args for generate function + generate_args = { + "model": trt_model, + "pixel_values": inputs["pixel_values"], + "input_ids": inputs["input_ids"], + "max_output_seq_length": max_output_len, + "eos_token_id": processor.tokenizer.eos_token_id, + "emb_layer": emb_layer, + } + if "qwen" in args.model.lower(): + generate_args["image_grid_thw"] = inputs["image_grid_thw"] + if args.cache == "static_v1": + generate_args["device"] = DEVICE + + trt_gen_tokens = trt_generate(**generate_args) if args.benchmark: + # Prepare args for the timing function + time_generate_args = { + "model": trt_model, + "pixel_values": inputs["pixel_values"].clone(), + "input_ids": inputs["input_ids"].clone(), + "max_output_seq_length": max_output_len, + "eos_token_id": processor.tokenizer.eos_token_id, + "emb_layer": emb_layer, + } + + # Add model-specific args + if "qwen" in args.model.lower(): + time_generate_args["image_grid_thw"] = inputs["image_grid_thw"] + if args.cache == "static_v1": + time_generate_args["device"] = DEVICE + trt_timings = time_generate_mm( trt_generate, - trt_model, - inputs["pixel_values"].clone(), - inputs["input_ids"].clone(), - max_output_len, - processor.tokenizer.eos_token_id, - emb_layer, iterations=args.iterations, - device=DEVICE if args.cache == "static_v1" else None, + **time_generate_args, ) trt_stats = record_stats( "TensorRT", diff --git a/tools/llm/static_cache_v1.py b/tools/llm/static_cache_v1.py index b60396c08b..161d02fe14 100644 --- a/tools/llm/static_cache_v1.py +++ b/tools/llm/static_cache_v1.py @@ -202,11 +202,25 @@ def insert_kv_slicing_before_sdpa( kwargs={}, ) # =============================================== # + # This prevents the cache tensor from growing when padded inputs are used. + update_window_size = gm.graph.create_node( + "call_function", + torch.ops.aten.sub.Tensor, + args=(end_idx_input, start_idx_input), + kwargs={}, + ) + sliced_new_kv = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(current_key_or_value_node, 2, 0, update_window_size), + kwargs={}, + ) + # Concatenate the sliced tensors to build KV cache cat = gm.graph.create_node( "call_function", torch.ops.aten.cat.default, - args=([slice_4, current_key_or_value_node, slice_8], 2), + args=([slice_4, sliced_new_kv, slice_8], 2), kwargs={}, ) # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph diff --git a/tools/llm/test_qwen2.5_components.py b/tools/llm/test_qwen2.5_components.py index 60482bf22d..1c1366bd0c 100644 --- a/tools/llm/test_qwen2.5_components.py +++ b/tools/llm/test_qwen2.5_components.py @@ -16,7 +16,7 @@ # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from register_sdpa import * +from torchtrt_ext import register_sdpa ATOL = 1e-5 RTOL = 1e-5 diff --git a/tools/llm/utils.py b/tools/llm/utils.py index a5b8662f27..ee8c852d6f 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -679,26 +679,358 @@ def generate_mm_with_static_cache_timing( def time_generate_mm( generate_fn, - model, - pixel_values, - input_ids, - output_seq_length, - eos_token_id, - emb_layer, iterations=10, - device="cuda:0", + **kwargs, ): """ - Measure the time for generating a sentence over certain number of iterations + Measure the time for generating a sentence over certain number of iterations. + Accepts generation function arguments via kwargs. """ timings = [] for _ in range(iterations): start_time = timeit.default_timer() - _ = generate_fn( - model, pixel_values, input_ids, output_seq_length, eos_token_id, emb_layer - ) + _ = generate_fn(**kwargs) torch.cuda.synchronize() end_time = timeit.default_timer() timings.append(end_time - start_time) return timings + + +def generate_mm_qwen2_5_vl( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, +): + """ + Custom generation function for the Qwen2_5_VLForConditionalGeneration model. + Performs greedy decoding without caching, using inputs_embeds instead of input_ids. + """ + # 1. Calculate image embeddings (if pixel_values are provided) + image_embeds = None + if pixel_values is not None: + image_embeds = model.visual(pixel_values, image_grid_thw) + + # 2. Create initial sequence embeddings + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + # 3. Insert image embeddings at image token positions + if image_embeds is not None: + mask = seq_tokens == model.config.image_token_id + num_image_tokens = mask.sum().item() + if num_image_tokens != image_embeds.shape[0]: + raise ValueError( + f"Number of image tokens ({num_image_tokens}) does not match number of image embeddings ({image_embeds.shape[0]})." + ) + mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds) + seq_embeds = seq_embeds.masked_scatter( + mask_expanded, image_embeds.to(seq_embeds.dtype) + ) + + # 5. Greedy generation loop + generated = 0 + while generated < max_output_seq_length: + # 5.1. Calculate position_ids + position_ids = ( + torch.arange( + 0, seq_tokens.size(1), dtype=torch.long, device=seq_tokens.device + ) + .unsqueeze(0) + .expand(seq_embeds.size(0), seq_embeds.size(1)) + ) + + # 5.2. Call the language model + with torch.no_grad(): + outputs = model.model( + inputs_embeds=seq_embeds, + position_ids=position_ids, + ) + hidden_states = ( + outputs + if isinstance(outputs, torch.Tensor) + else outputs.last_hidden_state + ) + + # 5.3. Calculate logits for the last token + logits = model.lm_head(hidden_states[:, -1, :]) + + # 5.4. Select the next token (greedy decoding) + next_tok = torch.argmax(logits, dim=-1) + + # 5.5. Append token and embedding to the sequence + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) + next_emb = emb_layer(next_tok)[:, None, :] + seq_embeds = torch.cat([seq_embeds, next_emb], dim=1) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + # 6. Return generated tokens (only the part after the prompt) + return seq_tokens[:, input_ids.shape[1] :] + + +def generate_mm_qwen2_5_vl_with_static_cache( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + device: str = "cuda:0", +) -> torch.LongTensor: + """ + Greedy Decoder for Qwen-2.5-VL using static KV-cache. + Identical to `generate_mm_with_static_cache` but adapted for Qwen-2.5-VL's + specific architecture (e.g., separate visual encoder call, lm_head). + """ + # 1. Vision encoding + image_embeds = None + if pixel_values is not None: + image_embeds = model.visual(pixel_values, image_grid_thw) + + # 2. Text embedding & image token replacement + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if image_embeds is not None: + mask = seq_tokens == model.config.image_token_id + num_image_tokens = mask.sum().item() + if num_image_tokens != image_embeds.shape[0]: + raise ValueError( + f"Number of image tokens ({num_image_tokens}) does not match " + f"number of image embeddings ({image_embeds.shape[0]})." + ) + mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds) + seq_embeds = seq_embeds.masked_scatter( + mask_expanded, image_embeds.to(seq_embeds.dtype) + ) + + # 3. KV-cache initialization + kv_cache = get_zeroed_static_cache_inputs(model.model) + start_idx = 0 + end_idx = seq_embeds.size(1) + generated = 0 + max_total_len = max_output_seq_length + output_tokens = seq_tokens.clone() + + # 4. Greedy loop + while output_tokens.size(1) < max_total_len: + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + # For the prefill step, the relevant logit is the very last one. + logit_pos = -1 + else: + # --- RUNTIME PADDING FIX for KV Cache Decode --- + # The compiled TensorRT engine has a minimum sequence length requirement (e.g., 16), + # as determined by its optimization profile. The decode step uses a sequence length + # of 1, which violates this profile. + # To resolve this, we manually pad the input tensors to the minimum length (16) + # at runtime before feeding them to the engine. + pad_len = 15 # Pad from 1 to 16 (1 + 15) + + # Pad cur_embeds tensor + padding_tensor_embeds = torch.zeros( + cur_embeds.size(0), + pad_len, + cur_embeds.size(2), + dtype=cur_embeds.dtype, + device=cur_embeds.device, + ) + cur_embeds = torch.cat([cur_embeds, padding_tensor_embeds], dim=1) + + # Pad position_ids tensor + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + padding_tensor_ids = torch.zeros( + position_ids.size(0), + pad_len, + dtype=position_ids.dtype, + device=position_ids.device, + ) + position_ids = torch.cat([position_ids, padding_tensor_ids], dim=1) + + # Since we padded the sequence, the logit for our actual token is now at position 0. + logit_pos = 0 + + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + ) + + outputs_and_kv = model.model(*input_signature) + # With the fix in static_cache_v1.py, the model output is now clean: + # (hidden_state, updated_kv_cache[72]) + hidden_states, kv_cache = outputs_and_kv[0], outputs_and_kv[1:] + + # Use logit_pos to get the correct logit based on whether we padded or not. + logits = model.lm_head(hidden_states[:, logit_pos, :]) + + next_tok = logits.argmax(dim=-1) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + next_embed = emb_layer(next_tok)[:, None, :] + seq_embeds = next_embed + + generated += 1 + start_idx = end_idx + end_idx += 1 + + if (next_tok == eos_token_id).all(): + break + + return output_tokens + + +def generate_mm_paligemma( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, +): + vit_embeds = None + if pixel_values is not None: + vit_out = model.vision_tower(pixel_values) + vit_embeds = model.multi_modal_projector(vit_out.last_hidden_state) + vit_embeds = vit_embeds / (model.config.text_config.hidden_size**0.5) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.config.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + B = seq_tokens.size(0) + cache_position = torch.arange(seq_tokens.size(1), device=seq_tokens.device) + position_ids = cache_position.unsqueeze(0) + 1 + + generated = 0 + while generated < max_output_seq_length: + causal_mask = model.model._update_causal_mask( + attention_mask=None, + token_type_ids=None, + past_key_values=None, + cache_position=cache_position, + input_tensor=seq_embeds, + is_training=False, + ) + + with torch.no_grad(): + out = model.language_model( + inputs_embeds=seq_embeds, + position_ids=position_ids, + attention_mask=causal_mask, + use_cache=False, + ) + logits = out.last_hidden_state if hasattr(out, "last_hidden_state") else out + + next_tok = torch.argmax(logits[:, -1, :], dim=-1) + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) + seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) + + position_ids = torch.cat([position_ids, position_ids[:, -1:] + 1], dim=1) + cache_position = torch.arange(seq_tokens.size(1), device=seq_tokens.device) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + return seq_tokens + + +@torch.inference_mode() +def generate_mm_paligemma_with_static_cache( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + device: str = "cuda:0", +) -> torch.LongTensor: + vit_embeds = None + if pixel_values is not None: + vit_latent = model.vision_tower(pixel_values) + vit_embeds = ( + vit_latent.last_hidden_state + if hasattr(vit_latent, "last_hidden_state") + else vit_latent + ) + vit_embeds = model.multi_modal_projector(vit_embeds) + vit_embeds = vit_embeds / (model.config.text_config.hidden_size**0.5) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + kv_cache = get_zeroed_static_cache_inputs(model.language_model, device=device) + start_idx = 0 + end_idx = seq_embeds.size(1) + generated = 0 + max_total_len = max_output_seq_length + output_tokens = seq_tokens.clone() + + while output_tokens.size(1) < max_total_len: + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + is_causal = True if cur_embeds.shape[1] > 1 else False + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + is_causal, + ) + + logits_and_kv = model.language_model(*input_signature) + logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] + + next_tok = logits[:, -1, :].argmax(dim=-1) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + next_embed = emb_layer(next_tok)[:, None, :] + seq_embeds = next_embed + + generated += 1 + start_idx = end_idx + end_idx += 1 + is_causal = True + + if (next_tok == eos_token_id).all(): + break + + return output_tokens From cfe1b237291001653669857782fa06a3497d7bef Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 28 Jul 2025 15:37:42 +0000 Subject: [PATCH 5/9] fix: align ISL/OSL with arguments and remove padding in language model --- tools/llm/run_vlm.py | 30 +++----- tools/llm/static_cache_v1.py | 16 +--- tools/llm/utils.py | 140 ++++++++++++++++++++++++----------- 3 files changed, 107 insertions(+), 79 deletions(-) diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py index 02ffec7875..be91ac6efe 100644 --- a/tools/llm/run_vlm.py +++ b/tools/llm/run_vlm.py @@ -37,6 +37,7 @@ from torchtrt_ext import register_sdpa from transformers import AutoConfig, AutoModel, AutoProcessor from utils import ( + export_llm, generate_mm, generate_mm_qwen2_5_vl, generate_mm_qwen2_5_vl_with_static_cache, @@ -74,7 +75,7 @@ }, "Qwen/Qwen2.5-VL-3B-Instruct": { "EXAMPLE_SEQLEN": 2560, - "IMAGE_TOKENS": 391, + "IMAGE_TOKENS": 1426, "PROMPT_WRAPPER_TOKENS": 21, }, } @@ -153,7 +154,7 @@ def load_inputs(args: argparse.Namespace, processor, device: torch.device): """ Loads and constructs the input dictionary for the specified VLM model. """ - url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" + url = "https://www.ilankelman.org/stopsigns/australia.jpg" image = Image.open(requests.get(url, stream=True).raw) if args.benchmark: @@ -197,6 +198,7 @@ def load_inputs(args: argparse.Namespace, processor, device: torch.device): padding=True, return_tensors="pt", ).to(device) + return inputs @@ -234,18 +236,8 @@ def _compile_lm( lm_wrap = _LMNoCache(language_model).to(DEVICE).eval() max_seq_len = input_embeds.shape[1] + args.num_tokens - # --- Model-specific dynamic shape definition --- - if "qwen" in args.model.lower(): - _seq = torch.export.Dim("_seq", min=1, max=512) - seq_len = 8 * _seq - position_ids = ( - torch.arange(input_embeds.shape[1], device=DEVICE, dtype=torch.long) - .unsqueeze(0) - .expand(input_embeds.size(0), input_embeds.size(1)) - ) - else: # eagle2 - seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) - position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) + seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) + position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) dyn_shapes = {"inputs_embeds": {1: seq_len}, "position_ids": {1: seq_len}} @@ -260,13 +252,9 @@ def _compile_lm( else: # FP32 enabled_precisions = {torch.float32} - with torch.inference_mode(): - exported_program = torch.export.export( - lm_wrap, - (input_embeds, position_ids), - dynamic_shapes=dyn_shapes, - strict=False, - ) + exported_program = export_llm( + lm_wrap, input_embeds, min_seq_len=1, max_seq_len=2560 + ) with torch_tensorrt.logging.debug() if args.debug else nullcontext(): trt_mod = torch_tensorrt.dynamo.compile( diff --git a/tools/llm/static_cache_v1.py b/tools/llm/static_cache_v1.py index 161d02fe14..58daacedf5 100644 --- a/tools/llm/static_cache_v1.py +++ b/tools/llm/static_cache_v1.py @@ -201,26 +201,12 @@ def insert_kv_slicing_before_sdpa( args=(slice_7, 3), kwargs={}, ) - # =============================================== # - # This prevents the cache tensor from growing when padded inputs are used. - update_window_size = gm.graph.create_node( - "call_function", - torch.ops.aten.sub.Tensor, - args=(end_idx_input, start_idx_input), - kwargs={}, - ) - sliced_new_kv = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(current_key_or_value_node, 2, 0, update_window_size), - kwargs={}, - ) # Concatenate the sliced tensors to build KV cache cat = gm.graph.create_node( "call_function", torch.ops.aten.cat.default, - args=([slice_4, sliced_new_kv, slice_8], 2), + args=([slice_4, current_key_or_value_node, slice_8], 2), kwargs={}, ) # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph diff --git a/tools/llm/utils.py b/tools/llm/utils.py index ee8c852d6f..09fa662299 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -448,9 +448,9 @@ def generate_mm_with_timing( model, pixel_values: torch.Tensor | None, input_ids: torch.Tensor, - max_output_seq_length: int, eos_token_id: int, emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, use_cache: bool = False, ): # Create timing events @@ -505,7 +505,7 @@ def generate_mm_with_timing( generated = 0 past_key_values = None - while generated < max_output_seq_length: + while generated < max_new_tokens: lm_start.record() cur_embeds = seq_embeds position_ids = ( @@ -527,8 +527,6 @@ def generate_mm_with_timing( seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) generated += 1 - if (next_tok == eos_token_id).all(): - break overall_end.record() torch.cuda.synchronize() @@ -732,9 +730,10 @@ def generate_mm_qwen2_5_vl( mask_expanded, image_embeds.to(seq_embeds.dtype) ) + osl = max_output_seq_length - seq_tokens.shape[1] # 5. Greedy generation loop generated = 0 - while generated < max_output_seq_length: + while generated < osl: # 5.1. Calculate position_ids position_ids = ( torch.arange( @@ -768,8 +767,6 @@ def generate_mm_qwen2_5_vl( seq_embeds = torch.cat([seq_embeds, next_emb], dim=1) generated += 1 - if (next_tok == eos_token_id).all(): - break # 6. Return generated tokens (only the part after the prompt) return seq_tokens[:, input_ids.shape[1] :] @@ -817,52 +814,22 @@ def generate_mm_qwen2_5_vl_with_static_cache( start_idx = 0 end_idx = seq_embeds.size(1) generated = 0 - max_total_len = max_output_seq_length + osl = max_output_seq_length - seq_tokens.shape[1] output_tokens = seq_tokens.clone() # 4. Greedy loop - while output_tokens.size(1) < max_total_len: + while generated < osl: cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] if generated == 0: position_ids = ( torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) ) - # For the prefill step, the relevant logit is the very last one. - logit_pos = -1 else: - # --- RUNTIME PADDING FIX for KV Cache Decode --- - # The compiled TensorRT engine has a minimum sequence length requirement (e.g., 16), - # as determined by its optimization profile. The decode step uses a sequence length - # of 1, which violates this profile. - # To resolve this, we manually pad the input tensors to the minimum length (16) - # at runtime before feeding them to the engine. - pad_len = 15 # Pad from 1 to 16 (1 + 15) - - # Pad cur_embeds tensor - padding_tensor_embeds = torch.zeros( - cur_embeds.size(0), - pad_len, - cur_embeds.size(2), - dtype=cur_embeds.dtype, - device=cur_embeds.device, - ) - cur_embeds = torch.cat([cur_embeds, padding_tensor_embeds], dim=1) - # Pad position_ids tensor position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( cur_embeds.device ) - padding_tensor_ids = torch.zeros( - position_ids.size(0), - pad_len, - dtype=position_ids.dtype, - device=position_ids.device, - ) - position_ids = torch.cat([position_ids, padding_tensor_ids], dim=1) - - # Since we padded the sequence, the logit for our actual token is now at position 0. - logit_pos = 0 input_signature = ( cur_embeds, @@ -878,7 +845,7 @@ def generate_mm_qwen2_5_vl_with_static_cache( hidden_states, kv_cache = outputs_and_kv[0], outputs_and_kv[1:] # Use logit_pos to get the correct logit based on whether we padded or not. - logits = model.lm_head(hidden_states[:, logit_pos, :]) + logits = model.lm_head(hidden_states[:, -1, :]) next_tok = logits.argmax(dim=-1) output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) @@ -890,9 +857,6 @@ def generate_mm_qwen2_5_vl_with_static_cache( start_idx = end_idx end_idx += 1 - if (next_tok == eos_token_id).all(): - break - return output_tokens @@ -1034,3 +998,93 @@ def generate_mm_paligemma_with_static_cache( break return output_tokens + + +@torch.inference_mode() +def generate_mm_qwen2_5_vl_with_timing( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, +): + """ + Custom generation function for the Qwen2_5_VLForConditionalGeneration model with timing. + """ + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + + overall_start.record() + + vision_time = 0.0 + image_embeds = None + if pixel_values is not None: + vision_start.record() + image_embeds = model.visual(pixel_values, image_grid_thw) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if image_embeds is not None: + mask = seq_tokens == model.config.image_token_id + num_image_tokens = mask.sum().item() + if num_image_tokens != image_embeds.shape[0]: + raise ValueError( + f"Number of image tokens ({num_image_tokens}) does not match number of image embeddings ({image_embeds.shape[0]})." + ) + mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds) + seq_embeds = seq_embeds.masked_scatter( + mask_expanded, image_embeds.to(seq_embeds.dtype) + ) + + step_times = [] + generated = 0 + while generated < max_new_tokens: + lm_start.record() + position_ids = ( + torch.arange( + 0, seq_tokens.size(1), dtype=torch.long, device=seq_tokens.device + ) + .unsqueeze(0) + .expand(seq_embeds.size(0), seq_embeds.size(1)) + ) + + with torch.no_grad(): + outputs = model.model( + inputs_embeds=seq_embeds, + position_ids=position_ids, + ) + hidden_states = ( + outputs + if isinstance(outputs, torch.Tensor) + else outputs.last_hidden_state + ) + + logits = model.lm_head(hidden_states[:, -1, :]) + next_tok = torch.argmax(logits, dim=-1) + + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) + next_emb = emb_layer(next_tok)[:, None, :] + seq_embeds = torch.cat([seq_embeds, next_emb], dim=1) + + generated += 1 + + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + + # For Qwen, there is no separate MLP part like in Eagle, so mlp_time is 0. + return seq_tokens, step_times, overall_time, vision_time, 0.0 From 7bb5efdd0eb5910d2ae00e48c233c91058371174 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Fri, 8 Aug 2025 15:39:55 +0000 Subject: [PATCH 6/9] Improve usability and visibility of arguments, README, and tutorial --- docsrc/tutorials/compile_hf_models.rst | 63 ++++++++- tools/llm/README.md | 25 +++- tools/llm/run_vlm.py | 189 ++++++++++++++++--------- tools/llm/utils.py | 10 +- 4 files changed, 213 insertions(+), 74 deletions(-) diff --git a/docsrc/tutorials/compile_hf_models.rst b/docsrc/tutorials/compile_hf_models.rst index f6da87b145..406cbb9140 100644 --- a/docsrc/tutorials/compile_hf_models.rst +++ b/docsrc/tutorials/compile_hf_models.rst @@ -18,6 +18,7 @@ Overview of tools/llm Directory The ``tools/llm`` directory provides the following tools to compile LLM models from Huggingface: * **run_llm.py**: Main entry point for model compilation, generating outputs, and benchmarking +* **run_vlm.py**: Entry point for compiling and benchmarking Visual Language Models (VLMs) * **Static Cache Utilities**: ``static_cache_v1.py`` and ``static_cache_v2.py`` for KV cache optimization * **SDPA Attention**: ``sdpa_converter.py`` and ``register_sdpa.py`` for registering scaled dot-product attention converter and lowering pass. * **Testing Components**: Model-specific test files for validation @@ -60,6 +61,30 @@ We have officially verified support for the following LLM families: - FP16, FP32 - Yes +Supported VLM Models +-------------------- +We have officially verified support for the following Visual Language Models (VLMs): + +.. list-table:: + :widths: 20 40 20 20 20 + :header-rows: 1 + + * - Model Series + - HuggingFace Model Card + - Precision + - KV Cache Support ? + - Component Support + * - Qwen 2.5 VL + - Qwen/Qwen2.5-VL-3B-Instruct + - FP16, FP32 + - Yes (static_v1 only) + - Language Model only (Image Encoder not supported) + * - Eagle2 + - nvidia/Eagle2-2B + - FP16, FP32 + - Yes (static_v1 only) + - Language Model and Image Encoder both supported + Getting Started with run_llm.py ------------------------------- @@ -112,6 +137,36 @@ Other Usage Examples python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP32 --benchmark +Getting Started with run_vlm.py +------------------------------- + +For Visual Language Models (VLMs), use ``run_vlm.py`` to compile and benchmark models that process both text and images. + +Basic Usage +^^^^^^^^^^^ + +.. code-block:: bash + + python tools/llm/run_vlm.py \ + --model Qwen/Qwen2.5-VL-3B-Instruct \ + --precision FP16 \ + --num_tokens 128 \ + --cache static_v1 \ + --enable_pytorch_run \ + --benchmark + +Key Arguments +^^^^^^^^^^^^^ + +* ``--model``: Name or path of the HuggingFace VLM +* ``--prompt``: Input prompt for generation +* ``--image_path``: (Optional) Path to input image file. If not provided, will use a sample image +* ``--precision``: Precision mode (``FP16``, ``FP32``) +* ``--num_tokens``: Number of output tokens to generate +* ``--cache``: KV cache type (``static_v1`` or empty for no KV caching) +* ``--benchmark``: Enable benchmarking mode +* ``--enable_pytorch_run``: Also run and compare PyTorch baseline + KV Caching in Torch-TensorRT --------------------------------- @@ -122,7 +177,7 @@ The length of KV cache = input sequence length + output sequence length (specifi Static Cache v1 ^^^^^^^^^^^^^^^^ -The ``static_cache_v1.py`` implements KV cache in the model graph as follows: +The ``static_cache_v1.py`` implements KV cache in the model graph as follows: .. code-block:: python @@ -210,9 +265,13 @@ Limitations and Known Issues * Sliding window attention (used in Gemma3 and Qwen 3 models) is not yet supported * Some model architectures (e.g. Phi-4) have issues with exporting the torch model. +* For VLMs, Qwen2.5-VL image encoder compilation is not supported due to dynamic operations incompatible with torch.export. Requirements ^^^^^^^^^^^^ * Torch-TensorRT 2.8.0 or later -* Transformers v4.52.3 \ No newline at end of file +* Transformers v4.52.3 +* For VLM models (run_vlm.py): + - ``pip install qwen-vl-utils`` (for Qwen2.5-VL-3B-Instruct model) + - ``pip install flash-attn --no-build-isolation -v`` (for Eagle2-2B model) \ No newline at end of file diff --git a/tools/llm/README.md b/tools/llm/README.md index a141505517..c0a88635f6 100644 --- a/tools/llm/README.md +++ b/tools/llm/README.md @@ -1,10 +1,11 @@ # Optimizing LLMs in Torch-TensorRT -This directory provides utilities and scripts for compiling, optimizing, and benchmarking Large Language Models (LLMs) using Torch-TensorRT, with a focus on efficient inference on NVIDIA GPUs. The main entry point is `run_llm.py`, which demonstrates how to export, compile, and run LLMs with various caching strategies and precision modes. Note that this is an **experimental release** and APIs may change in future versions. +This directory provides utilities and scripts for compiling, optimizing, and benchmarking Large Language Models (LLMs) and Visual Language Models (VLMs) using Torch-TensorRT, with a focus on efficient inference on NVIDIA GPUs. The main entry points are `run_llm.py` for text-only LLMs and `run_vlm.py` for vision-language models. Note that this is an **experimental release** and APIs may change in future versions. ### Key Features - **Model Support:** Works with popular LLMs such as Llama-3, Qwen2.5, etc. +- **VLM Support:** Supports Visual Language Models like Qwen2.5-VL and Eagle2. - **Precision Modes:** Supports FP16, BF16, and FP32. - **KV Cache:** Supports static and dynamic KV cache for efficient autoregressive decoding. - **Benchmarking:** Measures and compares throughput and latency for PyTorch and TensorRT backends. @@ -24,20 +25,33 @@ We have officially verified support for the following models: | Qwen 2.5 | Qwen/Qwen2.5-0.5B-Instruct
Qwen/Qwen2.5-1.5B-Instruct
Qwen/Qwen2.5-4B-Instruct
Qwen/Qwen2.5-7B-Instruct | FP16, FP32 | Yes | | Qwen 3 | Qwen/Qwen3-0.6B
Qwen/Qwen3-1.7B
Qwen/Qwen3-4B
Qwen/Qwen3-8B | FP16, FP32 | Yes | +### Supported VLM Models + +| Model Series | HF Model Card | Precision | KV Cache Supported ? | +|--------------|---------------|-----------|-------------------| +| Qwen 2.5 VL | Qwen/Qwen2.5-VL-3B-Instruct | FP16, FP32 | Yes | +| Eagle2 | nvidia/Eagle2-2B | FP16, FP32 | Yes | ### Usage -The main entry point is : `run_llm.py` +#### Text-only LLMs: `run_llm.py` ```bash python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark ``` +#### Vision Language Models: `run_vlm.py` + +```bash +python run_vlm.py --model Qwen/Qwen2.5-VL-3B-Instruct --precision FP16 --num_tokens 128 --cache static_v1 --enable_pytorch_run --benchmark +``` + #### Key Arguments -- `--model`: Name or path of the HuggingFace LLM. +- `--model`: Name or path of the HuggingFace LLM/VLM. - `--tokenizer`: (Optional) Tokenizer name; defaults to model. - `--prompt`: Input prompt for generation. +- `--image_path`: (Optional) Path to input image file for VLM models. If not provided, will use a sample image. - `--precision`: Precision mode (`FP16`, `FP32`). - `--num_tokens`: Number of output tokens to generate. - `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching). @@ -64,4 +78,7 @@ This codebase can be extended to ## Requirements - Torch-TensorRT 2.8.0 -- Transformers v4.52.3 \ No newline at end of file +- Transformers v4.52.3 +- For VLM models (run_vlm.py): + - `pip install qwen-vl-utils` (for Qwen2.5-VL-3B-Instruct model) + - `pip install flash-attn --no-build-isolation -v` (for Eagle2-2B model) \ No newline at end of file diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py index be91ac6efe..8914d245c9 100644 --- a/tools/llm/run_vlm.py +++ b/tools/llm/run_vlm.py @@ -21,6 +21,10 @@ This tool mirrors the style and structure of `run_llm.py`, providing a clear workflow for VLM optimization and analysis. + +Dependencies: +- For Qwen VLM models: pip install qwen-vl-utils +- For Eagle2 models: pip install flash-attn --no-build-isolation -v """ import argparse @@ -33,6 +37,12 @@ import requests import torch import torch_tensorrt + +# we "monkey-patch" the global attention function map for Qwen2. +# This ensures that any part of the code (including torch.export) requesting +# "flash_attention_2" will receive the "sdpa" implementation instead. +# This patch is global for the script's execution context. +import transformers.models.qwen2.modeling_qwen2 as mq from PIL import Image from torchtrt_ext import register_sdpa from transformers import AutoConfig, AutoModel, AutoProcessor @@ -46,21 +56,11 @@ time_generate_mm, ) -# -----------------------------------------------------------------------------# -# Global configuration -# -----------------------------------------------------------------------------# -DEVICE = torch.device("cuda:0") - # --- WORKAROUND FOR EAGLE2 SDPA COMPILATION --- # Eagle2's language model (Qwen2) implicitly defaults to "flash_attention_2" # due to settings in its remote code and config.json. This prevents direct # compilation with SDPA. To work around this without modifying the library, -# we "monkey-patch" the global attention function map for Qwen2. -# This ensures that any part of the code (including torch.export) requesting -# "flash_attention_2" will receive the "sdpa" implementation instead. -# This patch is global for the script's execution context. -import transformers.models.qwen2.modeling_qwen2 as mq mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"] # --- END WORKAROUND --- @@ -96,17 +96,25 @@ def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): The model, its processor and the language-model input embedding layer. """ model_id = "nvidia/Eagle2-2B" - with torch.no_grad(): - model = ( - AutoModel.from_pretrained( - model_id, - trust_remote_code=True, - torch_dtype=torch_dtype, - # attn_implementation="sdpa" is ignored due to the model's remote code. + try: + with torch.no_grad(): + model = ( + AutoModel.from_pretrained( + model_id, + trust_remote_code=True, + torch_dtype=torch_dtype, + # attn_implementation="sdpa" is ignored due to the model's remote code. + ) + .eval() + .to(device) ) - .eval() - .to(device) - ) + except ImportError as e: + if "flash_attn" in str(e): + raise ImportError( + "FlashAttention2 is required for Eagle2 models but not installed. " + "Please install it using: pip install flash-attn --no-build-isolation -v" + ) from e + raise processor = AutoProcessor.from_pretrained( model_id, trust_remote_code=True, use_fast=True @@ -154,8 +162,14 @@ def load_inputs(args: argparse.Namespace, processor, device: torch.device): """ Loads and constructs the input dictionary for the specified VLM model. """ - url = "https://www.ilankelman.org/stopsigns/australia.jpg" - image = Image.open(requests.get(url, stream=True).raw) + # Load image from local path if provided, otherwise use default URL + if args.image_path is not None: + # Use local image file + image = Image.open(args.image_path) + else: + # Use default URL image + url = "https://www.ilankelman.org/stopsigns/australia.jpg" + image = Image.open(requests.get(url, stream=True).raw) if args.benchmark: model_constants = MODEL_CONSTANTS[args.model] @@ -184,8 +198,14 @@ def load_inputs(args: argparse.Namespace, processor, device: torch.device): ] # --- Model-specific vision processing --- - if "qwen" in args.model.lower(): - from qwen_vl_utils import process_vision_info + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + try: + from qwen_vl_utils import process_vision_info + except ImportError: + raise ImportError( + "The 'qwen-vl-utils' package is required for Qwen VLM models. " + "Please install it using: pip install qwen-vl-utils" + ) image_inputs, video_inputs = process_vision_info(messages) else: # eagle2 @@ -229,15 +249,16 @@ def _compile_lm( language_model: torch.nn.Module, input_embeds: torch.Tensor, args: argparse.Namespace, + device: torch.device, ) -> torch.nn.Module: """ Compile the language model component of a VLM with Torch-TensorRT """ - lm_wrap = _LMNoCache(language_model).to(DEVICE).eval() + lm_wrap = _LMNoCache(language_model).to(device).eval() max_seq_len = input_embeds.shape[1] + args.num_tokens seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) - position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) + position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(device) dyn_shapes = {"inputs_embeds": {1: seq_len}, "position_ids": {1: seq_len}} @@ -247,8 +268,6 @@ def _compile_lm( enabled_precisions = {torch.float32} use_fp32_acc = True use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} else: # FP32 enabled_precisions = {torch.float32} @@ -263,18 +282,17 @@ def _compile_lm( enabled_precisions=enabled_precisions, use_explicit_typing=use_explicit_typing, use_fp32_acc=use_fp32_acc, - device=DEVICE, - disable_tf32=True, - use_python_runtime=True, - debug=args.debug, - offload_module_to_cpu=True, + device=device, + disable_tf32=args.disable_tf32, + use_python_runtime=args.use_python_runtime, + offload_module_to_cpu=args.offload_module_to_cpu, min_block_size=args.min_block_size, ) return trt_mod def compile_lm_torchtrt( - model: torch.nn.Module, args: argparse.Namespace + model: torch.nn.Module, args: argparse.Namespace, device: torch.device ) -> torch.nn.Module: """ Compiles the Language Model (LLM) component of the VLM using Torch-TensorRT. @@ -284,7 +302,11 @@ def compile_lm_torchtrt( "BF16": torch.bfloat16, }.get(args.precision, torch.float32) - lm_model = model.model if "qwen" in args.model.lower() else model.language_model + lm_model = ( + model.model + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct" + else model.language_model + ) model_constants = MODEL_CONSTANTS.get( args.model, {"EXAMPLE_SEQLEN": args.num_tokens} @@ -296,12 +318,12 @@ def compile_lm_torchtrt( example_seq_len, lm_model.config.hidden_size, dtype=torch_dtype, - device=DEVICE, + device=device, ) # All supported models use the same compilation helper. if args.model in ["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"]: - return _compile_lm(lm_model, example_embeds, args) + return _compile_lm(lm_model, example_embeds, args, device) else: msg = f"Unsupported model: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B', 'Qwen/Qwen2.5-VL-3B-Instruct']" raise ValueError(msg) @@ -311,6 +333,7 @@ def _compile_eagle2_vision( vision_model: torch.nn.Module, example_pixel_values: torch.Tensor, args: argparse.Namespace, + device: torch.device, ) -> torch.nn.Module: """ Compile Eagle2 vision model with Torch-TensorRT. @@ -341,11 +364,10 @@ def _compile_eagle2_vision( enabled_precisions=enabled_precisions, use_explicit_typing=use_explicit_typing, use_fp32_acc=use_fp32_acc, - device=DEVICE, - disable_tf32=True, - use_python_runtime=True, - debug=args.debug, - offload_module_to_cpu=True, + device=device, + disable_tf32=args.disable_tf32, + use_python_runtime=args.use_python_runtime, + offload_module_to_cpu=args.offload_module_to_cpu, min_block_size=args.min_block_size, ) return trt_mod @@ -355,12 +377,15 @@ def compile_vision_torchtrt( model: torch.nn.Module, args: argparse.Namespace, example_pixel_values: torch.Tensor, + device: torch.device, ) -> torch.nn.Module: """ Dispatcher function for vision model compilation. """ if args.model == "nvidia/Eagle2-2B": - return _compile_eagle2_vision(model.vision_model, example_pixel_values, args) + return _compile_eagle2_vision( + model.vision_model, example_pixel_values, args, device + ) elif args.model == "Qwen/Qwen2.5-VL-3B-Instruct": # TODO: Vision model compilation for Qwen2.5-VL is currently skipped. # The model's `get_window_index` method uses dynamic Python list operations @@ -396,7 +421,7 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): ) parser.add_argument( "--model", - default="Qwen/Qwen2.5-VL-3B-Instruct", + default="nvidia/Eagle2-2B", choices=["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"], help="VLM model name", ) @@ -404,7 +429,7 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): parser.add_argument( "--precision", default="FP16", - choices=["FP16", "BF16", "FP32"], + choices=["FP16", "FP32"], help="Computation precision", ) parser.add_argument("--iterations", type=int, default=5, help="# iterations") @@ -429,9 +454,43 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): parser.add_argument( "--benchmark", action="store_true", help="Enable benchmarking mode" ) + parser.add_argument( + "--image_path", + type=str, + default=None, + help="Path to local image file. If not provided, uses default URL image.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda:0", + help="Device to run inference on (e.g., 'cuda:0', 'cuda:1')", + ) + parser.add_argument( + "--disable_tf32", + action="store_false", + default=True, + help="Disable TF32 precision for TensorRT compilation (default: True)", + ) + parser.add_argument( + "--use_python_runtime", + action="store_false", + default=True, + help="Use Python runtime for TensorRT compilation (default: True)", + ) + parser.add_argument( + "--offload_module_to_cpu", + action="store_false", + default=True, + help="Offload module to CPU for TensorRT compilation (default: True)", + ) args = parser.parse_args() + device = torch.device(args.device) + if device.type == "cuda": + torch.cuda.set_device(device) + # -------------------------------------------------------------------------# # 1. Model / processor / embeddings # -------------------------------------------------------------------------# @@ -440,12 +499,12 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): "BF16": torch.bfloat16, }.get(args.precision, torch.float32) - model, processor, emb_layer = load_model(args.model, DEVICE, dtype) + model, processor, emb_layer = load_model(args.model, device, dtype) # -------------------------------------------------------------------------# # 2. Input construction (image + text prompt) # -------------------------------------------------------------------------# - inputs = load_inputs(args, processor, DEVICE) + inputs = load_inputs(args, processor, device) max_output_len = inputs["input_ids"].shape[1] + args.num_tokens @@ -454,7 +513,7 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): # -------------------------------------------------------------------------# pyt_gen_tokens = pyt_timings = pyt_stats = None if args.enable_pytorch_run: - if "qwen" in args.model.lower(): + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": pyt_gen_tokens = generate_mm_qwen2_5_vl( model, inputs["pixel_values"], @@ -473,7 +532,7 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): processor.tokenizer.eos_token_id, emb_layer, ) - print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) + if args.benchmark: # Prepare args for the timing function time_generate_args = { @@ -486,7 +545,7 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): } # Select the correct generation function and add model-specific args - if "qwen" in args.model.lower(): + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": generate_fn_for_timing = generate_mm_qwen2_5_vl time_generate_args["image_grid_thw"] = inputs["image_grid_thw"] else: # eagle2 @@ -511,8 +570,8 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): # 4.1. Vision model compilation # --- Add vision model compilation --- # example_pixel_values = inputs["pixel_values"] - trt_vision = compile_vision_torchtrt(model, args, example_pixel_values) - if "qwen" in args.model.lower(): + trt_vision = compile_vision_torchtrt(model, args, example_pixel_values, device) + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": trt_model.visual = trt_vision else: trt_model.vision_model = trt_vision @@ -524,24 +583,28 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): # Cache is not applied to vision model. if args.cache == "static_v1": import static_cache_v1 # noqa: F401 + elif args.cache not in ("", None): + raise ValueError( + f"Cache mode '{args.cache}' is not supported. Only 'static_v1' is supported." + ) - trt_lm = compile_lm_torchtrt(model, args) - if "qwen" in args.model.lower(): + trt_lm = compile_lm_torchtrt(model, args, device) + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": trt_model.model = trt_lm else: trt_model.language_model = trt_lm - emb_layer = emb_layer.to(DEVICE) - if "qwen" in args.model.lower(): - trt_model.lm_head = trt_model.lm_head.to(DEVICE) + emb_layer = emb_layer.to(device) + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + trt_model.lm_head = trt_model.lm_head.to(device) if args.cache == "static_v1": - if "qwen" in args.model.lower(): + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": trt_generate = generate_mm_qwen2_5_vl_with_static_cache else: # eagle2 trt_generate = generate_mm_with_static_cache else: - if "qwen" in args.model.lower(): + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": trt_generate = generate_mm_qwen2_5_vl else: # eagle2 trt_generate = generate_mm @@ -555,10 +618,10 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): "eos_token_id": processor.tokenizer.eos_token_id, "emb_layer": emb_layer, } - if "qwen" in args.model.lower(): + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": generate_args["image_grid_thw"] = inputs["image_grid_thw"] if args.cache == "static_v1": - generate_args["device"] = DEVICE + generate_args["device"] = device trt_gen_tokens = trt_generate(**generate_args) @@ -574,10 +637,10 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): } # Add model-specific args - if "qwen" in args.model.lower(): + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": time_generate_args["image_grid_thw"] = inputs["image_grid_thw"] if args.cache == "static_v1": - time_generate_args["device"] = DEVICE + time_generate_args["device"] = device trt_timings = time_generate_mm( trt_generate, diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 09fa662299..dd9997b34b 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -47,7 +47,7 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): return ep -def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule): +def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule, device: str = "cuda:0"): """ Extracts and returns zeroed static KV cache tensors from a torch.fx.GraphModule. This should only be used for static cache_v1 and static cache_v2. @@ -71,7 +71,7 @@ def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule): torch.zeros( input.meta["val"].shape, dtype=input.meta["val"].dtype, - device=torch.device("cuda:0"), + device=torch.device(device), ) ) @@ -386,7 +386,7 @@ def generate_mm_with_static_cache( seq_embeds = flat.view(B, N, C) # ───────────────────── KV-cache initialization ───────────────────── - kv_cache = get_zeroed_static_cache_inputs(model.language_model) + kv_cache = get_zeroed_static_cache_inputs(model.language_model, device=device) start_idx = 0 # First token index end_idx = seq_embeds.size(1) # Prompt length generated = 0 @@ -608,7 +608,7 @@ def generate_mm_with_static_cache_timing( seq_embeds = flat.view(B, N, C) # ───────────────────── KV-cache initialization ───────────────────── - kv_cache = get_zeroed_static_cache_inputs(model.language_model) + kv_cache = get_zeroed_static_cache_inputs(model.language_model, device=device) start_idx = 0 # First token index end_idx = seq_embeds.size(1) # Prompt length generated = 0 @@ -810,7 +810,7 @@ def generate_mm_qwen2_5_vl_with_static_cache( ) # 3. KV-cache initialization - kv_cache = get_zeroed_static_cache_inputs(model.model) + kv_cache = get_zeroed_static_cache_inputs(model.model, device=device) start_idx = 0 end_idx = seq_embeds.size(1) generated = 0 From da438a8010c04568a3cbd8c86ab6911d94e6dd8f Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Sat, 9 Aug 2025 06:10:27 +0000 Subject: [PATCH 7/9] refactoring utils for vision inputs and timings --- tools/llm/run_vlm.py | 134 ++++---- tools/llm/utils.py | 787 +++++++++++++++---------------------------- 2 files changed, 344 insertions(+), 577 deletions(-) diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py index 8914d245c9..e26d971f32 100644 --- a/tools/llm/run_vlm.py +++ b/tools/llm/run_vlm.py @@ -53,7 +53,6 @@ generate_mm_qwen2_5_vl_with_static_cache, generate_mm_with_static_cache, record_stats, - time_generate_mm, ) # --- WORKAROUND FOR EAGLE2 SDPA COMPILATION --- @@ -513,54 +512,68 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): # -------------------------------------------------------------------------# pyt_gen_tokens = pyt_timings = pyt_stats = None if args.enable_pytorch_run: - if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": - pyt_gen_tokens = generate_mm_qwen2_5_vl( - model, - inputs["pixel_values"], - inputs["input_ids"], - inputs["image_grid_thw"], - max_output_len, - processor.tokenizer.eos_token_id, - emb_layer, - ) - else: # eagle2 - pyt_gen_tokens = generate_mm( - model, - inputs["pixel_values"], - inputs["input_ids"], - max_output_len, - processor.tokenizer.eos_token_id, - emb_layer, - ) - + # For benchmarking, we run the generation with timing enabled. + # For regular runs, we run without timing for a single output. if args.benchmark: - # Prepare args for the timing function - time_generate_args = { - "model": model, - "pixel_values": inputs["pixel_values"].clone(), - "input_ids": inputs["input_ids"].clone(), - "max_output_seq_length": max_output_len, - "eos_token_id": processor.tokenizer.eos_token_id, - "emb_layer": emb_layer, - } - - # Select the correct generation function and add model-specific args if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": - generate_fn_for_timing = generate_mm_qwen2_5_vl - time_generate_args["image_grid_thw"] = inputs["image_grid_thw"] + ( + pyt_gen_tokens, + _, + overall_time, + _, + _, + ) = generate_mm_qwen2_5_vl( + model, + inputs["pixel_values"], + inputs["input_ids"], + inputs["image_grid_thw"], + processor.tokenizer.eos_token_id, + emb_layer, + max_new_tokens=args.num_tokens, + with_timing=True, + ) else: # eagle2 - generate_fn_for_timing = generate_mm - - pyt_timings = time_generate_mm( - generate_fn_for_timing, iterations=args.iterations, **time_generate_args - ) + ( + pyt_gen_tokens, + _, + overall_time, + _, + _, + ) = generate_mm( + model, + inputs["pixel_values"], + inputs["input_ids"], + processor.tokenizer.eos_token_id, + emb_layer, + max_new_tokens=args.num_tokens, + with_timing=True, + ) pyt_stats = record_stats( "PyTorch", - pyt_timings, + [overall_time / 1000], # time_generate returns seconds args.precision, batch_size=args.batch_size, - compile_time_s=None, ) + else: + if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + pyt_gen_tokens = generate_mm_qwen2_5_vl( + model, + inputs["pixel_values"], + inputs["input_ids"], + inputs["image_grid_thw"], + processor.tokenizer.eos_token_id, + emb_layer, + max_new_tokens=args.num_tokens, + ) + else: # eagle2 + pyt_gen_tokens = generate_mm( + model, + inputs["pixel_values"], + inputs["input_ids"], + processor.tokenizer.eos_token_id, + emb_layer, + max_new_tokens=args.num_tokens, + ) # -------------------------------------------------------------------------# # 4. Torch-TensorRT compile & run @@ -614,46 +627,33 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): "model": trt_model, "pixel_values": inputs["pixel_values"], "input_ids": inputs["input_ids"], - "max_output_seq_length": max_output_len, "eos_token_id": processor.tokenizer.eos_token_id, "emb_layer": emb_layer, + "max_new_tokens": args.num_tokens, } if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": generate_args["image_grid_thw"] = inputs["image_grid_thw"] + + if args.cache == "static_v1" or args.benchmark: + generate_args["with_timing"] = True + if args.cache == "static_v1": generate_args["device"] = device - trt_gen_tokens = trt_generate(**generate_args) - - if args.benchmark: - # Prepare args for the timing function - time_generate_args = { - "model": trt_model, - "pixel_values": inputs["pixel_values"].clone(), - "input_ids": inputs["input_ids"].clone(), - "max_output_seq_length": max_output_len, - "eos_token_id": processor.tokenizer.eos_token_id, - "emb_layer": emb_layer, - } + # Run TRT generation + trt_output = trt_generate(**generate_args) - # Add model-specific args - if args.model == "Qwen/Qwen2.5-VL-3B-Instruct": - time_generate_args["image_grid_thw"] = inputs["image_grid_thw"] - if args.cache == "static_v1": - time_generate_args["device"] = device - - trt_timings = time_generate_mm( - trt_generate, - iterations=args.iterations, - **time_generate_args, - ) + # Unpack results + if args.benchmark or args.cache == "static_v1": + trt_gen_tokens, _, overall_time, _, _ = trt_output trt_stats = record_stats( "TensorRT", - trt_timings, + [overall_time / 1000], # time is in ms, convert to s args.precision, batch_size=args.batch_size, - compile_time_s=None, ) + else: + trt_gen_tokens = trt_output # -------------------------------------------------------------------------# # 5. Reporting diff --git a/tools/llm/utils.py b/tools/llm/utils.py index dd9997b34b..e62fcced35 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -244,43 +244,35 @@ def record_stats(backend, timings, precision, batch_size=1, compile_time_s=None) return stats -def generate_mm( +def _prepare_mm_inputs( model, pixel_values: torch.Tensor | None, input_ids: torch.Tensor, - max_output_seq_length: int, - eos_token_id: int, emb_layer: torch.nn.Embedding, - device: str = "cuda:0", + with_timing: bool = False, ): - """Greedy decode for Eagle2-style VLM. - - Parameters - ---------- - model : nn.Module - Must expose vision_model, mlp1, language_model, pixel_shuffle, downsample_ratio, image_token_index. - pixel_values : Tensor | None - Input image batch (B,C,H,W) or None. - input_ids : LongTensor (B, N_prompt) - Text prompt token ids including [IMG] placeholder(s). - max_output_seq_length : int - Maximum tokens to generate **in addition to** the prompt. - eos_token_id : int - Stop generation when all sequences emit EOS. - emb_layer : nn.Embedding - Embedding layer for input_ids. """ - + Prepares multimodal inputs for Eagle2-style VLMs by encoding images and merging with text embeddings. + Optionally times the vision and MLP parts. + """ + vision_time = 0.0 + mlp_time = 0.0 vit_embeds = None if pixel_values is not None: - # --- Vision encoder timing --- - vis_s = torch.cuda.Event(enable_timing=True) - vis_e = torch.cuda.Event(enable_timing=True) - vis_s.record() - vit_out = model.vision_model(pixel_values) - vis_e.record() - torch.cuda.synchronize() + if with_timing: + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + mlp_start = torch.cuda.Event(enable_timing=True) + mlp_end = torch.cuda.Event(enable_timing=True) + + vision_start.record() + vit_out = model.vision_model(pixel_values) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + else: + vit_out = model.vision_model(pixel_values) vit_embeds = ( vit_out.last_hidden_state @@ -288,6 +280,9 @@ def generate_mm( else vit_out ) + if with_timing: + mlp_start.record() + h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = model.pixel_shuffle( @@ -296,14 +291,17 @@ def generate_mm( vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = model.mlp1(vit_embeds) - # 2) Text token embeddings + if with_timing: + mlp_end.record() + torch.cuda.synchronize() + mlp_time = mlp_start.elapsed_time(mlp_end) + seq_tokens = input_ids.clone() seq_embeds = emb_layer(seq_tokens) if vit_embeds is not None: B, N, C = seq_embeds.shape flat_emb = seq_embeds.view(B * N, C) - mask = seq_tokens.view(B * N) == model.image_token_index try: flat_emb[mask] = vit_embeds.reshape(-1, C).to(flat_emb.dtype)[: mask.sum()] @@ -312,201 +310,79 @@ def generate_mm( flat_emb[mask] = vit_embeds.reshape(-1, C)[: mask.sum()].to(flat_emb.dtype) seq_embeds = flat_emb.view(B, N, C) - # ───────────────────────────────── Greedy loop ─────────────────────────────────────────────────── - isl = seq_tokens.shape[1] - osl = max_output_seq_length - isl - - generated = 0 + if with_timing: + return seq_tokens, seq_embeds, vision_time, mlp_time + else: + return seq_tokens, seq_embeds - while generated < osl: - cur_embeds = seq_embeds # full seq first step or cache off - position_ids = ( - torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) - ) - with torch.no_grad(): - logits = model.language_model( - inputs_embeds=cur_embeds, position_ids=position_ids - ) - if hasattr(logits, "logits"): - logits = logits.logits - next_tok = torch.argmax(logits[:, -1, :], dim=-1) # (B,) - # append token & embed - seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=-1) - seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) - - generated += 1 - if (next_tok == eos_token_id).all(): - break - - return seq_tokens[:, input_ids.shape[1] :] - - -@torch.inference_mode() -def generate_mm_with_static_cache( - model, # Complete VLM module - pixel_values: torch.Tensor | None, - input_ids: torch.Tensor, # (B, N_prompt) - max_output_seq_length: int, - eos_token_id: int, - emb_layer: torch.nn.Embedding, - device: str = "cuda:0", -) -> torch.LongTensor: # (B, N_prompt + new) - """ - Greedy Decoder for multimodal VLM (using static KV-cache v1). - Basic structure is identical to LM version (generate_with_static_cache) but - * Input is `inputs_embeds` - * Vision tokens are sent together only in the first step - """ - - # ───────────────────── Vision encoding ───────────────────── - vit_embeds = None - if pixel_values is not None: - vit_latent = model.vision_model(pixel_values) - vit_embeds = ( - vit_latent.last_hidden_state - if hasattr(vit_latent, "last_hidden_state") - else vit_latent - ) - h = w = int(vit_embeds.shape[1] ** 0.5) - vit_embeds = vit_embeds.view(vit_embeds.size(0), h, w, -1) - vit_embeds = model.pixel_shuffle(vit_embeds, model.downsample_ratio) - vit_embeds = vit_embeds.view(vit_embeds.size(0), -1, vit_embeds.size(-1)) - vit_embeds = model.mlp1(vit_embeds) # (B, N_img, C) - - # ───────────────────── Text embedding & [IMG] replacement ───────────── - seq_tokens = input_ids.clone() # (B, N_txt) - seq_embeds = emb_layer(seq_tokens) # (B, N_txt, C) - - if vit_embeds is not None: - B, N, C = seq_embeds.shape - flat = seq_embeds.view(B * N, C) - mask = seq_tokens.view(B * N) == model.image_token_index - flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] - seq_embeds = flat.view(B, N, C) - - # ───────────────────── KV-cache initialization ───────────────────── - kv_cache = get_zeroed_static_cache_inputs(model.language_model, device=device) - start_idx = 0 # First token index - end_idx = seq_embeds.size(1) # Prompt length - generated = 0 - max_total_len = max_output_seq_length - output_tokens = seq_tokens.clone() - - # ───────────────────── Greedy loop ─────────────────────── - while output_tokens.size(1) < max_total_len: - - # When using static cache: - # - First step: Use full prompt embedding - # - Subsequent steps: Use only new token embedding (KV cache remembers previous tokens) - cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] - - # position_ids: Same pattern as generate_with_static_cache - # - First step: Position of entire sequence - # - Subsequent steps: Position of current token only - if generated == 0: - position_ids = ( - torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) - ) - else: - position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( - cur_embeds.device - ) - - # is_causal = True if cur_embeds.shape[1] > 1 else False - input_signature = ( - cur_embeds, - position_ids, - *kv_cache, - start_idx, - end_idx, - # is_causal, - ) - - logits_and_kv = model.language_model(*input_signature) - logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] - - next_tok = logits[:, -1, :].argmax(dim=-1) # (B,) - output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) - - # Prepare for next step - Static cache only needs new token - next_embed = emb_layer(next_tok)[:, None, :] # (B, 1, C) - seq_embeds = next_embed # Next step uses only new token - - generated += 1 - start_idx = end_idx - end_idx += 1 - # is_causal = True # Causal mask active from now on - - if (next_tok == eos_token_id).all(): - break - - return output_tokens - - -def generate_mm_with_timing( +def generate_mm( model, pixel_values: torch.Tensor | None, input_ids: torch.Tensor, eos_token_id: int, emb_layer: torch.nn.Embedding, max_new_tokens: int = 64, - use_cache: bool = False, + device: str = "cuda:0", + with_timing: bool = False, ): - # Create timing events - overall_start = torch.cuda.Event(enable_timing=True) - overall_end = torch.cuda.Event(enable_timing=True) - vision_start = torch.cuda.Event(enable_timing=True) - vision_end = torch.cuda.Event(enable_timing=True) - mlp_start = torch.cuda.Event(enable_timing=True) - mlp_end = torch.cuda.Event(enable_timing=True) - lm_start = torch.cuda.Event(enable_timing=True) - lm_end = torch.cuda.Event(enable_timing=True) - - overall_start.record() + """Greedy decode for Eagle2-style VLM, with optional detailed timing. - vit_embeds = None - if pixel_values is not None: - vision_start.record() - vit_out = model.vision_model(pixel_values) - vision_end.record() - torch.cuda.synchronize() - vision_time = vision_start.elapsed_time(vision_end) - - vit_embeds = ( - vit_out.last_hidden_state - if hasattr(vit_out, "last_hidden_state") - else vit_out + Parameters + ---------- + model : nn.Module + Must expose vision_model, mlp1, language_model, pixel_shuffle, downsample_ratio, image_token_index. + pixel_values : Tensor | None + Input image batch (B,C,H,W) or None. + input_ids : LongTensor (B, N_prompt) + Text prompt token ids including [IMG] placeholder(s). + eos_token_id : int + Stop generation when all sequences emit EOS. + emb_layer : nn.Embedding + Embedding layer for input_ids. + max_new_tokens : int + Maximum number of new tokens to generate. + with_timing : bool + If True, returns detailed timing information. + + Returns + ------- + if with_timing is False: + torch.LongTensor: Generated token sequence (only new tokens). + if with_timing is True: + tuple: ( + seq_tokens: Full generated token sequence, + step_times: List of latencies for each generation step, + overall_time: Total generation time, + vision_time: Vision encoder latency, + mlp_time: MLP latency ) - - mlp_start.record() - h = w = int(vit_embeds.shape[1] ** 0.5) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) - vit_embeds = model.pixel_shuffle( - vit_embeds, scale_factor=model.downsample_ratio + """ + if with_timing: + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + overall_start.record() + + # --- Input preparation --- + if with_timing: + seq_tokens, seq_embeds, vision_time, mlp_time = _prepare_mm_inputs( + model, pixel_values, input_ids, emb_layer, with_timing=True + ) + else: + seq_tokens, seq_embeds = _prepare_mm_inputs( + model, pixel_values, input_ids, emb_layer, with_timing=False ) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) - vit_embeds = model.mlp1(vit_embeds) - mlp_end.record() - torch.cuda.synchronize() - mlp_time = mlp_start.elapsed_time(mlp_end) - - seq_tokens = input_ids.clone() - seq_embeds = emb_layer(seq_tokens) - - if vit_embeds is not None: - B, N, C = seq_embeds.shape - flat_emb = seq_embeds.view(B * N, C) - mask = seq_tokens.view(B * N) == model.image_token_index - flat_emb[mask] = vit_embeds.reshape(-1, C).to(flat_emb.dtype)[: mask.sum()] - seq_embeds = flat_emb.view(B, N, C) + # ───────────────────────────────── Greedy loop ─────────────────────────────────────────────────── step_times = [] generated = 0 - past_key_values = None while generated < max_new_tokens: - lm_start.record() + if with_timing: + lm_start.record() + cur_embeds = seq_embeds position_ids = ( torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) @@ -519,115 +395,84 @@ def generate_mm_with_timing( logits = logits.logits next_tok = torch.argmax(logits[:, -1, :], dim=-1) - lm_end.record() - torch.cuda.synchronize() - step_times.append(lm_start.elapsed_time(lm_end)) + + if with_timing: + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=-1) seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) generated += 1 + if (next_tok == eos_token_id).all(): + break - overall_end.record() - torch.cuda.synchronize() - overall_time = overall_start.elapsed_time(overall_end) - - return seq_tokens, step_times, overall_time, vision_time, mlp_time + if with_timing: + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + return ( + seq_tokens[:, input_ids.shape[1] :], + step_times, + overall_time, + vision_time, + mlp_time, + ) + else: + return seq_tokens[:, input_ids.shape[1] :] @torch.inference_mode() -def generate_mm_with_static_cache_timing( - model, # Complete VLM module +def generate_mm_with_static_cache( + model, pixel_values: torch.Tensor | None, - input_ids: torch.Tensor, # (B, N_prompt) + input_ids: torch.Tensor, eos_token_id: int, emb_layer: torch.nn.Embedding, max_new_tokens: int = 64, device: str = "cuda:0", -) -> tuple: # (seq_tokens, step_times, overall_time, vision_time, mlp_time) + with_timing: bool = False, +): """ - Greedy Decoder for multimodal VLM (using static KV-cache v1) + detailed timing measurement. - - Returns: - seq_tokens: Generated token sequence - step_times: Language model inference time for each step (ms) - overall_time: Total execution time (ms) - vision_time: Vision encoding time (ms) - mlp_time: MLP processing time (ms) + Greedy Decoder for multimodal VLM (using static KV-cache v1), with optional timing. + Basic structure is identical to LM version (generate_with_static_cache) but + * Input is `inputs_embeds` + * Vision tokens are sent together only in the first step """ - - # ───────────────────── Create timing events ───────────────────── - overall_start = torch.cuda.Event(enable_timing=True) - overall_end = torch.cuda.Event(enable_timing=True) - vision_start = torch.cuda.Event(enable_timing=True) - vision_end = torch.cuda.Event(enable_timing=True) - mlp_start = torch.cuda.Event(enable_timing=True) - mlp_end = torch.cuda.Event(enable_timing=True) - lm_start = torch.cuda.Event(enable_timing=True) - lm_end = torch.cuda.Event(enable_timing=True) - - overall_start.record() - - # ───────────────────── Vision encoding ───────────────────── - vit_embeds = None - vision_time = 0.0 - mlp_time = 0.0 - - if pixel_values is not None: - vision_start.record() - vit_latent = model.vision_model(pixel_values) - vision_end.record() - torch.cuda.synchronize() - vision_time = vision_start.elapsed_time(vision_end) - - vit_embeds = ( - vit_latent.last_hidden_state - if hasattr(vit_latent, "last_hidden_state") - else vit_latent + if with_timing: + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + overall_start.record() + vision_time, mlp_time = 0.0, 0.0 + + if with_timing: + seq_tokens, seq_embeds, vision_time, mlp_time = _prepare_mm_inputs( + model, pixel_values, input_ids, emb_layer, with_timing=True + ) + else: + seq_tokens, seq_embeds = _prepare_mm_inputs( + model, pixel_values, input_ids, emb_layer, with_timing=False ) - - mlp_start.record() - h = w = int(vit_embeds.shape[1] ** 0.5) - vit_embeds = vit_embeds.view(vit_embeds.size(0), h, w, -1) - vit_embeds = model.pixel_shuffle(vit_embeds, model.downsample_ratio) - vit_embeds = vit_embeds.view(vit_embeds.size(0), -1, vit_embeds.size(-1)) - vit_embeds = model.mlp1(vit_embeds) # (B, N_img, C) - mlp_end.record() - torch.cuda.synchronize() - mlp_time = mlp_start.elapsed_time(mlp_end) - - # ───────────────────── Text embedding & [IMG] replacement ───────────── - seq_tokens = input_ids.clone() # (B, N_txt) - seq_embeds = emb_layer(seq_tokens) # (B, N_txt, C) - - if vit_embeds is not None: - B, N, C = seq_embeds.shape - flat = seq_embeds.view(B * N, C) - mask = seq_tokens.view(B * N) == model.image_token_index - flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] - seq_embeds = flat.view(B, N, C) # ───────────────────── KV-cache initialization ───────────────────── kv_cache = get_zeroed_static_cache_inputs(model.language_model, device=device) - start_idx = 0 # First token index - end_idx = seq_embeds.size(1) # Prompt length + start_idx = 0 + end_idx = seq_embeds.size(1) generated = 0 max_total_len = end_idx + max_new_tokens output_tokens = seq_tokens.clone() - step_times = [] # Timing for each step + step_times = [] # ───────────────────── Greedy loop ─────────────────────── while output_tokens.size(1) < max_total_len: - lm_start.record() + if with_timing: + lm_start.record() - # When using static cache: - # - First step: Use full prompt embedding - # - Subsequent steps: Use only new token embedding (KV cache remembers previous tokens) cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] - # position_ids: Same pattern as generate_with_static_cache - # - First step: Position of entire sequence - # - Subsequent steps: Position of current token only if generated == 0: position_ids = ( torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) @@ -637,87 +482,75 @@ def generate_mm_with_static_cache_timing( cur_embeds.device ) - # is_causal = True if cur_embeds.shape[1] > 1 else False input_signature = ( cur_embeds, position_ids, *kv_cache, start_idx, end_idx, - # is_causal, ) logits_and_kv = model.language_model(*input_signature) logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] - next_tok = logits[:, -1, :].argmax(dim=-1) # (B,) + next_tok = logits[:, -1, :].argmax(dim=-1) output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) - # Prepare for next step - Static cache only needs new token - next_embed = emb_layer(next_tok)[:, None, :] # (B, 1, C) - seq_embeds = next_embed # Next step uses only new token + next_embed = emb_layer(next_tok)[:, None, :] + seq_embeds = next_embed generated += 1 start_idx = end_idx end_idx += 1 - lm_end.record() - torch.cuda.synchronize() - step_times.append(lm_start.elapsed_time(lm_end)) + if with_timing: + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) if (next_tok == eos_token_id).all(): break - overall_end.record() - torch.cuda.synchronize() - overall_time = overall_start.elapsed_time(overall_end) - - return output_tokens, step_times, overall_time, vision_time, mlp_time - - -def time_generate_mm( - generate_fn, - iterations=10, - **kwargs, -): - """ - Measure the time for generating a sentence over certain number of iterations. - Accepts generation function arguments via kwargs. - """ - timings = [] - for _ in range(iterations): - start_time = timeit.default_timer() - _ = generate_fn(**kwargs) + if with_timing: + overall_end.record() torch.cuda.synchronize() - end_time = timeit.default_timer() - timings.append(end_time - start_time) - - return timings + overall_time = overall_start.elapsed_time(overall_end) + return output_tokens, step_times, overall_time, vision_time, mlp_time + else: + return output_tokens -def generate_mm_qwen2_5_vl( +def _prepare_qwen_mm_inputs( model, pixel_values: torch.Tensor | None, input_ids: torch.Tensor, image_grid_thw: torch.Tensor, - max_output_seq_length: int, - eos_token_id: int, emb_layer: torch.nn.Embedding, + with_timing: bool = False, ): """ - Custom generation function for the Qwen2_5_VLForConditionalGeneration model. - Performs greedy decoding without caching, using inputs_embeds instead of input_ids. + Prepares multimodal inputs for Qwen2.5-VL by encoding images and merging with text embeddings. + Optionally times the vision part. """ - # 1. Calculate image embeddings (if pixel_values are provided) + vision_time = 0.0 image_embeds = None + if pixel_values is not None: + if with_timing: + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + vision_start.record() + image_embeds = model.visual(pixel_values, image_grid_thw) - # 2. Create initial sequence embeddings + if with_timing: + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + seq_tokens = input_ids.clone() seq_embeds = emb_layer(seq_tokens) - # 3. Insert image embeddings at image token positions if image_embeds is not None: mask = seq_tokens == model.config.image_token_id num_image_tokens = mask.sum().item() @@ -730,11 +563,57 @@ def generate_mm_qwen2_5_vl( mask_expanded, image_embeds.to(seq_embeds.dtype) ) - osl = max_output_seq_length - seq_tokens.shape[1] - # 5. Greedy generation loop + if with_timing: + return seq_tokens, seq_embeds, vision_time + else: + return seq_tokens, seq_embeds + + +def generate_mm_qwen2_5_vl( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, + with_timing: bool = False, +): + """ + Custom generation function for the Qwen2_5_VLForConditionalGeneration model, with optional timing. + """ + if with_timing: + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + overall_start.record() + + if with_timing: + seq_tokens, seq_embeds, vision_time = _prepare_qwen_mm_inputs( + model, + pixel_values, + input_ids, + image_grid_thw, + emb_layer, + with_timing=True, + ) + else: + seq_tokens, seq_embeds = _prepare_qwen_mm_inputs( + model, + pixel_values, + input_ids, + image_grid_thw, + emb_layer, + with_timing=False, + ) + + step_times = [] generated = 0 - while generated < osl: - # 5.1. Calculate position_ids + while generated < max_new_tokens: + if with_timing: + lm_start.record() + position_ids = ( torch.arange( 0, seq_tokens.size(1), dtype=torch.long, device=seq_tokens.device @@ -743,7 +622,6 @@ def generate_mm_qwen2_5_vl( .expand(seq_embeds.size(0), seq_embeds.size(1)) ) - # 5.2. Call the language model with torch.no_grad(): outputs = model.model( inputs_embeds=seq_embeds, @@ -755,21 +633,35 @@ def generate_mm_qwen2_5_vl( else outputs.last_hidden_state ) - # 5.3. Calculate logits for the last token logits = model.lm_head(hidden_states[:, -1, :]) - - # 5.4. Select the next token (greedy decoding) next_tok = torch.argmax(logits, dim=-1) - # 5.5. Append token and embedding to the sequence + if with_timing: + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) next_emb = emb_layer(next_tok)[:, None, :] seq_embeds = torch.cat([seq_embeds, next_emb], dim=1) generated += 1 + if (next_tok == eos_token_id).all(): + break - # 6. Return generated tokens (only the part after the prompt) - return seq_tokens[:, input_ids.shape[1] :] + if with_timing: + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + return ( + seq_tokens[:, input_ids.shape[1] :], + step_times, + overall_time, + vision_time, + 0.0, + ) + else: + return seq_tokens[:, input_ids.shape[1] :] def generate_mm_qwen2_5_vl_with_static_cache( @@ -777,48 +669,53 @@ def generate_mm_qwen2_5_vl_with_static_cache( pixel_values: torch.Tensor | None, input_ids: torch.Tensor, image_grid_thw: torch.Tensor, - max_output_seq_length: int, eos_token_id: int, emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, device: str = "cuda:0", + with_timing: bool = False, ) -> torch.LongTensor: """ - Greedy Decoder for Qwen-2.5-VL using static KV-cache. - Identical to `generate_mm_with_static_cache` but adapted for Qwen-2.5-VL's - specific architecture (e.g., separate visual encoder call, lm_head). + Greedy Decoder for Qwen-2.5-VL using static KV-cache, with optional timing. """ - # 1. Vision encoding - image_embeds = None - if pixel_values is not None: - image_embeds = model.visual(pixel_values, image_grid_thw) - - # 2. Text embedding & image token replacement - seq_tokens = input_ids.clone() - seq_embeds = emb_layer(seq_tokens) - - if image_embeds is not None: - mask = seq_tokens == model.config.image_token_id - num_image_tokens = mask.sum().item() - if num_image_tokens != image_embeds.shape[0]: - raise ValueError( - f"Number of image tokens ({num_image_tokens}) does not match " - f"number of image embeddings ({image_embeds.shape[0]})." - ) - mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds) - seq_embeds = seq_embeds.masked_scatter( - mask_expanded, image_embeds.to(seq_embeds.dtype) + if with_timing: + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + overall_start.record() + + if with_timing: + seq_tokens, seq_embeds, vision_time = _prepare_qwen_mm_inputs( + model, + pixel_values, + input_ids, + image_grid_thw, + emb_layer, + with_timing=True, + ) + else: + seq_tokens, seq_embeds = _prepare_qwen_mm_inputs( + model, + pixel_values, + input_ids, + image_grid_thw, + emb_layer, + with_timing=False, ) - # 3. KV-cache initialization kv_cache = get_zeroed_static_cache_inputs(model.model, device=device) start_idx = 0 end_idx = seq_embeds.size(1) generated = 0 - osl = max_output_seq_length - seq_tokens.shape[1] + max_total_len = end_idx + max_new_tokens output_tokens = seq_tokens.clone() + step_times = [] + + while output_tokens.size(1) < max_total_len: + if with_timing: + lm_start.record() - # 4. Greedy loop - while generated < osl: cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] if generated == 0: @@ -826,7 +723,6 @@ def generate_mm_qwen2_5_vl_with_static_cache( torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) ) else: - position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( cur_embeds.device ) @@ -840,13 +736,9 @@ def generate_mm_qwen2_5_vl_with_static_cache( ) outputs_and_kv = model.model(*input_signature) - # With the fix in static_cache_v1.py, the model output is now clean: - # (hidden_state, updated_kv_cache[72]) hidden_states, kv_cache = outputs_and_kv[0], outputs_and_kv[1:] - # Use logit_pos to get the correct logit based on whether we padded or not. logits = model.lm_head(hidden_states[:, -1, :]) - next_tok = logits.argmax(dim=-1) output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) @@ -857,147 +749,22 @@ def generate_mm_qwen2_5_vl_with_static_cache( start_idx = end_idx end_idx += 1 - return output_tokens - - -def generate_mm_paligemma( - model, - pixel_values: torch.Tensor | None, - input_ids: torch.Tensor, - max_output_seq_length: int, - eos_token_id: int, - emb_layer: torch.nn.Embedding, -): - vit_embeds = None - if pixel_values is not None: - vit_out = model.vision_tower(pixel_values) - vit_embeds = model.multi_modal_projector(vit_out.last_hidden_state) - vit_embeds = vit_embeds / (model.config.text_config.hidden_size**0.5) - - seq_tokens = input_ids.clone() - seq_embeds = emb_layer(seq_tokens) - - if vit_embeds is not None: - B, N, C = seq_embeds.shape - flat = seq_embeds.view(B * N, C) - mask = seq_tokens.view(B * N) == model.config.image_token_index - flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] - seq_embeds = flat.view(B, N, C) - - B = seq_tokens.size(0) - cache_position = torch.arange(seq_tokens.size(1), device=seq_tokens.device) - position_ids = cache_position.unsqueeze(0) + 1 - - generated = 0 - while generated < max_output_seq_length: - causal_mask = model.model._update_causal_mask( - attention_mask=None, - token_type_ids=None, - past_key_values=None, - cache_position=cache_position, - input_tensor=seq_embeds, - is_training=False, - ) - - with torch.no_grad(): - out = model.language_model( - inputs_embeds=seq_embeds, - position_ids=position_ids, - attention_mask=causal_mask, - use_cache=False, - ) - logits = out.last_hidden_state if hasattr(out, "last_hidden_state") else out - - next_tok = torch.argmax(logits[:, -1, :], dim=-1) - seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) - seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) - - position_ids = torch.cat([position_ids, position_ids[:, -1:] + 1], dim=1) - cache_position = torch.arange(seq_tokens.size(1), device=seq_tokens.device) - - generated += 1 - if (next_tok == eos_token_id).all(): - break - - return seq_tokens - - -@torch.inference_mode() -def generate_mm_paligemma_with_static_cache( - model, - pixel_values: torch.Tensor | None, - input_ids: torch.Tensor, - max_output_seq_length: int, - eos_token_id: int, - emb_layer: torch.nn.Embedding, - device: str = "cuda:0", -) -> torch.LongTensor: - vit_embeds = None - if pixel_values is not None: - vit_latent = model.vision_tower(pixel_values) - vit_embeds = ( - vit_latent.last_hidden_state - if hasattr(vit_latent, "last_hidden_state") - else vit_latent - ) - vit_embeds = model.multi_modal_projector(vit_embeds) - vit_embeds = vit_embeds / (model.config.text_config.hidden_size**0.5) - - seq_tokens = input_ids.clone() - seq_embeds = emb_layer(seq_tokens) - - if vit_embeds is not None: - B, N, C = seq_embeds.shape - flat = seq_embeds.view(B * N, C) - mask = seq_tokens.view(B * N) == model.image_token_index - flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] - seq_embeds = flat.view(B, N, C) - - kv_cache = get_zeroed_static_cache_inputs(model.language_model, device=device) - start_idx = 0 - end_idx = seq_embeds.size(1) - generated = 0 - max_total_len = max_output_seq_length - output_tokens = seq_tokens.clone() - - while output_tokens.size(1) < max_total_len: - cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] - if generated == 0: - position_ids = ( - torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) - ) - else: - position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( - cur_embeds.device - ) - is_causal = True if cur_embeds.shape[1] > 1 else False - input_signature = ( - cur_embeds, - position_ids, - *kv_cache, - start_idx, - end_idx, - is_causal, - ) - - logits_and_kv = model.language_model(*input_signature) - logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] - - next_tok = logits[:, -1, :].argmax(dim=-1) - output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) - - next_embed = emb_layer(next_tok)[:, None, :] - seq_embeds = next_embed - - generated += 1 - start_idx = end_idx - end_idx += 1 - is_causal = True + if with_timing: + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) if (next_tok == eos_token_id).all(): break - return output_tokens + if with_timing: + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + # For Qwen, there is no separate MLP part like in Eagle, so mlp_time is 0. + return output_tokens, step_times, overall_time, vision_time, 0.0 + else: + return output_tokens @torch.inference_mode() From 7cef8b2045d779a137e42c97411d92211d5ed521 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 11 Aug 2025 14:49:26 +0000 Subject: [PATCH 8/9] chore: slicing ouput token --- tools/llm/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/llm/utils.py b/tools/llm/utils.py index e62fcced35..9701385208 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -515,9 +515,9 @@ def generate_mm_with_static_cache( overall_end.record() torch.cuda.synchronize() overall_time = overall_start.elapsed_time(overall_end) - return output_tokens, step_times, overall_time, vision_time, mlp_time + return output_tokens[:, input_ids.shape[1]:], step_times, overall_time, vision_time, mlp_time else: - return output_tokens + return output_tokens[:, input_ids.shape[1]:] def _prepare_qwen_mm_inputs( @@ -762,9 +762,9 @@ def generate_mm_qwen2_5_vl_with_static_cache( torch.cuda.synchronize() overall_time = overall_start.elapsed_time(overall_end) # For Qwen, there is no separate MLP part like in Eagle, so mlp_time is 0. - return output_tokens, step_times, overall_time, vision_time, 0.0 + return output_tokens[:, input_ids.shape[1]:], step_times, overall_time, vision_time, 0.0 else: - return output_tokens + return output_tokens[:, input_ids.shape[1]:] @torch.inference_mode() From 3a58d2b15087dcbff251bcd6f147351823e6f17c Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 11 Aug 2025 14:56:53 +0000 Subject: [PATCH 9/9] chore: minor linting --- tools/llm/utils.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 9701385208..77ef26b33a 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -515,9 +515,15 @@ def generate_mm_with_static_cache( overall_end.record() torch.cuda.synchronize() overall_time = overall_start.elapsed_time(overall_end) - return output_tokens[:, input_ids.shape[1]:], step_times, overall_time, vision_time, mlp_time + return ( + output_tokens[:, input_ids.shape[1] :], + step_times, + overall_time, + vision_time, + mlp_time, + ) else: - return output_tokens[:, input_ids.shape[1]:] + return output_tokens[:, input_ids.shape[1] :] def _prepare_qwen_mm_inputs( @@ -762,9 +768,15 @@ def generate_mm_qwen2_5_vl_with_static_cache( torch.cuda.synchronize() overall_time = overall_start.elapsed_time(overall_end) # For Qwen, there is no separate MLP part like in Eagle, so mlp_time is 0. - return output_tokens[:, input_ids.shape[1]:], step_times, overall_time, vision_time, 0.0 + return ( + output_tokens[:, input_ids.shape[1] :], + step_times, + overall_time, + vision_time, + 0.0, + ) else: - return output_tokens[:, input_ids.shape[1]:] + return output_tokens[:, input_ids.shape[1] :] @torch.inference_mode()