-
Notifications
You must be signed in to change notification settings - Fork 13.4k
vulkan: Optimize SSM_SCAN #16645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
vulkan: Optimize SSM_SCAN #16645
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
#version 450 | ||
|
||
#extension GL_EXT_control_flow_attributes : require | ||
#if USE_SUBGROUP_ADD | ||
#extension GL_KHR_shader_subgroup_arithmetic : enable | ||
#endif | ||
|
||
#include "types.glsl" | ||
|
||
|
@@ -84,35 +87,47 @@ void main() { | |
} | ||
|
||
barrier(); | ||
for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) { | ||
[[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) { | ||
const uint k = (tid % (w >> 1)) + | ||
(D_STATE * (tid / (w >> 1))) + | ||
j * D_STATE * (D_STATE / (w >> 1)); | ||
if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) { | ||
stateC[k] += stateC[k + (w >> 1)]; | ||
[[unroll]] | ||
for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) { | ||
[[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) { | ||
const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w); | ||
if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) { | ||
stateC[k] += stateC[k + w]; | ||
} | ||
} | ||
barrier(); | ||
} | ||
|
||
[[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was one too many iterations most of the time, leading to extra subgroup ops or barriers. But when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for fixing this one! |
||
[[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) { | ||
const uint idx = (tid % SUBGROUP_SIZE) + | ||
D_STATE * (tid / SUBGROUP_SIZE) + | ||
j * D_STATE * (D_STATE / SUBGROUP_SIZE); | ||
const uint max_idx = SUBGROUP_SIZE - 1 + | ||
D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) + | ||
j * D_STATE * (D_STATE / SUBGROUP_SIZE); | ||
|
||
uint lane = tid % SUBGROUP_SIZE; | ||
|
||
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) { | ||
if (idx + offset < SPLIT_H * D_STATE) { | ||
stateC[idx] += stateC[idx + offset]; | ||
if (idx < SPLIT_H * D_STATE || | ||
max_idx < SPLIT_H * D_STATE) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This max_idx comparison should fold away and avoid the need for the branch most of the time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just curious, is this needed to get the subgroup to run the same code? |
||
float sc; | ||
#if USE_SUBGROUP_ADD | ||
sc = stateC[idx]; | ||
sc = subgroupAdd(sc); | ||
#else | ||
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) { | ||
if (idx + offset < SPLIT_H * D_STATE) { | ||
stateC[idx] += stateC[idx + offset]; | ||
} | ||
barrier(); | ||
} | ||
barrier(); | ||
} | ||
if (tid % SUBGROUP_SIZE == 0) { | ||
sc = stateC[idx]; | ||
} | ||
#endif | ||
|
||
if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) { | ||
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE); | ||
d[y_base_idx + i * stride_y + k] = stateC[idx]; | ||
if (tid % SUBGROUP_SIZE == 0) { | ||
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE); | ||
d[y_base_idx + i * stride_y + k] = sc; | ||
} | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shifted the values of
w
down by a factor of 2, rather than usingw>>1
everywhere. This seemed to be causing some weird code to be generated.