Skip to content

Commit 6ab959e

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents a2fb081 + 961660b commit 6ab959e

File tree

8 files changed

+541
-143
lines changed

8 files changed

+541
-143
lines changed

CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
/src/llama-model-loader.* @slaren
9090
/src/llama-model.* @CISC
9191
/src/llama-vocab.* @CISC
92+
/src/models/ @CISC
9293
/tests/ @ggerganov
9394
/tests/test-backend-ops.cpp @slaren
9495
/tests/test-thread-safety.cpp @slaren

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2030,7 +2030,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
20302030
params.system_prompt.pop_back();
20312031
}
20322032
}
2033-
).set_examples({LLAMA_EXAMPLE_MAIN}));
2033+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_DIFFUSION}));
20342034
add_opt(common_arg(
20352035
{"--in-file"}, "FNAME",
20362036
"an input file (repeat to specify multiple files)",

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,26 +190,28 @@ static __global__ void mul_mat_vec_q(
190190

191191
const uint32_t channel_bias = ids ? channel_x : channel_dst;
192192

193-
float x_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
194-
float gate_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
193+
float x_biases[ncols_dst] = { 0.0f };
194+
float gate_biases[ncols_dst] = { 0.0f };
195195
if constexpr (has_fusion) {
196196
if (use_bias) {
197197
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
198198
// 1. Hide latency by prefetching bias and gate here
199199
// 2. load only on threads that won't die after partial sum calculation
200200
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
201201
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
202+
#pragma unroll
202203
for (int j = 0; j < ncols_dst; ++j) {
203-
x_biases[j][threadIdx.x] = x_bias[j * stride_col_dst + threadIdx.x];
204+
x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];
204205
}
205206
}
206207
}
207208
if (use_gate_bias) {
208209
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
209210
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
210211
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
212+
#pragma unroll
211213
for (int j = 0; j < ncols_dst; ++j) {
212-
gate_biases[j][threadIdx.x] = gate_bias[j * stride_col_dst + threadIdx.x];
214+
gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];
213215
}
214216
}
215217
}
@@ -299,12 +301,12 @@ static __global__ void mul_mat_vec_q(
299301
float result = tmp[j][threadIdx.x];
300302
if constexpr (has_fusion) {
301303
if (use_bias) {
302-
result += x_biases[j][threadIdx.x];
304+
result += x_biases[j];
303305
}
304306
if (use_gate) {
305307
float gate_value = tmp_gate[j][threadIdx.x];
306308
if (use_gate_bias) {
307-
gate_value += gate_biases[j][threadIdx.x];
309+
gate_value += gate_biases[j];
308310
}
309311
switch (active_glu) {
310312
case GGML_GLU_OP_SWIGLU:

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 393 additions & 124 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@ layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
2828
#endif
2929

3030
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
31+
32+
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
33+
3134
#ifdef MUL_MAT_ID
32-
layout (binding = 3) readonly buffer IDS {int data_ids[];};
35+
layout (binding = 4) readonly buffer IDS {int data_ids[];};
3336
#endif
3437

3538
#include "dequant_funcs.glsl"
@@ -45,6 +48,8 @@ layout (push_constant) uniform parameter
4548
uint batch_stride_b;
4649
uint batch_stride_d;
4750

51+
uint enable_bias;
52+
4853
#ifdef MUL_MAT_ID
4954
uint nei0;
5055
uint ne11;
@@ -56,6 +61,10 @@ layout (push_constant) uniform parameter
5661
#endif
5762
} p;
5863

64+
#ifdef MUL_MAT_ID
65+
uint expert_id;
66+
#endif
67+
5968
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
6069
#ifdef MUL_MAT_ID
6170
const uint expert_idx = gl_GlobalInvocationID.y;
@@ -75,7 +84,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
7584
batch_idx_a = i03 * p.ne02 + i02;
7685
}
7786
#else
78-
const uint expert_id = data_ids[expert_idx];
87+
expert_id = data_ids[expert_idx];
7988
#endif
8089

8190
a_offset =
@@ -113,6 +122,13 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
113122
if (tid == 0) {
114123
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
115124
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
125+
if (p.enable_bias != 0) {
126+
#ifdef MUL_MAT_ID
127+
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
128+
#else
129+
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
130+
#endif
131+
}
116132
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
117133
}
118134
}
@@ -148,6 +164,13 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
148164
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
149165
temp[j][n] += tmpsh[j][n][s];
150166
}
167+
if (p.enable_bias != 0) {
168+
#ifdef MUL_MAT_ID
169+
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
170+
#else
171+
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
172+
#endif
173+
}
151174
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
152175
}
153176
}
@@ -173,6 +196,13 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
173196
if (tid == 0) {
174197
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
175198
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
199+
if (p.enable_bias != 0) {
200+
#ifdef MUL_MAT_ID
201+
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
202+
#else
203+
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
204+
#endif
205+
}
176206
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
177207
}
178208
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
1515
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
1616
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
1717

