Skip to content

Commit e7c9c64

Browse files
IwakuraReingemini-code-assist[bot]mgoinzyongye
authored andcommitted
Fix trtllm-gen attention env and add attention sink (vllm-project#22378)
Signed-off-by: Siyuan Fu <[email protected]> Signed-off-by: Lain <[email protected]> Signed-off-by: Yongye Zhu <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Yongye Zhu <[email protected]>
1 parent 24e18b2 commit e7c9c64

File tree

5 files changed

+21
-28
lines changed

5 files changed

+21
-28
lines changed

vllm/envs.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@
152152
VLLM_LOOPBACK_IP: str = ""
153153
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
154154
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
155-
VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False
156-
VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False
155+
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
157156
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
158157
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
159158

@@ -1043,13 +1042,9 @@ def get_vllm_port() -> Optional[int]:
10431042
"VLLM_USE_CUDNN_PREFILL":
10441043
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
10451044

1046-
# If set to 1, use the TRTLLM Context Attention backend in flashinfer.
1047-
"VLLM_USE_TRTLLM_CONTEXT_ATTENTION":
1048-
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_CONTEXT_ATTENTION", "0"))),
1049-
1050-
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
1051-
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
1052-
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", "0"))),
1045+
# If set to 1, use the TRTLLM attention backend in flashinfer.
1046+
"VLLM_USE_TRTLLM_ATTENTION":
1047+
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
10531048

10541049
# Controls garbage collection during CUDA graph capture.
10551050
# If set to 0 (default), enables GC freezing to speed up capture time.

vllm/model_executor/models/gpt_oss.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,8 @@ def __init__(
7070

7171
tp_size = get_tensor_model_parallel_world_size()
7272

73-
attention_sink_dtype = (
74-
torch.float32 if envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
75-
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION else torch.bfloat16)
73+
attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
74+
else torch.bfloat16)
7675
self.sinks = torch.nn.Parameter(
7776
torch.empty(config.num_attention_heads // tp_size,
7877
dtype=attention_sink_dtype,

vllm/utils/flashinfer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def use_trtllm_attention(
159159

160160
# Check if the dimensions are supported by TRTLLM decode attention
161161
if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
162-
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
162+
or num_qo_heads % num_kv_heads != 0):
163163
return False
164164

165165
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
@@ -169,10 +169,10 @@ def use_trtllm_attention(
169169
# Making the conditional check for zero because
170170
# the path is automatically enabled if the batch size condition
171171
# is satisfied.
172-
no_use_trtllm = (env_value == "0")
173-
if not no_use_trtllm:
172+
use_trtllm = (env_value == "1")
173+
if use_trtllm:
174174
logger.info_once("Using TRTLLM attention.")
175-
return not no_use_trtllm
175+
return use_trtllm
176176
else:
177177
# Environment variable not set - use auto-detection
178178
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072

vllm/v1/attention/backends/flashinfer.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
215215
self._cascade_wrapper = None # Wrapper for cascade attention
216216

217217
# Global hyperparameters shared by all attention layers
218+
# TODO: discard this for trtllm-gen backend
218219
self.global_hyperparameters = infer_global_hyperparameters(
219220
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
220221

@@ -523,16 +524,12 @@ def build(self,
523524
head_dim = self.kv_cache_spec.head_size
524525

525526
# currently prefill trtllm attention does not support fp8 kv cache
526-
# trtllm may not support sliding window
527-
prefill_use_trtllm = (self.global_hyperparameters.window_left == -1
528-
and not cache_dtype.startswith("fp8")
529-
and use_trtllm_attention(
527+
prefill_use_trtllm = use_trtllm_attention(
530528
num_prefill_tokens, max_seq_len, cache_dtype,
531-
num_qo_heads, num_kv_heads, head_dim))
532-
decode_use_trtllm = (self.global_hyperparameters.window_left == -1
533-
and use_trtllm_attention(
529+
num_qo_heads, num_kv_heads, head_dim)
530+
decode_use_trtllm = use_trtllm_attention(
534531
num_decode_tokens, max_seq_len, cache_dtype,
535-
num_qo_heads, num_kv_heads, head_dim))
532+
num_qo_heads, num_kv_heads, head_dim)
536533

537534
attn_metadata = FlashInferMetadata(
538535
num_actual_tokens=num_actual_tokens,
@@ -793,6 +790,8 @@ def forward(
793790
batch_size=attn_metadata.num_prefills,
794791
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
795792
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
793+
window_left=window_left,
794+
sinks=self.sinks,
796795
out=output[num_decode_tokens:],
797796
)
798797

@@ -839,6 +838,8 @@ def forward(
839838
max_seq_len=attn_metadata.max_seq_len,
840839
bmm1_scale=layer._k_scale_float * self.scale,
841840
bmm2_scale=layer._v_scale_float,
841+
window_left=window_left,
842+
sinks=self.sinks,
842843
out=output[:num_decode_tokens],
843844
)
844845
return output_padded

vllm/v1/attention/backends/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,7 @@ def get_kv_cache_layout():
254254
# Override with format specified by the user.
255255
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
256256
if cache_layout is None:
257-
if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
258-
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
257+
if envs.VLLM_USE_TRTLLM_ATTENTION:
259258
cache_layout = "HND"
260259
else:
261260
cache_layout = get_kv_connector_cache_layout()
@@ -333,8 +332,7 @@ def infer_global_hyperparameters(
333332
global_params = param_sets[0]
334333

335334
# trtllm attention doesn't need global hyper params so disable the check
336-
if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
337-
and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
335+
if not envs.VLLM_USE_TRTLLM_ATTENTION:
338336
for params in param_sets:
339337
if params.window_left != global_params.window_left:
340338
raise ValueError(

0 commit comments

Comments
 (0)