-
Notifications
You must be signed in to change notification settings - Fork 149
Description
Qwen3-32B SFT performance regressed by 6x on NPU
Take Qwen3 as a example.
#154 patched some code on Qwen3ForCausalLM.forward to keep cu_seq_lens_x on CPU in NPU case, but which lead to performance regression. A lot communication cannot be overlapped base on profiling.
I am not sure if .cpu() Qwen3ForCausalLM.forward lead to some sync op. And it is all good once I moved the .cpu() to Qwen3Attention.forward.
And I get #199 to patch the .cpu() code in Attention class only.
PS:
We noticed main branch code now, in my case, the memory usage will increase about 4 gigabytes, and which also leads a huge time-consuming in empty_tensor, we canexport MULTI_STREAM_MEMORY_REUSE=2 export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True to optimize memory usage and avoid this problem. But still confused about why #154 get extra 4 gigabytes in memory usage