18+
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
19+
1820
layout (push_constant) uniform parameter
1921
{
2022
uint ncols_x;
@@ -29,6 +31,7 @@ layout (push_constant) uniform parameter
2931
uint nb03;
3032
uint nb13;
3133
uint nb23;
34+
uint enable_bias;
3235
} p;
3336

3437
shared FLOAT_TYPE tmp[BLOCK_SIZE];
@@ -117,6 +120,9 @@ void main() {
117120
}
118121

119122
if (tid == 0) {
123+
if (p.enable_bias != 0) {
124+
tmp[0] += FLOAT_TYPE(data_bias[idst]);
125+
}
120126
dst[idst] = tmp[0];
121127
}
122128
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
1717
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
1818
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
1919

20+
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
21+
2022
layout(constant_id = 0) const int BLOCK_SIZE = 32;
2123
// gqa_ratio is in the range [1,8]
2224
layout(constant_id = 1) const uint gqa_ratio = 1;
@@ -29,6 +31,7 @@ layout (push_constant) uniform parameter
2931
uint nchannels_y;
3032
uint b_offset;
3133
uint d_offset;
34+
uint enable_bias;
3235
} p;
3336

3437
#if !USE_SUBGROUP_ADD
@@ -148,6 +151,9 @@ void main() {
148151
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
149152
// dst is not transposed and not permuted
150153
const uint idst = (channel + c)*nrows_dst + row_dst;
154+
if (p.enable_bias != 0) {
155+
temp[c] += FLOAT_TYPE(data_bias[idst]);
156+
}
151157
dst[idst] = temp[c];
152158
}
153159
}

ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp

Lines changed: 94 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,100 @@ layout (push_constant) uniform parameter2
2323
uint rms_partials;
2424
} p;
2525

26-
// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
27-
// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
28-
// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
29-
layout (binding = 0) buffer A {A_TYPE data_a[];} a[];
30-
layout (binding = 0) buffer D {D_TYPE data_d[];} d[];
31-
32-
layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
26+
// No readonly/writeonly decorations. Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
27+
layout (binding = 0) buffer A0 {A_TYPE data_a[];} a0;
28+
layout (binding = 1) buffer A1 {A_TYPE data_a[];} a1;
29+
layout (binding = 2) buffer A2 {A_TYPE data_a[];} a2;
30+
layout (binding = 3) buffer A3 {A_TYPE data_a[];} a3;
31+
layout (binding = 4) buffer A4 {A_TYPE data_a[];} a4;
32+
layout (binding = 5) buffer A5 {A_TYPE data_a[];} a5;
33+
layout (binding = 6) buffer A6 {A_TYPE data_a[];} a6;
34+
layout (binding = 7) buffer A7 {A_TYPE data_a[];} a7;
35+
layout (binding = 8) buffer A8 {A_TYPE data_a[];} a8;
36+
layout (binding = 9) buffer A9 {A_TYPE data_a[];} a9;
37+
layout (binding = 10) buffer A10 {A_TYPE data_a[];} a10;
38+
layout (binding = 11) buffer A11 {A_TYPE data_a[];} a11;
39+
layout (binding = 0) buffer D0 {D_TYPE data_d[];} d0;
40+
layout (binding = 1) buffer D1 {D_TYPE data_d[];} d1;
41+
layout (binding = 2) buffer D2 {D_TYPE data_d[];} d2;
42+
layout (binding = 3) buffer D3 {D_TYPE data_d[];} d3;
43+
layout (binding = 4) buffer D4 {D_TYPE data_d[];} d4;
44+
layout (binding = 5) buffer D5 {D_TYPE data_d[];} d5;
45+
layout (binding = 6) buffer D6 {D_TYPE data_d[];} d6;
46+
layout (binding = 7) buffer D7 {D_TYPE data_d[];} d7;
47+
layout (binding = 8) buffer D8 {D_TYPE data_d[];} d8;
48+
layout (binding = 9) buffer D9 {D_TYPE data_d[];} d9;
49+
layout (binding = 10) buffer D10 {D_TYPE data_d[];} d10;
50+
layout (binding = 11) buffer D11 {D_TYPE data_d[];} d11;
51+
layout (binding = 0, std430) buffer PartialBuf0 {float partial_sums[];} partials0;
52+
layout (binding = 1, std430) buffer PartialBuf1 {float partial_sums[];} partials1;
53+
layout (binding = 2, std430) buffer PartialBuf2 {float partial_sums[];} partials2;
54+
layout (binding = 3, std430) buffer PartialBuf3 {float partial_sums[];} partials3;
55+
layout (binding = 4, std430) buffer PartialBuf4 {float partial_sums[];} partials4;
56+
layout (binding = 5, std430) buffer PartialBuf5 {float partial_sums[];} partials5;
57+
layout (binding = 6, std430) buffer PartialBuf6 {float partial_sums[];} partials6;
58+
layout (binding = 7, std430) buffer PartialBuf7 {float partial_sums[];} partials7;
59+
layout (binding = 8, std430) buffer PartialBuf8 {float partial_sums[];} partials8;
60+
layout (binding = 9, std430) buffer PartialBuf9 {float partial_sums[];} partials9;
61+
layout (binding = 10, std430) buffer PartialBuf10 {float partial_sums[];} partials10;
62+
layout (binding = 11, std430) buffer PartialBuf11 {float partial_sums[];} partials11;
3363

3464
layout(constant_id = 0) const uint num_srcs = 2;
3565

66+
FLOAT_TYPE load_a(uint b, uint i) {
67+
switch (b) {
68+
case 0: return FLOAT_TYPE(a0.data_a[i]);
69+
case 1: return FLOAT_TYPE(a1.data_a[i]);
70+
case 2: return FLOAT_TYPE(a2.data_a[i]);
71+
case 3: return FLOAT_TYPE(a3.data_a[i]);
72+
case 4: return FLOAT_TYPE(a4.data_a[i]);
73+
case 5: return FLOAT_TYPE(a5.data_a[i]);
74+
case 6: return FLOAT_TYPE(a6.data_a[i]);
75+
case 7: return FLOAT_TYPE(a7.data_a[i]);
76+
case 8: return FLOAT_TYPE(a8.data_a[i]);
77+
case 9: return FLOAT_TYPE(a9.data_a[i]);
78+
case 10: return FLOAT_TYPE(a10.data_a[i]);
79+
case 11: return FLOAT_TYPE(a11.data_a[i]);
80+
default: return FLOAT_TYPE(0);
81+
}
82+
}
83+
84+
void store_d(uint b, uint i, FLOAT_TYPE v) {
85+
switch (b) {
86+
case 0: d0.data_d[i] = D_TYPE(v); break;
87+
case 1: d1.data_d[i] = D_TYPE(v); break;
88+
case 2: d2.data_d[i] = D_TYPE(v); break;
89+
case 3: d3.data_d[i] = D_TYPE(v); break;
90+
case 4: d4.data_d[i] = D_TYPE(v); break;
91+
case 5: d5.data_d[i] = D_TYPE(v); break;
92+
case 6: d6.data_d[i] = D_TYPE(v); break;
93+
case 7: d7.data_d[i] = D_TYPE(v); break;
94+
case 8: d8.data_d[i] = D_TYPE(v); break;
95+
case 9: d9.data_d[i] = D_TYPE(v); break;
96+
case 10: d10.data_d[i] = D_TYPE(v); break;
97+
case 11: d11.data_d[i] = D_TYPE(v); break;
98+
default: break;
99+
}
100+
}
101+
102+
void store_partial(uint b, uint i, float v) {
103+
switch (b) {
104+
case 0: partials0.partial_sums[i] = v; break;
105+
case 1: partials1.partial_sums[i] = v; break;
106+
case 2: partials2.partial_sums[i] = v; break;
107+
case 3: partials3.partial_sums[i] = v; break;
108+
case 4: partials4.partial_sums[i] = v; break;
109+
case 5: partials5.partial_sums[i] = v; break;
110+
case 6: partials6.partial_sums[i] = v; break;
111+
case 7: partials7.partial_sums[i] = v; break;
112+
case 8: partials8.partial_sums[i] = v; break;
113+
case 9: partials9.partial_sums[i] = v; break;
114+
case 10: partials10.partial_sums[i] = v; break;
115+
case 11: partials11.partial_sums[i] = v; break;
116+
default: break;
117+
}
118+
}
119+
36120
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
37121
return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
38122
}
@@ -78,10 +162,10 @@ void main() {
78162

79163
FLOAT_TYPE sum = FLOAT_TYPE(0);
80164
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
81-
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
165+
sum += load_a(s, src_idx(s, i00, i01, i02, i03));
82166
}
83167
sum_sq += sum*sum;
84-
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
168+
store_d(num_srcs, dst_idx(i00, i01, i02, i03), sum);
85169

86170
idx += num_threads;
87171
}
@@ -104,7 +188,7 @@ void main() {
104188
}
105189

106190
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
107-
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
191+
store_partial(num_srcs + 1, orig_idx / (num_iter * num_threads), sum_sq);
108192
}
109193
}
110194
#endif

0 commit comments

Comments
 (0)