@@ -1385,8 +1385,11 @@ def forward(
13851385 )
13861386 else :
13871387 assert isinstance (attn_metadata .prefill , TRTLLMPrefill )
1388- # prefill_query may be non-contiguous
1389- prefill_query = prefill_query .contiguous ()
1388+ # prefill_query may be non-contiguous or have degenerate strides
1389+ # First ensure memory contiguity, then fix degenerate strides
1390+ # with reshape. contiguous() alone doesn't fix degenerate
1391+ # strides when a dimension has size 1.
1392+ prefill_query = prefill_query .contiguous ().reshape (prefill_query .shape )
13901393 workspace_buffer = _get_trtllm_gen_workspace_buffer ()
13911394 block_tables_prefill = attn_metadata .prefill .block_tables
13921395 seq_lens_prefill = attn_metadata .prefill .seq_lens
@@ -1495,9 +1498,12 @@ def forward(
14951498 out = output [:num_decode_tokens ],
14961499 )
14971500 else :
1498- # decode_query may be non-contiguous
1501+ # decode_query may be non-contiguous or have degenerate strides
14991502 assert isinstance (attn_metadata .decode , TRTLLMDecode )
1500- decode_query = decode_query .contiguous ()
1503+ # First ensure memory contiguity, then fix degenerate strides
1504+ # with reshape. contiguous() alone doesn't fix degenerate
1505+ # strides when a dimension has size 1.
1506+ decode_query = decode_query .contiguous ().reshape (decode_query .shape )
15011507 workspace_buffer = _get_trtllm_gen_workspace_buffer ()
15021508 block_tables_decode = attn_metadata .decode .block_tables
15031509 seq_lens_decode = attn_metadata .decode .seq_lens
0 commit comments