Skip to content

Commit ae056e1

Browse files
committed
init
Signed-off-by: Sage Moore <[email protected]>
1 parent ae122b1 commit ae056e1

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

vllm/attention/backends/mla/common.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,6 +1282,7 @@ def _compute_prefill_context(
12821282
assert prefill_metadata.context_chunk_max_seq_lens is not None
12831283
assert prefill_metadata.context_lens_tensor is not None
12841284

1285+
has_context = prefill_metadata.context_lens_tensor.max() > 0
12851286
output = None
12861287
iters = len(prefill_metadata.context_chunk_seq_tot)
12871288

@@ -1322,7 +1323,8 @@ def _compute_prefill_context(
13221323
[0, q.shape[-1] - v.shape[-1]],
13231324
value=0)
13241325

1325-
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
1326+
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and \
1327+
has_context is False:
13261328
attn_output, attn_softmax_lse = self.triton_fa_func(
13271329
q,
13281330
k,
@@ -1411,7 +1413,7 @@ def _forward_prefill(
14111413
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
14121414
value=0)
14131415

1414-
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
1416+
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and has_context is False:
14151417
output = self.triton_fa_func(
14161418
q,
14171419
k,

vllm/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3433,9 +3433,9 @@ def __post_init__(self):
34333433
self.compilation_config.level = CompilationLevel.NO_COMPILATION
34343434

34353435
if self.model_config and self.model_config.use_mla and \
3436-
not current_platform.is_cuda():
3436+
not (current_platform.is_cuda() or current_platform.is_rocm()):
34373437
logger.info(
3438-
"MLA is enabled on a non-cuda platform; forcing chunked "
3438+
"MLA is enabled on a non-GPU platform; forcing chunked "
34393439
"prefill and prefix caching to be disabled.")
34403440
self.scheduler_config.enable_chunked_prefill = False
34413441
self.scheduler_config.chunked_prefill_enabled = False

0 commit comments

Comments
 (0)