-
Notifications
You must be signed in to change notification settings - Fork 12.7k
vulkan: optimize rms_norm, and allow the work to spread across multiple SMs #15281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…le SMs There are really two parts to this change: (1) Some optimizations similar to what we have in soft_max, to unroll with different numbers of iterations. (2) A fusion optimization where we detect add followed by rms_norm, and make the add shader atomically accumulate the values^2 into memory. Then the rms_norm shader can just load that sum. This allows the rms_norm to be parallelized across multiple workgroups, it just becomes a simple per-element multiply. The fusion optimization is currently only applied when the rms_norm is on a single vector. This previously always ran on a single SM. It could apply more broadly, but when there are other dimensions the work can already spread across SMs, and there would be some complexity to tracking multiple atomic sums.
Set to draft because there will be an interaction with #15252 when it's merged. |
if (p.param3 != 0) { | ||
sum_sq = subgroupAdd(sum_sq); | ||
if (sum_sq != 0 && gl_SubgroupInvocationID == 0) { | ||
atomicAdd(data_atom, sum_sq); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just want to point out that this potentially introduces a bit of nondeterminism due to floating point addition not being associative. I don't expect it to be a problem, just want to mention in case anybody is concerned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, it's not a good idea to introduce nondeterminism in the computations. Are there alternatives?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second commit changes this to write out a partial sum for each workgroup, and the rms_norm shader adds them up, so it's a deterministic order now.
rather than using atomic add, to make it deterministic. The rms_norm shader fetches a subgroup's worth in parallel and uses subgroupAdd to add them up.
e0b01db
to
075dac2
Compare
There are really two parts to this change:
(1) Some optimizations similar to what we have in soft_max, to unroll with different numbers of iterations.
(2) A fusion optimization where we detect add followed by rms_norm, and make the add shader atomically accumulate the values^2 into memory. Then the rms_norm shader can just load that sum. This allows the rms_norm to be parallelized across multiple workgroups, it just becomes a simple per-element multiply.
The fusion optimization is currently only applied when the rms_norm is on a single vector. This previously always ran on a single SM. It could apply more broadly, but when there are other dimensions the work can already spread across SMs, and there would be some complexity to tracking multiple atomic sums.
Perf results below. As expected, bigger gains on a bigger GPU, because the serial cost of rms_norm is more pronounced.