File tree Expand file tree Collapse file tree 4 files changed +15
-6
lines changed
Expand file tree Collapse file tree 4 files changed +15
-6
lines changed Original file line number Diff line number Diff line change @@ -269,8 +269,8 @@ def chunked_prefill_paged_decode(
269269 # Conversion of FP8 Tensor from uint8 storage to
270270 # appropriate torch.dtype for interpretation by Triton
271271 if "fp8" in kv_cache_dtype :
272- assert key_cache .dtype == torch .uint8
273- assert value_cache .dtype == torch .uint8
272+ assert key_cache .dtype in [ torch .uint8 , current_platform . fp8_dtype ()]
273+ assert value_cache .dtype in [ torch .uint8 , current_platform . fp8_dtype ()]
274274
275275 if kv_cache_dtype in ("fp8" , "fp8_e4m3" ):
276276 target_dtype = current_platform .fp8_dtype ()
Original file line number Diff line number Diff line change @@ -749,8 +749,8 @@ def context_attention_fwd(q,
749749 # Conversion of FP8 Tensor from uint8 storage to
750750 # appropriate torch.dtype for interpretation by Triton
751751 if "fp8" in kv_cache_dtype :
752- assert ( k_cache .dtype == torch .uint8 )
753- assert ( v_cache .dtype == torch .uint8 )
752+ assert k_cache .dtype in [ torch .uint8 , current_platform . fp8_dtype ()]
753+ assert v_cache .dtype in [ torch .uint8 , current_platform . fp8_dtype ()]
754754
755755 if kv_cache_dtype in ("fp8" , "fp8_e4m3" ):
756756 target_dtype = current_platform .fp8_dtype ()
Original file line number Diff line number Diff line change 2020 VLLM_USE_TRITON_FLASH_ATTN : bool = True
2121 VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT : bool = True
2222 VLLM_USE_ROCM_FP8_FLASH_ATTN : bool = False
23+ VLLM_V1_USE_PREFILL_DECODE_ATTENTION : bool = False
2324 VLLM_FLASH_ATTN_VERSION : Optional [int ] = None
2425 LOCAL_RANK : int = 0
2526 CUDA_VISIBLE_DEVICES : Optional [str ] = None
@@ -331,6 +332,13 @@ def get_vllm_port() -> Optional[int]:
331332 lambda : (os .getenv ("VLLM_USE_ROCM_FP8_FLASH_ATTN" , "False" ).lower () in
332333 ("true" , "1" )),
333334
335+ # Use separate prefill and decode kernels for V1 attention instead of
336+ # the unified triton kernel.
337+ "VLLM_V1_USE_PREFILL_DECODE_ATTENTION" :
338+ lambda :
339+ (os .getenv ("VLLM_V1_USE_PREFILL_DECODE_ATTENTION" , "False" ).lower () in
340+ ("true" , "1" )),
341+
334342 # Internal flag to enable/disable Inductor standalone compile
335343 "VLLM_TEST_STANDALONE_COMPILE" :
336344 lambda : os .environ .get ("VLLM_TEST_STANDALONE_COMPILE" , "0" ) != "0" ,
Original file line number Diff line number Diff line change 55import torch
66
77from vllm import _custom_ops as ops
8+ from vllm import envs
89from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
910 AttentionMetadata , AttentionType )
1011from vllm .attention .ops .chunked_prefill_paged_decode import (
@@ -167,8 +168,8 @@ def forward(
167168 # performance to make sure it does not introduce any overhead.
168169
169170 num_queries_per_kv = query .shape [1 ] // key .shape [1 ]
170- use_prefill_decode_attn = ( num_queries_per_kv &
171- (num_queries_per_kv - 1 )) != 0
171+ use_prefill_decode_attn = envs . VLLM_V1_USE_PREFILL_DECODE_ATTENTION or (
172+ ( num_queries_per_kv & (num_queries_per_kv - 1 )) != 0 )
172173
173174 num_actual_tokens = attn_metadata .num_actual_tokens
174175
You can’t perform that action at this time.
0 commit comments