Skip to content

Commit 414e382

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents e2decd6 + 710dfc4 commit 414e382

18 files changed

+625
-53
lines changed

convert_hf_to_gguf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5854,6 +5854,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
58545854
return [(self.map_tensor_name(name), data_torch)]
58555855

58565856

5857+
@ModelBase.register("SeedOssForCausalLM")
5858+
class SeedOssModel(TextModel):
5859+
model_arch = gguf.MODEL_ARCH.SEED_OSS
5860+
5861+
58575862
@ModelBase.register("Olmo2ForCausalLM")
58585863
class Olmo2Model(TextModel):
58595864
model_arch = gguf.MODEL_ARCH.OLMO2

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ static __global__ void flash_attn_tile_ext_f16(
258258
const half val = hexp(sink - kqmax[j0/nwarps]);
259259
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
260260
if (threadIdx.x == 0) {
261-
kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
261+
kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val);
262262
}
263263

264264
#pragma unroll

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 164 additions & 31 deletions
Large diffs are not rendered by default.
Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,69 @@
11
#version 450
22

33
#extension GL_EXT_shader_16bit_storage : require
4+
#if ADD_RMS
5+
#extension GL_KHR_shader_subgroup_arithmetic : enable
6+
#extension GL_KHR_shader_subgroup_basic : enable
7+
#endif
48

59
#include "types.comp"
610
#include "generic_binary_head.comp"
711

812
const uint num_threads = 256;
913

14+
layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
15+
1016
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
1117

18+
#if ADD_RMS
19+
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
20+
shared FLOAT_TYPE sumsh[num_threads];
21+
#endif
22+
1223
void main() {
1324
uint idx = get_idx();
25+
uint orig_idx = idx;
1426

1527
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
1628
const uint num_iter = 2;
1729

30+
FLOAT_TYPE sum_sq = 0;
31+
1832
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
1933
if (idx >= p.ne) {
2034
continue;
2135
}
2236
uint i00, i01, i02, i03;
2337
get_indices(idx, i00, i01, i02, i03);
2438

25-
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
39+
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
40+
sum_sq += sum*sum;
41+
42+
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
2643

2744
idx += num_threads;
2845
}
46+
47+
#if ADD_RMS
48+
if (p.param3 != 0) {
49+
// reduce the sum within each subgroup, then across subgroups
50+
const uint NumSubgroups = num_threads / gl_SubgroupSize;
51+
sum_sq = subgroupAdd(sum_sq);
52+
if (gl_SubgroupInvocationID == 0) {
53+
sumsh[gl_SubgroupID] = sum_sq;
54+
}
55+
barrier();
56+
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
57+
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
58+
sum_sq += sumsh[gl_SubgroupID + s];
59+
sumsh[gl_SubgroupID] = sum_sq;
60+
}
61+
barrier();
62+
}
63+
64+
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
65+
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
66+
}
67+
}
68+
#endif
2969
}

ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#extension GL_EXT_shader_16bit_storage : require
44
#extension GL_EXT_nonuniform_qualifier : enable
55
#extension GL_EXT_control_flow_attributes : require
6+
#if ADD_RMS
7+
#extension GL_KHR_shader_subgroup_arithmetic : enable
8+
#extension GL_KHR_shader_subgroup_basic : enable
9+
#endif
610

711
#include "rte.comp"
812
#include "types.comp"
@@ -14,12 +18,16 @@ layout (push_constant) uniform parameter2
1418
uint ne20; uint ne21; uint ne22; uint ne23;
1519

1620
// strides for srcs+dst
17-
uint nb[8][4];
21+
uint nb[12][4];
22+
23+
uint rms_partials;
1824
} p;
1925

2026
layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
2127
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
2228

29+
layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
30+
2331
layout(constant_id = 0) const uint num_srcs = 2;
2432

2533
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
@@ -42,14 +50,22 @@ const uint num_threads = 256;
4250

4351
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
4452

53+
#if ADD_RMS
54+
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
55+
shared FLOAT_TYPE sumsh[num_threads];
56+
#endif
57+
4558
void main() {
4659
uint idx = get_idx();
60+
uint orig_idx = idx;
4761

4862
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
4963

5064
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
5165
const uint num_iter = 2;
5266

67+
FLOAT_TYPE sum_sq = 0;
68+
5369
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
5470
if (idx >= ne) {
5571
continue;
@@ -61,8 +77,32 @@ void main() {
6177
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
6278
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
6379
}
80+
sum_sq += sum*sum;
6481
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
6582

6683
idx += num_threads;
6784
}
85+
86+
#if ADD_RMS
87+
if (p.rms_partials != 0) {
88+
// reduce the sum within each subgroup, then across subgroups
89+
const uint NumSubgroups = num_threads / gl_SubgroupSize;
90+
sum_sq = subgroupAdd(sum_sq);
91+
if (gl_SubgroupInvocationID == 0) {
92+
sumsh[gl_SubgroupID] = sum_sq;
93+
}
94+
barrier();
95+
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
96+
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
97+
sum_sq += sumsh[gl_SubgroupID + s];
98+
sumsh[gl_SubgroupID] = sum_sq;
99+
}
100+
barrier();
101+
}
102+
103+
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
104+
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
105+
}
106+
}
107+
#endif
68108
}

ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ layout (constant_id = 1) const bool do_multiply = false;
1010

1111
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
1212

