Describe the issue
In some models, softmax drops 70% of performance after the Pytorch PR (pytorch/pytorch#162447), which changes the number of warps in reduction and causes register spill.
reproduce:
softmax-new.py
Environment details
pytorch: f05e23e1bc1439e19145e43e8ffca0051cda2f33
triton: pytorch_triton_xpu-3.5.0+git1b0418a9