-
Notifications
You must be signed in to change notification settings - Fork 13.3k
vulkan: Add State Space Model (SSM) Operations Support #16463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this contribution!
warp_sdata[warp_offset + lane] = val; | ||
barrier(); | ||
|
||
if (lane < 16) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 16]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like it's assuming a subgroup size of 32 (also at line 37).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I understand correctly that this doesn't actually rely on a subgroup size of 32, but it's splitting the workgroup into groups of 32 and just reducing those (and it looks like some reduction across groups of 32 has already happened?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I've missed this one. Yeah I don't think it would work with a size != 32. I need to think more through this one.
Do you've any suggestions on what I could do here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this may work because you're not relying on SubgroupInvocationId or SubgroupID, you've just split the workgroup into groups of 32. Maybe we can just test it on AMD (with wave64) and Intel and verify that it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it work on Intel, but I am worried about all these settings that we made configurable. I've not really tried how it behaves with different values of the constants we defined. Or is the assumption that these values should not be tweaked from vulkan-shaders-gen.cpp
without also changing the implementation in the shader?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could change this to a while loop that will handle any power-of-two value of WARP_SIZE. We do want to allow the spec constants to be changeable but it's fine to have limitations like "must be a power of two".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm wave64 AMD and wave8 llvmpipe are failing one test here, possibly due to this. All other tests are passing.
[SSM_SCAN] NMSE = 31335529439335960.000000000 > 0.000000100 SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4): FAIL
I think Intel also has a subgroup size of 32 so it wouldn't be a good test for this.
return warp_sdata[warp_offset]; | ||
} | ||
|
||
void main() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do all threads always load/store in bounds? In the host code there was some rounding up going on, which suggests maybe some threads don't correspond to in-bounds locations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ping on this one. I don't really understand what this shader does and which locations it should be accessing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried to follow what the CUDA shader does. I'll spend more time on it and see if there is anything I can improve about memory access and make sure all the assumptions in the code are checked.
I've addressed the comments and pushed a new version. The results are even better now:
|
|
||
string_to_spv("ssm_scan_f32_d16", "ssm_scan.comp", {{"A_TYPE", "float"}}); | ||
string_to_spv("ssm_scan_f32_d128", "ssm_scan.comp", {{"A_TYPE", "float"}}); | ||
string_to_spv("ssm_scan_f32_d256", "ssm_scan.comp", {{"A_TYPE", "float"}}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These three are all identical now, you only need one.
warp_sdata[warp_offset + lane] = val; | ||
barrier(); | ||
|
||
if (lane < 16) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 16]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could change this to a while loop that will handle any power-of-two value of WARP_SIZE. We do want to allow the spec constants to be changeable but it's fine to have limitations like "must be a power of two".
I've completely replaced the code to reduce sum with In the last version I've also renamed |
Be aware that not all devices support subgroup commands. If there's a performance advantage to using them, you can do that, but it would still need a fallback to using a shared memory reduction. If there isn't a performance advantage, just use the shared memory reduction for compatibility. |
with the subgroup code I get:
if I revert to the reduction loop, I've:
|
I've reverted to the version with a for loop. We can look at the subgroup optimization later |
742ebd7
to
c467631
Compare
Add State Space Model scan operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano <[email protected]>
Add State Space Model conv operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano <[email protected]>
implement SSM scan and SSM conv for Vulkan.
Tested on an NVIDIA L4:
master (4e0388a)
and db8b8bc: