Skip to content

Commit 8fb910e

Browse files
committed
Revert "Merge branch 'upstream' into concedo_experimental"
This reverts commit f8ee5d9, reversing changes made to a9f5c2d.
1 parent 841b749 commit 8fb910e

22 files changed

+324
-721
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 146 additions & 265 deletions
Large diffs are not rendered by default.
Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,29 @@
11
#version 450
22

33
#extension GL_EXT_shader_16bit_storage : require
4-
#if ADD_RMS
5-
#extension GL_KHR_shader_subgroup_arithmetic : enable
6-
#extension GL_KHR_shader_subgroup_basic : enable
7-
#endif
84

95
#include "types.comp"
106
#include "generic_binary_head.comp"
117

128
const uint num_threads = 256;
139

14-
layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
15-
1610
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
1711

18-
#if ADD_RMS
19-
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
20-
shared FLOAT_TYPE sumsh[num_threads];
21-
#endif
22-
2312
void main() {
2413
uint idx = get_idx();
25-
uint orig_idx = idx;
2614

2715
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
2816
const uint num_iter = 2;
2917

30-
FLOAT_TYPE sum_sq = 0;
31-
3218
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
3319
if (idx >= p.ne) {
3420
continue;
3521
}
3622
uint i00, i01, i02, i03;
3723
get_indices(idx, i00, i01, i02, i03);
3824

39-
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
40-
sum_sq += sum*sum;
41-
42-
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
25+
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
4326

4427
idx += num_threads;
4528
}
46-
47-
#if ADD_RMS
48-
if (p.param3 != 0) {
49-
// reduce the sum within each subgroup, then across subgroups
50-
const uint NumSubgroups = num_threads / gl_SubgroupSize;
51-
sum_sq = subgroupAdd(sum_sq);
52-
if (gl_SubgroupInvocationID == 0) {
53-
sumsh[gl_SubgroupID] = sum_sq;
54-
}
55-
barrier();
56-
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
57-
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
58-
sum_sq += sumsh[gl_SubgroupID + s];
59-
sumsh[gl_SubgroupID] = sum_sq;
60-
}
61-
barrier();
62-
}
63-
64-
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
65-
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
66-
}
67-
}
68-
#endif
6929
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@ layout (constant_id = 4) const uint32_t HSV = 32;
99
layout (constant_id = 5) const uint32_t Clamp = 0;
1010
layout (constant_id = 6) const uint32_t D_split = 16;
1111

12-
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
13-
const uint32_t HSK_pad = (HSK + 15) & ~15;
14-
const uint32_t HSV_pad = (HSV + 15) & ~15;
15-
1612
layout (push_constant) uniform parameter {
1713
uint32_t N;
1814
uint32_t KV;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ const uint32_t MatBc = 16;
4646
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
4747
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
4848

49-
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
49+
const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
5050
shared f16vec4 Qf[Br * qstride];
5151

5252
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
5353
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
5454
shared ACC_TYPE sfsh[Bc * sfshstride];
5555

56-
const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
56+
const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
5757
shared f16vec4 ksh[Bc * kshstride];
5858

5959
shared float slope[Br];
@@ -74,21 +74,6 @@ void main() {
7474

7575
#define tile_row(r) (row_tid * rows_per_thread + (r))
7676

77-
// Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
78-
if ((HSK % 16) != 0) {
79-
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
80-
if (i + tid < Br * qstride) {
81-
Qf[i + tid] = f16vec4(0);
82-
}
83-
}
84-
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
85-
if (i + tid < Bc * kshstride) {
86-
ksh[i + tid] = f16vec4(0);
87-
}
88-
}
89-
barrier();
90-
}
91-
9277
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
9378

9479
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
@@ -166,14 +151,14 @@ void main() {
166151
}
167152
barrier();
168153

169-
// K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
154+
// K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
170155
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
171156
// This is written transposed in order to allow for N being 8 if implementations need it
172157
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
173158
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
174159
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
175160

