Skip to content

Commit b9ce940

Browse files
authored
vulkan: Fuse rope+set_rows (#16769)
This pattern appears in a lot of models, the rope operation is applied right before storing into the KV cache (usually on the K tensor). Add a path to some of the rope shaders that computes the destination address based on the set_rows tensor. Compile variants of the shader with D_TYPE of f16 (the usual KV cache type). Add a src3 operand to ggml_vk_op_f32 - sometimes rope uses three srcs and needs the fourth for the row indices. Add fused_ops_write_mask to indicate which intermediate tensors need to write their results to memory. Skipping writing the roped K value helps to allow more nodes to run concurrently. Add logic to ggml_vk_graph_optimize to make ROPE+VIEW+SET_ROWS consecutive. It rarely starts out that way in the graph. Add new backend tests.
1 parent 3464bda commit b9ce940

File tree

6 files changed

+371
-117
lines changed

6 files changed

+371
-117
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 248 additions & 86 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
1010
layout (binding = 1) readonly buffer Y {int data_pos[];};
1111
layout (binding = 2) readonly buffer Z {float data_ff[];};
1212
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
13+
layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
1314

1415
layout (push_constant) uniform parameter {
1516
uint ncols;
@@ -27,6 +28,7 @@ layout (push_constant) uniform parameter {
2728
uint s2;
2829
int sections[4];
2930
uint is_back;
31+
uint set_rows_stride;
3032
} p;
3133

3234
float rope_yarn_ramp(const float low, const float high, const uint i0) {

ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,19 @@ void main() {
1616
const uint row_x = row_dst % ne1;
1717
const uint channel_x = row_dst / ne1;
1818

19-
const uint idst = row_dst*ne0 + i0/2;
19+
uint idst = row_dst*ne0 + i0/2;
2020
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
2121

22+
// Fusion optimization: ROPE + VIEW + SET_ROWS..
23+
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
24+
if (p.set_rows_stride != 0) {
25+
idst = row_x*ne0 + i0/2;
26+
idst += data_i[channel_x].x * p.set_rows_stride;
27+
}
28+
2229
if (i0 >= p.n_dims) {
23-
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
24-
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
30+
data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]);
31+
data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]);
2532

2633
return;
2734
}

ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,19 @@ void main() {
1616
const uint row_x = row_dst % ne1;
1717
const uint channel_x = row_dst / ne1;
1818

19-
const uint idst = row_dst*ne0 + i0;
19+
uint idst = row_dst*ne0 + i0;
2020
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
2121

22+
// Fusion optimization: ROPE + VIEW + SET_ROWS..
23+
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
24+
if (p.set_rows_stride != 0) {
25+
idst = row_x*ne0 + i0;
26+
idst += data_i[channel_x].x * p.set_rows_stride;
27+
}
28+
2229
if (i0 >= p.n_dims) {
23-
data_d[idst + 0] = data_a[ix + 0];
24-
data_d[idst + 1] = data_a[ix + 1];
30+
data_d[idst + 0] = D_TYPE(data_a[ix + 0]);
31+
data_d[idst + 1] = D_TYPE(data_a[ix + 1]);
2532

2633
return;
2734
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,10 +842,14 @@ void process_shaders() {
842842
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
843843
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
844844
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
845+
string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
846+
string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
845847

846848
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
847849
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
848850
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
851+
string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
852+
string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
849853

850854
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
851855
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});

tests/test-backend-ops.cpp

Lines changed: 97 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2125,6 +2125,34 @@ struct test_get_rows_back : public test_case {
21252125
}
21262126
};
21272127

