Skip to content

Commit 85d5aa7

Browse files
[None][feat] Support kv_cahce_reuse for HyperCLOVAX-Vision model (#7789)
Signed-off-by: yechank <[email protected]>
1 parent 984d4fe commit 85d5aa7

File tree

7 files changed

+312
-212
lines changed

7 files changed

+312
-212
lines changed

docs/source/models/supported-models.md

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@ The following is a table of supported models for the PyTorch backend:
1010
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3` |
1111
| `Exaone4ForCausalLM` | EXAONE 4.0 | `LGAI-EXAONE/EXAONE-4.0-32B` |
1212
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it` |
13+
| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b` |
1314
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA | `meta-llama/Meta-Llama-3.1-70B` |
1415
| `Llama4ForConditionalGeneration` | Llama 4 | `meta-llama/Llama-4-Scout-17B-16E-Instruct` |
1516
| `MistralForCausalLM` | Mistral | `mistralai/Mistral-7B-v0.1` |
1617
| `MixtralForCausalLM` | Mixtral | `mistralai/Mixtral-8x7B-v0.1` |
1718
| `MllamaForConditionalGeneration` | Llama 3.2 | `meta-llama/Llama-3.2-11B-Vision` |
1819
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base` |
1920
| `NemotronNASForCausalLM` | NemotronNAS | `nvidia/Llama-3_3-Nemotron-Super-49B-v1` |
20-
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/Qwen2-7B-Instruct` |
2121
| `Phi3ForCausalLM` | Phi-4 | `microsoft/Phi-4` |
22+
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/Qwen2-7B-Instruct` |
2223
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B` |
2324
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B` |
2425
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B` |
@@ -32,30 +33,30 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl
3233

3334
| Model Architecture/Feature | Overlap Scheduler | CUDA Graph | Attention Data Parallelism | Disaggregated Serving | Chunked Prefill | MTP | EAGLE-3(One Model Engine) | EAGLE-3(Two Model Engine) | Torch Sampler | TLLM C++ Sampler | KV Cache Reuse | Sliding Window Attention | Logits Post Processor | Guided Decoding |
3435
| ------------------------------ | ----------------- | ---------- | -------------------------- | --------------------- | --------------- | --- | ------------------------- | ------------------------- | ------------- | ---------------- | -------------- | ------------------------ | --------------------- | --------------- |
35-
| DeepseekV3ForCausalLM | Yes | Yes | Yes | Yes | Yes [^1] | Yes | No | No | Yes | Yes | Yes [^2] | N/A | Yes | Yes |
36-
| Qwen3MoeForCausalLM | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | N/A | Yes | Yes |
37-
| Qwen3NextForCausalLM | Yes | Yes | No | Untested | Yes | No | No | No | Yes | Yes | No | No | Untested | Untested |
38-
| Llama4ForConditionalGeneration | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes | Untested | N/A | Yes | Yes |
39-
| GPT-OSS | Yes | Yes | Yes | Yes | No | No | Yes | No | Yes | Yes | No | N/A | Yes | Yes |
36+
| `DeepseekV3ForCausalLM` | Yes | Yes | Yes | Yes | Yes [^1] | Yes | No | No | Yes | Yes | Yes [^2] | N/A | Yes | Yes |
37+
| `Qwen3MoeForCausalLM` | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | N/A | Yes | Yes |
38+
| `Qwen3NextForCausalLM` | Yes | Yes | No | Untested | Yes | No | No | No | Yes | Yes | No | No | Untested | Untested |
39+
| `Llama4ForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes | Untested | N/A | Yes | Yes |
40+
| `GptOssForCausalLM` | Yes | Yes | Yes | Yes | No | No | Yes | No | Yes | Yes | No | N/A | Yes | Yes |
4041

4142
[^1]: Chunked Prefill for MLA can only be enabled on SM100.
4243
[^2]: KV cache reuse for MLA can only be enabled on SM90/SM100 and in BF16/FP8 KV cache dtype.
4344

4445

4546
# Multimodal Feature Support Matrix (PyTorch Backend)
4647

47-
| Model Architecture/Feature | Overlap Scheduler | CUDA Graph | Chunked Prefill | Torch Sampler | TLLM C++ Sampler | KV Cache Reuse | Logits Post Processor | EPD Disaggregated Serving | Modality |
48-
| ---------------------------------- | ----------------- | ---------- | --------------- | ------------- | ---------------- | -------------- | --------------------- | ------------------------- | -------- |
49-
| Gemma3ForConditionalGeneration | Yes | Yes | N/A | Yes | Yes | N/A | Yes | No | L + I |
50-
| HCXVisionForCausalLM | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
51-
| LlavaLlamaModel (VILA) | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + V |
52-
| LlavaNextForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
53-
| Llama4ForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
54-
| Mistral3ForConditionalGeneration | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | L + I |
55-
| NemotronH_Nano_VL_V2 | Yes | Yes | Yes | Yes | Yes | No | Yes | No | L + I + V |
56-
| Phi4MMForCausalLM | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | L + I + A |
57-
| Qwen2VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |
58-
| Qwen2_5_VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |
48+
| Model Architecture/Feature | Overlap Scheduler | CUDA Graph | Chunked Prefill | Torch Sampler | TLLM C++ Sampler | KV Cache Reuse | Logits Post Processor | EPD Disaggregated Serving | Modality |
49+
| ------------------------------------ | ----------------- | ---------- | --------------- | ------------- | ---------------- | -------------- | --------------------- | ------------------------- | --------- |
50+
| `Gemma3ForConditionalGeneration` | Yes | Yes | N/A | Yes | Yes | N/A | Yes | No | L + I |
51+
| `HCXVisionForCausalLM` | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I |
52+
| `LlavaLlamaModel (VILA)` | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + V |
53+
| `LlavaNextForConditionalGeneration` | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
54+
| `Llama4ForConditionalGeneration` | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
55+
| `Mistral3ForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | L + I |
56+
| `NemotronH_Nano_VL_V2` | Yes | Yes | Yes | Yes | Yes | No | Yes | No | L + I + V |
57+
| `Phi4MMForCausalLM` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | L + I + A |
58+
| `Qwen2VLForConditionalGeneration` | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |
59+
| `Qwen2_5_VLForConditionalGeneration` | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |
5960

6061
Note:
6162
- L: Language

tensorrt_llm/_torch/models/modeling_clip.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -182,29 +182,39 @@ def __init__(self, model_config: ModelConfig[CLIPVisionConfig]):
182182
self.model_config = model_config
183183
self.config = self.model_config.pretrained_config # HF Vision Config
184184
self.vision_model = CLIPVisionTransformer(self.model_config)
185+
186+
# Needed for prepare_attn_metadata
187+
self.image_size = self.config.image_size
188+
self.patch_size = self.config.patch_size
189+
185190
self.metadata_cls = get_attention_backend(
186191
model_config.attn_backend).Metadata
192+
self.attn_metadata = self.metadata_cls(
193+
max_num_requests=
194+
8192, #TODO(yechank-nvidia): Make this along with the LLM's max_num_requests
195+
max_num_tokens=model_config.max_num_tokens,
196+
kv_cache_manager=None,
197+
)
187198

188199
def prepare_attn_metadata(self, batch_size):
189200
"""
190201
To simplify the usage of the model, this function aims to fill the metadata for Attention
191202
Call this function before forward pass
192203
"""
193-
seq_len = (self.config.image_size // self.config.patch_size)**2 + 1
204+
seq_len = (self.image_size // self.patch_size)**2 + 1
194205
request_ids = list(range(1, batch_size + 1))
195206
prompt_lens = [seq_len] * batch_size
196-
attn_metadata = self.metadata_cls(
197-
seq_lens=torch.tensor([seq_len] * batch_size, dtype=torch.int),
198-
num_contexts=batch_size,
199-
max_num_requests=batch_size,
200-
max_num_tokens=seq_len * batch_size,
201-
kv_cache_manager=None,
202-
request_ids=request_ids,
203-
prompt_lens=prompt_lens,
204-
)
205-
attn_metadata.max_seq_len = seq_len
206-
attn_metadata.prepare()
207-
return attn_metadata
207+
seq_lens = torch.tensor([seq_len] * batch_size,
208+
dtype=torch.int,
209+
pin_memory=True)
210+
211+
self.attn_metadata.num_contexts = batch_size
212+
self.attn_metadata.request_ids = request_ids
213+
self.attn_metadata.prompt_lens = prompt_lens
214+
self.attn_metadata.seq_lens = seq_lens
215+
self.attn_metadata.max_seq_len = seq_len
216+
self.attn_metadata.prepare()
217+
return self.attn_metadata
208218

209219
@property
210220
def dtype(self):

0 commit comments

Comments
 (0)