Skip to content

Commit 6101a26

Browse files
authored
[BUGFIX] Fix degenerate strides in TRTLLM query tensors for FlashInfer backend. Fixes issue vllm-project#32353 (vllm-project#32417)
Signed-off-by: Vadim Gimpelson <[email protected]>
1 parent f5d1740 commit 6101a26

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,8 +1385,11 @@ def forward(
13851385
)
13861386
else:
13871387
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
1388-
# prefill_query may be non-contiguous
1389-
prefill_query = prefill_query.contiguous()
1388+
# prefill_query may be non-contiguous or have degenerate strides
1389+
# First ensure memory contiguity, then fix degenerate strides
1390+
# with reshape. contiguous() alone doesn't fix degenerate
1391+
# strides when a dimension has size 1.
1392+
prefill_query = prefill_query.contiguous().reshape(prefill_query.shape)
13901393
workspace_buffer = _get_trtllm_gen_workspace_buffer()
13911394
block_tables_prefill = attn_metadata.prefill.block_tables
13921395
seq_lens_prefill = attn_metadata.prefill.seq_lens
@@ -1495,9 +1498,12 @@ def forward(
14951498
out=output[:num_decode_tokens],
14961499
)
14971500
else:
1498-
# decode_query may be non-contiguous
1501+
# decode_query may be non-contiguous or have degenerate strides
14991502
assert isinstance(attn_metadata.decode, TRTLLMDecode)
1500-
decode_query = decode_query.contiguous()
1503+
# First ensure memory contiguity, then fix degenerate strides
1504+
# with reshape. contiguous() alone doesn't fix degenerate
1505+
# strides when a dimension has size 1.
1506+
decode_query = decode_query.contiguous().reshape(decode_query.shape)
15011507
workspace_buffer = _get_trtllm_gen_workspace_buffer()
15021508
block_tables_decode = attn_metadata.decode.block_tables
15031509
seq_lens_decode = attn_metadata.decode.seq_lens

0 commit comments

Comments
 (0)