2128+
static void init_set_rows_row_ids(ggml_tensor * t, int num_rows) {
2129+
std::random_device rd;
2130+
std::default_random_engine rng(rd());
2131+
for (int i2 = 0; i2 < t->ne[2]; i2++) {
2132+
for (int i1 = 0; i1 < t->ne[1]; i1++) {
2133+
// generate a shuffled subset of row indices
2134+
std::vector<int64_t> data(num_rows);
2135+
for (int i = 0; i < num_rows; i++) {
2136+
data[i] = i;
2137+
}
2138+
std::shuffle(data.begin(), data.end(), rng);
2139+
data.resize(t->ne[0]);
2140+
2141+
const size_t offs = i1*t->nb[1] + i2*t->nb[2];
2142+
if (t->type == GGML_TYPE_I32) {
2143+
// TODO: Make a template or something
2144+
std::vector<int32_t> data_i32(t->ne[0]);
2145+
for (int i = 0; i < t->ne[0]; i++) {
2146+
data_i32[i] = static_cast<int32_t>(data[i]);
2147+
}
2148+
ggml_backend_tensor_set(t, data_i32.data(), offs, t->ne[0]*sizeof(int32_t));
2149+
} else {
2150+
ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
2151+
}
2152+
}
2153+
}
2154+
}
2155+
21282156
// GGML_OP_SET_ROWS
21292157
struct test_set_rows : public test_case {
21302158
const ggml_type type;
@@ -2168,37 +2196,13 @@ struct test_set_rows : public test_case {
21682196
}
21692197

21702198
void initialize_tensors(ggml_context * ctx) override {
2171-
std::random_device rd;
2172-
std::default_random_engine rng(rd());
21732199
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
21742200
if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
21752201
if (ggml_is_view_op(t->op)) {
21762202
continue;
21772203
}
21782204

2179-
for (int i2 = 0; i2 < t->ne[2]; i2++) {
2180-
for (int i1 = 0; i1 < t->ne[1]; i1++) {
2181-
// generate a shuffled subset of row indices
2182-
std::vector<int64_t> data(ne[1]);
2183-
for (int i = 0; i < ne[1]; i++) {
2184-
data[i] = i;
2185-
}
2186-
std::shuffle(data.begin(), data.end(), rng);
2187-
data.resize(t->ne[0]);
2188-
2189-
const size_t offs = i1*t->nb[1] + i2*t->nb[2];
2190-
if (t->type == GGML_TYPE_I32) {
2191-
// TODO: Make a template or something
2192-
std::vector<int32_t> data_i32(t->ne[0]);
2193-
for (int i = 0; i < t->ne[0]; i++) {
2194-
data_i32[i] = static_cast<int32_t>(data[i]);
2195-
}
2196-
ggml_backend_tensor_set(t, data_i32.data(), offs, t->ne[0]*sizeof(int32_t));
2197-
} else {
2198-
ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
2199-
}
2200-
}
2201-
}
2205+
init_set_rows_row_ids(t, ne[1]);
22022206
} else {
22032207
init_tensor_uniform(t);
22042208
}
@@ -2227,6 +2231,67 @@ struct test_set_rows : public test_case {
22272231
}
22282232
};
22292233

2234+
// GGML_OP_ROPE + GGML_OP_VIEW + GGML_OP_SET_ROWS
2235+
struct test_rope_set_rows : public test_case {
2236+
const ggml_type type;
2237+
const ggml_type type_idx;
2238+
const std::array<int64_t, 4> ne;
2239+
int mode;
2240+
2241+
std::string vars() override {
2242+
return VARS_TO_STR4(type, type_idx, ne, mode);
2243+
}
2244+
2245+
std::string op_desc(ggml_tensor * t) override {
2246+
GGML_UNUSED(t);
2247+
return "ROPE_SET_ROWS";
2248+
}
2249+
2250+
bool run_whole_graph() override { return true; }
2251+
2252+
test_rope_set_rows(ggml_type type,
2253+
ggml_type type_idx,
2254+
std::array<int64_t, 4> ne,
2255+
int mode)
2256+
: type(type), type_idx(type_idx), ne(ne), mode(mode) {}
2257+
2258+
ggml_tensor * build_graph(ggml_context * ctx) override {
2259+
ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
2260+
ggml_set_name(src, "src");
2261+
2262+
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
2263+
2264+
ggml_tensor * rope = ggml_rope(ctx, src, pos, ne[0], mode);
2265+
2266+
ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0);
2267+
2268+
ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0] * ne[1], ne[2] * ne[3], 1, 1);
2269+
ggml_set_name(dst, "dst");
2270+
2271+
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, ne[2], 1, 1);
2272+
ggml_set_name(row_idxs, "row_idxs");
2273+
2274+
ggml_tensor * out = ggml_set_rows(ctx, dst, view, row_idxs);
2275+
ggml_set_name(out, "out");
2276+
2277+
return out;
2278+
}
2279+
2280+
void initialize_tensors(ggml_context * ctx) override {
2281+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2282+
if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
2283+
if (ggml_is_view_op(t->op)) {
2284+
continue;
2285+
}
2286+
2287+
init_set_rows_row_ids(t, ne[2]);
2288+
} else {
2289+
init_tensor_uniform(t);
2290+
}
2291+
}
2292+
}
2293+
};
2294+
22302295
// GGML_OP_ARGMAX
22312296
struct test_argmax : public test_case {
22322297
const ggml_type type;
@@ -6163,6 +6228,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
61636228
}
61646229
}
61656230

6231+
for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX }) {
6232+
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
6233+
test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 1, 100 }, mode));
6234+
test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 512, 1 }, mode));
6235+
}
6236+
}
6237+
61666238
for (ggml_type type_input : {GGML_TYPE_F32}) {
61676239
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
61686240
for (int k0 : {1, 3}) {

0 commit comments

Comments
 (0)