Skip to content

Commit 650d398

Browse files
authored
vulkan: Increase workgroup size for GLU, for performance (#14345)
* vulkan: Increase workgroup size for GLU, for performance * vulkan: change GLU shaders to do one element per invocation rather than one row per workgroup
1 parent ab46d11 commit 650d398

File tree

3 files changed

+23
-23
lines changed

3 files changed

+23
-23
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,9 @@ struct vk_op_push_constants {
665665
};
666666

667667
struct vk_op_glu_push_constants {
668+
uint32_t N;
668669
uint32_t ne00;
670+
uint32_t ne20;
669671
uint32_t mode; // 0: default, 1: swapped, 2: split
670672
};
671673

@@ -2761,8 +2763,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
27612763
#undef CREATE_UNARY
27622764

27632765
#define CREATE_GLU(name) \
2764-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \
2765-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2766+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2767+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
27662768

27672769
CREATE_GLU(geglu)
27682770
CREATE_GLU(reglu)
@@ -6867,7 +6869,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
68676869
case GGML_OP_SOFT_MAX_BACK:
68686870
case GGML_OP_SUM_ROWS:
68696871
case GGML_OP_ARGMAX:
6870-
case GGML_OP_GLU:
68716872
{
68726873
const uint32_t nr = ggml_nrows(src0);
68736874
if (nr > 262144) {
@@ -6952,6 +6953,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
69526953
case GGML_OP_CONCAT:
69536954
case GGML_OP_UPSCALE:
69546955
case GGML_OP_UNARY:
6956+
case GGML_OP_GLU:
69556957
case GGML_OP_CONV_2D_DW:
69566958
{
69576959
uint32_t ne = ggml_nelements(dst);
@@ -7600,7 +7602,7 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
76007602

76017603
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
76027604

7603-
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], mode }, dryrun);
7605+
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
76047606
}
76057607

76067608
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
#extension GL_EXT_shader_16bit_storage : require
22

3-
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
3+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
44

55
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
66
layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
77
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
88

9-
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
10-
119
layout (push_constant) uniform parameter
1210
{
11+
uint N;
1312
uint ne00;
13+
uint ne20;
1414
uint mode;
1515
} p;
Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,29 @@
11
void main() {
2-
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
3-
const uint col = gl_LocalInvocationID.x;
2+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
3+
4+
if (i >= p.N) {
5+
return;
6+
}
7+
8+
const uint row = i / p.ne20;
9+
const uint col = i - row * p.ne20;
410

511
if (p.mode == 0) {
612
// Default
713
const uint offset = p.ne00 / 2;
14+
const uint idx = row * p.ne00 + col;
815

9-
for (uint i = col; i < offset; i += BLOCK_SIZE) {
10-
const uint idx = row * p.ne00 + i;
11-
12-
data_d[row * offset + i] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
13-
}
16+
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
1417
} else if (p.mode == 1) {
1518
// Swapped
1619
const uint offset = p.ne00 / 2;
20+
const uint idx = row * p.ne00 + col;
1721

18-
for (uint i = col; i < offset; i += BLOCK_SIZE) {
19-
const uint idx = row * p.ne00 + i;
20-
21-
data_d[row * offset + i] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
22-
}
22+
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
2323
} else {
2424
// Split
25-
for (uint i = col; i < p.ne00; i += BLOCK_SIZE) {
26-
const uint idx = row * p.ne00 + i;
25+
const uint idx = row * p.ne00 + col;
2726

28-
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
29-
}
27+
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
3028
}
3129
}

0 commit comments

Comments
 (0)