Skip to content

Commit 7cef8b2

Browse files
committed
chore: slicing ouput token
1 parent da438a8 commit 7cef8b2

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tools/llm/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,9 +515,9 @@ def generate_mm_with_static_cache(
515515
overall_end.record()
516516
torch.cuda.synchronize()
517517
overall_time = overall_start.elapsed_time(overall_end)
518-
return output_tokens, step_times, overall_time, vision_time, mlp_time
518+
return output_tokens[:, input_ids.shape[1]:], step_times, overall_time, vision_time, mlp_time
519519
else:
520-
return output_tokens
520+
return output_tokens[:, input_ids.shape[1]:]
521521

522522

523523
def _prepare_qwen_mm_inputs(
@@ -762,9 +762,9 @@ def generate_mm_qwen2_5_vl_with_static_cache(
762762
torch.cuda.synchronize()
763763
overall_time = overall_start.elapsed_time(overall_end)
764764
# For Qwen, there is no separate MLP part like in Eagle, so mlp_time is 0.
765-
return output_tokens, step_times, overall_time, vision_time, 0.0
765+
return output_tokens[:, input_ids.shape[1]:], step_times, overall_time, vision_time, 0.0
766766
else:
767-
return output_tokens
767+
return output_tokens[:, input_ids.shape[1]:]
768768

769769

770770
@torch.inference_mode()

0 commit comments

Comments
 (0)