@@ -23,16 +23,100 @@ layout (push_constant) uniform parameter2
2323 uint rms_partials;
2424} p;
2525
26- // Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
27- // layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
28- // layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
29- layout (binding = 0) buffer A {A_TYPE data_a[];} a[];
30- layout (binding = 0) buffer D {D_TYPE data_d[];} d[];
31-
32- layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
26+ // No readonly/writeonly decorations. Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
27+ layout (binding = 0) buffer A0 {A_TYPE data_a[];} a0;
28+ layout (binding = 1) buffer A1 {A_TYPE data_a[];} a1;
29+ layout (binding = 2) buffer A2 {A_TYPE data_a[];} a2;
30+ layout (binding = 3) buffer A3 {A_TYPE data_a[];} a3;
31+ layout (binding = 4) buffer A4 {A_TYPE data_a[];} a4;
32+ layout (binding = 5) buffer A5 {A_TYPE data_a[];} a5;
33+ layout (binding = 6) buffer A6 {A_TYPE data_a[];} a6;
34+ layout (binding = 7) buffer A7 {A_TYPE data_a[];} a7;
35+ layout (binding = 8) buffer A8 {A_TYPE data_a[];} a8;
36+ layout (binding = 9) buffer A9 {A_TYPE data_a[];} a9;
37+ layout (binding = 10) buffer A10 {A_TYPE data_a[];} a10;
38+ layout (binding = 11) buffer A11 {A_TYPE data_a[];} a11;
39+ layout (binding = 0) buffer D0 {D_TYPE data_d[];} d0;
40+ layout (binding = 1) buffer D1 {D_TYPE data_d[];} d1;
41+ layout (binding = 2) buffer D2 {D_TYPE data_d[];} d2;
42+ layout (binding = 3) buffer D3 {D_TYPE data_d[];} d3;
43+ layout (binding = 4) buffer D4 {D_TYPE data_d[];} d4;
44+ layout (binding = 5) buffer D5 {D_TYPE data_d[];} d5;
45+ layout (binding = 6) buffer D6 {D_TYPE data_d[];} d6;
46+ layout (binding = 7) buffer D7 {D_TYPE data_d[];} d7;
47+ layout (binding = 8) buffer D8 {D_TYPE data_d[];} d8;
48+ layout (binding = 9) buffer D9 {D_TYPE data_d[];} d9;
49+ layout (binding = 10) buffer D10 {D_TYPE data_d[];} d10;
50+ layout (binding = 11) buffer D11 {D_TYPE data_d[];} d11;
51+ layout (binding = 0, std430) buffer PartialBuf0 {float partial_sums[];} partials0;
52+ layout (binding = 1, std430) buffer PartialBuf1 {float partial_sums[];} partials1;
53+ layout (binding = 2, std430) buffer PartialBuf2 {float partial_sums[];} partials2;
54+ layout (binding = 3, std430) buffer PartialBuf3 {float partial_sums[];} partials3;
55+ layout (binding = 4, std430) buffer PartialBuf4 {float partial_sums[];} partials4;
56+ layout (binding = 5, std430) buffer PartialBuf5 {float partial_sums[];} partials5;
57+ layout (binding = 6, std430) buffer PartialBuf6 {float partial_sums[];} partials6;
58+ layout (binding = 7, std430) buffer PartialBuf7 {float partial_sums[];} partials7;
59+ layout (binding = 8, std430) buffer PartialBuf8 {float partial_sums[];} partials8;
60+ layout (binding = 9, std430) buffer PartialBuf9 {float partial_sums[];} partials9;
61+ layout (binding = 10, std430) buffer PartialBuf10 {float partial_sums[];} partials10;
62+ layout (binding = 11, std430) buffer PartialBuf11 {float partial_sums[];} partials11;
3363
3464layout(constant_id = 0) const uint num_srcs = 2;
3565
66+ FLOAT_TYPE load_a(uint b, uint i) {
67+ switch (b) {
68+ case 0: return FLOAT_TYPE(a0.data_a[i]);
69+ case 1: return FLOAT_TYPE(a1.data_a[i]);
70+ case 2: return FLOAT_TYPE(a2.data_a[i]);
71+ case 3: return FLOAT_TYPE(a3.data_a[i]);
72+ case 4: return FLOAT_TYPE(a4.data_a[i]);
73+ case 5: return FLOAT_TYPE(a5.data_a[i]);
74+ case 6: return FLOAT_TYPE(a6.data_a[i]);
75+ case 7: return FLOAT_TYPE(a7.data_a[i]);
76+ case 8: return FLOAT_TYPE(a8.data_a[i]);
77+ case 9: return FLOAT_TYPE(a9.data_a[i]);
78+ case 10: return FLOAT_TYPE(a10.data_a[i]);
79+ case 11: return FLOAT_TYPE(a11.data_a[i]);
80+ default: return FLOAT_TYPE(0);
81+ }
82+ }
83+
84+ void store_d(uint b, uint i, FLOAT_TYPE v) {
85+ switch (b) {
86+ case 0: d0.data_d[i] = D_TYPE(v); break;
87+ case 1: d1.data_d[i] = D_TYPE(v); break;
88+ case 2: d2.data_d[i] = D_TYPE(v); break;
89+ case 3: d3.data_d[i] = D_TYPE(v); break;
90+ case 4: d4.data_d[i] = D_TYPE(v); break;
91+ case 5: d5.data_d[i] = D_TYPE(v); break;
92+ case 6: d6.data_d[i] = D_TYPE(v); break;
93+ case 7: d7.data_d[i] = D_TYPE(v); break;
94+ case 8: d8.data_d[i] = D_TYPE(v); break;
95+ case 9: d9.data_d[i] = D_TYPE(v); break;
96+ case 10: d10.data_d[i] = D_TYPE(v); break;
97+ case 11: d11.data_d[i] = D_TYPE(v); break;
98+ default: break;
99+ }
100+ }
101+
102+ void store_partial(uint b, uint i, float v) {
103+ switch (b) {
104+ case 0: partials0.partial_sums[i] = v; break;
105+ case 1: partials1.partial_sums[i] = v; break;
106+ case 2: partials2.partial_sums[i] = v; break;
107+ case 3: partials3.partial_sums[i] = v; break;
108+ case 4: partials4.partial_sums[i] = v; break;
109+ case 5: partials5.partial_sums[i] = v; break;
110+ case 6: partials6.partial_sums[i] = v; break;
111+ case 7: partials7.partial_sums[i] = v; break;
112+ case 8: partials8.partial_sums[i] = v; break;
113+ case 9: partials9.partial_sums[i] = v; break;
114+ case 10: partials10.partial_sums[i] = v; break;
115+ case 11: partials11.partial_sums[i] = v; break;
116+ default: break;
117+ }
118+ }
119+
36120uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
37121 return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
38122}
@@ -78,10 +162,10 @@ void main() {
78162
79163 FLOAT_TYPE sum = FLOAT_TYPE(0);
80164 [[unroll]] for (uint s = 0; s < num_srcs; ++s) {
81- sum += FLOAT_TYPE(a[s].data_a[ src_idx(s, i00, i01, i02, i03)] );
165+ sum += load_a(s, src_idx(s, i00, i01, i02, i03));
82166 }
83167 sum_sq += sum*sum;
84- d[ num_srcs].data_d[ dst_idx(i00, i01, i02, i03)] = D_TYPE( sum);
168+ store_d( num_srcs, dst_idx(i00, i01, i02, i03), sum);
85169
86170 idx += num_threads;
87171 }
@@ -104,7 +188,7 @@ void main() {
104188 }
105189
106190 if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
107- partials[ num_srcs + 1].partial_sums[ orig_idx / (num_iter * num_threads)] = sum_sq;
191+ store_partial( num_srcs + 1, orig_idx / (num_iter * num_threads), sum_sq) ;
108192 }
109193 }
110194#endif
0 commit comments