13-
shared FLOAT_TYPE sum[BLOCK_SIZE];
13+
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
1414

15-
void main() {
15+
void rms_norm(uint num_iters) {
1616
const uint ncols = p.ne00;
1717
const uint nrows = gl_NumWorkGroups.x;
1818
const uint nchannels = gl_NumWorkGroups.y;
@@ -30,38 +30,76 @@ void main() {
3030
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
3131
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
3232

33-
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
33+
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
3434

35-
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
36-
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
37-
sum[tid] += xi * xi;
35+
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
36+
FLOAT_TYPE xi = FLOAT_TYPE(0);
37+
if (col < ncols) {
38+
xi = FLOAT_TYPE(data_a[a_offset + col]);
39+
}
40+
sum += xi * xi;
3841
}
3942

43+
sumsh[tid] = sum;
4044
// sum up partial sums and write back result
4145
barrier();
4246
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
4347
if (tid < s) {
44-
sum[tid] += sum[tid + s];
48+
sum += sumsh[tid + s];
49+
sumsh[tid] = sum;
4550
}
4651
barrier();
4752
}
53+
sum = sumsh[0];
4854

49-
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
55+
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
5056
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
5157

5258
if (do_multiply) {
5359
if (ncols > p.ne10) {
54-
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
60+
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
61+
if (col >= ncols) {
62+
continue;
63+
}
5564
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
5665
}
5766
} else {
58-
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
67+
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
68+
if (col >= ncols) {
69+
continue;
70+
}
5971
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
6072
}
6173
}
6274
} else {
63-
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
75+
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
76+
if (col >= ncols) {
77+
continue;
78+
}
6479
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
6580
}
6681
}
6782
}
83+
84+
void main() {
85+
// instantiate the rms_norm function for several different
86+
// dimensions, to allow loop unrolling
87+
uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE;
88+
if (num_blocks > 32) {
89+
rms_norm(num_blocks);
90+
} else if (num_blocks > 16) {
91+
rms_norm(32);
92+
} else if (num_blocks > 8) {
93+
rms_norm(16);
94+
} else if (num_blocks > 4) {
95+
rms_norm(8);
96+
} else if (num_blocks == 4) {
97+
rms_norm(4);
98+
} else if (num_blocks == 3) {
99+
rms_norm(3);
100+
} else if (num_blocks == 2) {
101+
rms_norm(2);
102+
} else if (num_blocks == 1) {
103+
rms_norm(1);
104+
}
105+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#version 450
2+
3+
#include "generic_binary_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
#extension GL_KHR_shader_subgroup_arithmetic : enable
8+
#extension GL_KHR_shader_subgroup_basic : enable
9+
10+
#define BLOCK_SIZE 128
11+
12+
layout (constant_id = 1) const bool do_multiply = false;
13+
14+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
15+
16+
layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];};
17+
18+
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
19+
20+
void main() {
21+
const uint ncols = p.ne00;
22+
const uint nrows = gl_NumWorkGroups.x;
23+
const uint nchannels = gl_NumWorkGroups.y;
24+
25+
const uint row = 0;
26+
const uint channel = gl_WorkGroupID.y;
27+
const uint samp = gl_WorkGroupID.z;
28+
// The work is split across multiple workgroups in the x dimension. Each invocation
29+
// processes one element
30+
const uint tid = gl_GlobalInvocationID.x;
31+
32+
const uint stride_row = p.nb01;
33+
const uint stride_channel = p.nb02;
34+
const uint stride_sample = p.nb03;
35+
36+
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
37+
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
38+
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
39+
40+
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
41+
42+
uint32_t num_partials = p.param3;
43+
for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) {
44+
sum += partial_sums[i];
45+
}
46+
sum = subgroupAdd(sum);
47+
48+
uint col = tid;
49+
if (col >= ncols) {
50+
return;
51+
}
52+
53+
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
54+
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
55+
56+
if (do_multiply) {
57+
if (ncols > p.ne10) {
58+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
59+
} else {
60+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
61+
}
62+
} else {
63+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
64+
}
65+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ void process_shaders() {
503503
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
504504
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
505505
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
506+
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
506507
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
507508
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
508509

@@ -538,13 +539,15 @@ void process_shaders() {
538539
s += std::string(dst_f16 ? "_f16" : "_f32");
539540
return s;
540541
};
541-
for (std::string op : {"add", "sub", "mul", "div"}) {
542+
for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) {
542543
for (auto src0_f16 : {false, true}) {
543544
for (auto src1_f16 : {false, true}) {
544545
for (auto dst_f16 : {false, true}) {
545546
for (auto rte : {false, true}) {
547+
auto source = op == "add_rms" ? std::string("add") : op;
546548
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
547-
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
549+
auto add_rms = op == "add_rms" ? "1" : "0";
550+
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
548551
}
549552
}
550553
}
@@ -687,7 +690,8 @@ void process_shaders() {
687690

688691
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
689692

690-
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
693+
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
694+
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
691695

692696
for (auto &c : compiles) {
693697
c.wait();
@@ -745,7 +749,7 @@ void write_output_files() {
745749
}
746750

747751
std::string suffixes[2] = {"_f32", "_f16"};
748-
for (const char *op : {"add", "sub", "mul", "div"}) {
752+
for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) {
749753
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
750754
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
751755
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";

0 commit comments

Comments
 (0)