176-
for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
161+
for (uint32_t d = 0; d < HSK / 16; ++d) {
177162
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
178163

179164
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,16 @@ void main() {
104104
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
105105
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
106106

107-
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
108-
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
107+
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
108+
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
109109

110110
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
111-
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
111+
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
112112

113-
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
113+
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
114114
Qf16 *= float16_t(p.scale);
115115

116-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
116+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
117117

118118
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
119119

@@ -140,10 +140,10 @@ void main() {
140140

141141
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
142142

143-
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
143+
coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
144144

145145
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
146-
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
146+
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
147147
S = coopMatMulAdd(Qf16, K_T, S);
148148

149149
if (p.logit_softcap != 0.0f) {
@@ -208,31 +208,31 @@ void main() {
208208
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
209209
rowsum = coopMatMulAdd(P_A, One, rowsum);
210210

211-
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
211+
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
212212
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
213-
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
213+
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
214214

215215
L = eM*L + rowsum;
216216

217217
// This is the "diagonal" matrix in the paper, but since we do componentwise
218218
// multiply rather than matrix multiply it has the diagonal element smeared
219219
// across the row
220-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
220+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
221221

222222
// resize eM by using smear/reduce
223223
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
224224

225225
// multiply with fp16 accumulation, then add to O.
226-
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
226+
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
227227
PV = coopMatMulAdd(P_A, V, PV);
228228

229-
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV);
229+
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
230230
}
231231

232232
// If there is split_k, then the split_k resolve shader does the final
233233
// division by L. Store the intermediate O value and per-row m and L values.
234234
if (p.k_num > 1) {
235-
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
235+
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
236236

237237
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
238238
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
@@ -243,16 +243,16 @@ void main() {
243243
return;
244244
}
245245

246-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
246+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
247247

248248
// resize L by using smear/reduce
249249
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
250250

251251
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
252-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
252+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> S;
253253
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
254254

255-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
255+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Mr;
256256

257257
// resize M by using smear/reduce
258258
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
@@ -285,7 +285,7 @@ void main() {
285285

286286
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
287287

288-
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
288+
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
289289
if (p.gqa_ratio > 1) {
290290
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
291291
} else {
@@ -295,6 +295,6 @@ void main() {
295295
// permute dimensions
296296
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
297297

298-
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
298+
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
299299
}
300300
}

ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@
33
#extension GL_EXT_shader_16bit_storage : require
44
#extension GL_EXT_nonuniform_qualifier : enable
55
#extension GL_EXT_control_flow_attributes : require
6-
#if ADD_RMS
7-
#extension GL_KHR_shader_subgroup_arithmetic : enable
8-
#extension GL_KHR_shader_subgroup_basic : enable
9-
#endif
106

117
#include "rte.comp"
128
#include "types.comp"
@@ -18,18 +14,11 @@ layout (push_constant) uniform parameter2
1814
uint ne20; uint ne21; uint ne22; uint ne23;
1915

2016
// strides for srcs+dst
21-
uint nb[12][4];
22-
23-
uint rms_partials;
17+
uint nb[8][4];
2418
} p;
2519

26-
// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
27-
// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
28-
// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
29-
layout (binding = 0) buffer A {A_TYPE data_a[];} a[];
30-
layout (binding = 0) buffer D {D_TYPE data_d[];} d[];
31-
32-
layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
20+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
21+
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
3322

3423
layout(constant_id = 0) const uint num_srcs = 2;
3524

@@ -53,22 +42,14 @@ const uint num_threads = 256;
5342

5443
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
5544

56-
#if ADD_RMS
57-
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
58-
shared FLOAT_TYPE sumsh[num_threads];
59-
#endif
60-
6145
void main() {
6246
uint idx = get_idx();
63-
uint orig_idx = idx;
6447

6548
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
6649

6750
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
6851
const uint num_iter = 2;
6952

70-
FLOAT_TYPE sum_sq = 0;
71-
7253
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
7354
if (idx >= ne) {
7455
continue;
@@ -80,32 +61,8 @@ void main() {
8061
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
8162
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
8263
}
83-
sum_sq += sum*sum;
8464
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
8565

8666
idx += num_threads;
8767
}
88-
89-
#if ADD_RMS
90-
if (p.rms_partials != 0) {
91-
// reduce the sum within each subgroup, then across subgroups
92-
const uint NumSubgroups = num_threads / gl_SubgroupSize;
93-
sum_sq = subgroupAdd(sum_sq);
94-
if (gl_SubgroupInvocationID == 0) {
95-
sumsh[gl_SubgroupID] = sum_sq;
96-
}
97-
barrier();
98-
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
99-
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
100-
sum_sq += sumsh[gl_SubgroupID + s];
101-
sumsh[gl_SubgroupID] = sum_sq;
102-
}
103-
barrier();
104-
}
105-
106-
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
107-
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
108-
}
109-
}
110-
#endif
11168
}

0 commit comments

Comments
 (0)