Skip to content

Commit 234ae7d

Browse files
authored
vulkan: skip all-negative-inf blocks in FA (#17186)
1 parent 38eaf32 commit 234ae7d

File tree

4 files changed

+110
-37
lines changed

4 files changed

+110
-37
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ struct vk_device_struct {
521521
bool subgroup_shuffle;
522522
bool subgroup_ballot;
523523
bool subgroup_clustered;
524+
bool subgroup_vote;
524525
bool multi_add;
525526
bool shader_int64;
526527
bool buffer_device_address;
@@ -4188,6 +4189,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
41884189
device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
41894190
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);
41904191

4192+
device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4193+
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote);
4194+
41914195
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
41924196

41934197
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -13572,8 +13576,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1357213576
default:
1357313577
return false;
1357413578
}
13575-
if (!coopmat2 && !device->subgroup_shuffle) {
13576-
// scalar FA uses subgroupShuffle
13579+
if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
13580+
// scalar/coopmat1 FA uses subgroupShuffle/subgroupAll
1357713581
return false;
1357813582
}
1357913583
return true;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
88

99
#extension GL_KHR_shader_subgroup_shuffle : enable
10+
#extension GL_KHR_shader_subgroup_vote : enable
1011

1112
#include "types.glsl"
1213
#include "flash_attn_base.glsl"
@@ -108,6 +109,38 @@ void main() {
108109
[[dont_unroll]]
109110
for (uint32_t j = start_j; j < end_j; ++j) {
110111

112+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
113+
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
114+
115+
float max_mask = NEG_FLT_MAX_OVER_2;
116+
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
117+
uint32_t c = (idx + tid) % Bc;
118+
uint32_t r = (idx + tid) / Bc;
119+
if (idx + tid < Bc * Br) {
120+
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
121+
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
122+
masksh[c][r] = m;
123+
max_mask = max(max_mask, m);
124+
} else {
125+
masksh[c][r] = float(0);
126+
}
127+
}
128+
}
129+
// skip the block if the mask is entirely -inf
130+
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
131+
barrier();
132+
if (gl_SubgroupInvocationID == 0) {
133+
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
134+
}
135+
barrier();
136+
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
137+
max_mask = max(max_mask, tmpsh[s]);
138+
}
139+
if (max_mask <= NEG_FLT_MAX_OVER_2) {
140+
continue;
141+
}
142+
}
143+
111144
float Sf[Br][cols_per_thread];
112145
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
113146
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
@@ -153,21 +186,6 @@ void main() {
153186
}
154187

155188
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
156-
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
157-
158-
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
159-
uint32_t c = (idx + tid) % Bc;
160-
uint32_t r = (idx + tid) / Bc;
161-
if (idx + tid < Bc * Br) {
162-
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
163-
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
164-
} else {
165-
masksh[c][r] = float(0);
166-
}
167-
}
168-
}
169-
barrier();
170-
171189
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
172190
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
173191
float mvf = masksh[c * cols_per_iter + col_tid][r];

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
88

99
#extension GL_KHR_shader_subgroup_basic : enable
10+
#extension GL_KHR_shader_subgroup_vote : enable
1011
#extension GL_KHR_memory_scope_semantics : enable
1112
#extension GL_KHR_cooperative_matrix : enable
1213

@@ -148,6 +149,37 @@ void main() {
148149
[[dont_unroll]]
149150
for (uint32_t j = start_j; j < end_j; ++j) {
150151

152+
float mask_cache[Bc * Br / WorkGroupSize];
153+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
154+
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
155+
156+
float max_mask = NEG_FLT_MAX_OVER_2;
157+
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
158+
uint32_t c = (idx + tid) % Bc;
159+
uint32_t r = (idx + tid) / Bc;
160+
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
161+
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
162+
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
163+
mask_cache[idx / WorkGroupSize] = m;
164+
max_mask = max(max_mask, m);
165+
}
166+
}
167+
}
168+
// skip the block if the mask is entirely -inf
169+
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
170+
barrier();
171+
if (gl_SubgroupInvocationID == 0) {
172+
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
173+
}
174+
barrier();
175+
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
176+
max_mask = max(max_mask, tmpsh[s]);
177+
}
178+
if (max_mask <= NEG_FLT_MAX_OVER_2) {
179+
continue;
180+
}
181+
}
182+
151183
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
152184
uint32_t d = (idx + tid) % (HSK / 4);
153185
uint32_t c = (idx + tid) / (HSK / 4);
@@ -208,7 +240,8 @@ void main() {
208240
uint32_t r = (idx + tid) / Bc;
209241
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
210242
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
211-
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
243+
float f = mask_cache[idx / WorkGroupSize];
244+
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
212245
}
213246
}
214247
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
2929
return max(x, y);
3030
}
3131

32+
float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
33+
return max(x, y);
34+
}
35+
3236
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
3337
return x;
3438
}
@@ -142,49 +146,63 @@ void main() {
142146
[[dont_unroll]]
143147
for (uint32_t j = start_j; j < end_j; ++j) {
144148

145-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
146-
147-
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
148-
149-
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
150-
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
151-
S = coopMatMulAdd(Qf16, K_T, S);
152-
153-
if (p.logit_softcap != 0.0f) {
154-
[[unroll]]
155-
for (int k = 0; k < S.length(); ++k) {
156-
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
157-
}
158-
}
159-
149+
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
160150
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
161151
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
162152

163153
if (nem1_bounds_check) {
164154
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
165155
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
166156
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
157+
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
167158

168-
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
159+
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
169160

170161
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
171162

172-
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
163+
// skip the block if the mask is entirely -inf
164+
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
165+
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
166+
continue;
167+
}
173168
} else {
174169
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
175170
// Don't clamp against nem1 when GQA is enabled
176171
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
177172
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
178173
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
179174

180-
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
175+
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
181176

182177
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
183178

184-
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
179+
// skip the block if the mask is entirely -inf
180+
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
181+
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
182+
continue;
183+
}
185184
}
186185
}
187186

187+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
188+
189+
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
190+
191+
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
192+
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
193+
S = coopMatMulAdd(Qf16, K_T, S);
194+
195+
if (p.logit_softcap != 0.0f) {
196+
[[unroll]]
197+
for (int k = 0; k < S.length(); ++k) {
198+
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
199+
}
200+
}
201+
202+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
203+
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
204+
}
205+
188206
// Clear padding elements to -inf, so they don't contribute to rowmax
189207
if (Clamp != 0 &&
190208
((j + 1) * Bc > KV ||

0 commit comments

Comments
 (0)