Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
#define VK_VENDOR_ID_APPLE 0x106b
#define VK_VENDOR_ID_INTEL 0x8086
#define VK_VENDOR_ID_NVIDIA 0x10de
#define VK_VENDOR_ID_ARM 0x13B5
#define VK_VENDOR_ID_QUALCOMM 0x5143


#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256

Expand Down Expand Up @@ -448,6 +451,8 @@ struct vk_device_struct {
vk_matmul_pipeline pipeline_matmul_bf16 {};
vk_matmul_pipeline2 pipeline_matmul_f16;
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
vk_pipeline pipeline_matmul_f16_f32_embed;
vk_pipeline pipeline_matmul_f32_f32_embed;

vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
Expand Down Expand Up @@ -2901,6 +2906,40 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
}
}

if (device->vendor_id == VK_VENDOR_ID_ARM || device->vendor_id == VK_VENDOR_ID_QUALCOMM) {
// Shader workgroup size is 16x8 = 128
const uint32_t wg_x = 16;
const uint32_t wg_y = 8;

// Tile sizes for the workgroup
uint32_t bm, bn, bk;

if (device->vendor_id == VK_VENDOR_ID_QUALCOMM) {
bm = 32;
bn = 128;
bk = 8;
} else {
bm = 64;
bn = 64;
bk = 16;
}

// Threads per tile
const uint32_t tm = bm / wg_y;
const uint32_t tn = bn / wg_x;

const std::vector<uint32_t> embed_spec_constants = {bm, bn, bk, tm, tn};
const std::array<uint32_t, 3> embed_wg_denoms = {bm, bn, 1};

ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32_embed, "mul_mat_f16_f32_embed",
mul_mat_f16_f32_embed_len, mul_mat_f16_f32_embed_data, "main", 3,
sizeof(vk_mat_mat_push_constants), embed_wg_denoms, embed_spec_constants, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f32_embed, "mul_mat_f32_f32_embed",
mul_mat_f32_f32_embed_len, mul_mat_f32_f32_embed_data, "main", 3,
sizeof(vk_mat_mat_push_constants), embed_wg_denoms, embed_spec_constants, 1);
}

// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
Expand Down Expand Up @@ -5726,6 +5765,114 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const uint64_t ne12 = src1->ne[2];
const uint64_t ne13 = src1->ne[3];

if ((ctx->device->vendor_id == VK_VENDOR_ID_ARM || ctx->device->vendor_id == VK_VENDOR_ID_QUALCOMM) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems to have an impact also with Intel integrated GPUs in some cases:

before:

| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | Vulkan     | 999 |           pp512 |         60.12 ± 0.00 |
| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | Vulkan     | 999 |           tg128 |          8.02 ± 0.00 |

| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     | 999 |           pp512 |        109.99 ± 0.00 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     | 999 |           tg128 |          8.12 ± 0.00 |

after:

| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | Vulkan     | 999 |           pp512 |         86.34 ± 0.00 |
| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | Vulkan     | 999 |           tg128 |          8.30 ± 0.00 |

| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     | 999 |           pp512 |        101.52 ± 0.00 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     | 999 |           tg128 |          8.01 ± 0.00 |

(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 && ggml_vk_dim01_contiguous(src1) &&
ne02 == 1 && ne03 == 1 &&
ne12 == 1 && ne13 == 1) {

ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
vk_buffer d_Qx = nullptr, d_Qy = nullptr, d_D = nullptr;
size_t qx_buf_offset = 0, qy_buf_offset = 0, d_buf_offset = 0;
bool src0_uma = false, src1_uma = false, dst_uma = false;

if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset);
src0_uma = d_Qx != nullptr;
src1_uma = d_Qy != nullptr;
dst_uma = d_D != nullptr;
}

if (!src0_uma) { d_Qx = src0_buf_ctx->dev_buffer; qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; }
if (!src1_uma) { d_Qy = src1_buf_ctx->dev_buffer; qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; }
if (!dst_uma) { d_D = dst_buf_ctx->dev_buffer; d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; }

const uint32_t M = ne01;
const uint32_t N = ne11;
const uint32_t K = ne10;

vk_pipeline pipeline = nullptr;
vk_buffer d_X;
uint64_t x_buf_offset;
uint32_t stride_a;
bool dequantized = false;

if (ggml_is_quantized(src0->type)) {
vk_pipeline dequant_pipeline = ggml_vk_get_to_fp16(ctx, src0->type);

if (dequant_pipeline) {
dequantized = true;
const uint64_t x_sz = sizeof(ggml_fp16_t) * M * K;

if (dryrun) {
if (ctx->prealloc_size_x < x_sz) {
ctx->prealloc_size_x = x_sz;
}
ggml_pipeline_request_descriptor_sets(ctx, dequant_pipeline, 1);
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_f16_f32_embed, 1);
return;
}

const std::vector<uint32_t> pc = { (uint32_t)M, (uint32_t)K, (uint32_t)K, (uint32_t)K, (uint32_t)(ggml_nelements(src0)) };
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this path is not handling noncontiguous src0. Like @0cc4m said, it'll be better to let this run through the existing code paths rather than having this separate code path.


ggml_vk_dispatch_pipeline(ctx, subctx, dequant_pipeline, {
vk_subbuffer{ d_Qx, qx_buf_offset, VK_WHOLE_SIZE },
vk_subbuffer{ ctx->prealloc_x, 0, VK_WHOLE_SIZE }
}, pc, { (uint32_t)(ggml_nelements(src0)), 1, 1});

d_X = ctx->prealloc_x;
x_buf_offset = 0;
stride_a = K;
pipeline = ctx->device->pipeline_matmul_f16_f32_embed;
}
} else {
if (src0->type == GGML_TYPE_F16) {
pipeline = ctx->device->pipeline_matmul_f16_f32_embed;
} else {
pipeline = ctx->device->pipeline_matmul_f32_f32_embed;
}

if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
return;
}

d_X = d_Qx;
x_buf_offset = qx_buf_offset;
stride_a = src0->nb[1] / ggml_type_size(src0->type);
}

if (pipeline != nullptr) {
if (dequantized) {
ggml_vk_sync_buffers(ctx, subctx); // Ensure dequant is finished
}

const uint32_t stride_b = src1->nb[1] / ggml_type_size(src1->type);
const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type);

const vk_mat_mat_push_constants pc = { M, N, K, stride_a, stride_b, stride_d, M * K, K * N, M * N, K, 1, 1, 1, 1, N };

vk_buffer d_Y = d_Qy;
const uint64_t y_buf_offset = qy_buf_offset;

ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_X, x_buf_offset, VK_WHOLE_SIZE },
vk_subbuffer{ d_Y, y_buf_offset, VK_WHOLE_SIZE },
vk_subbuffer{ d_D, d_buf_offset, VK_WHOLE_SIZE },
}, pc, { M, N, 1 });

if (dequantized) {
ctx->prealloc_x_need_sync = true;
}

return;
}
}

const uint64_t ne20 = dst->ne[0];
const uint64_t ne21 = dst->ne[1];

Expand Down
167 changes: 167 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#version 450

#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_control_flow_attributes : require

#ifdef A_TYPE_FP16
#define A_VEC4_TYPE f16vec4
#define A_SCALAR_TYPE float16_t
#define A_VEC4_ZERO f16vec4(0.0hf)
#define A_VEC4_CAST(v) vec4(v)
#else
#define A_VEC4_TYPE vec4
#define A_SCALAR_TYPE float
#define A_VEC4_ZERO vec4(0.0f)
#define A_VEC4_CAST(v) (v)
#endif

layout(constant_id = 0) const uint BM = 64;
layout(constant_id = 1) const uint BN = 64;
layout(constant_id = 2) const uint BK = 16;
layout(constant_id = 3) const uint TM = 4;
layout(constant_id = 4) const uint TN = 8;

const uint WG_X = BN / TN;
const uint WG_Y = BM / TM;
const uint WG_SIZE = WG_X * WG_Y;
const uint VEC_K = BK / 4;

layout(local_size_x = 16, local_size_y = 8, local_size_z = 1) in;

layout (binding = 0) readonly buffer A_BUFFER { A_SCALAR_TYPE data_a[]; };
layout (binding = 1) readonly buffer B_BUFFER { float data_b[]; };
layout (binding = 2) writeonly buffer D_BUFFER { float data_d[]; };

layout (push_constant) uniform parameter
{
uint M;
uint N;
uint K;
uint stride_a;
uint stride_b;
uint stride_d;
} p;

shared A_VEC4_TYPE buf_a[BM][VEC_K];
shared vec4 buf_b[BN][VEC_K];

void main() {
const uint lidx = gl_LocalInvocationID.x;
const uint lidy = gl_LocalInvocationID.y;
const uint lid = lidy * WG_X + lidx;

const uint group_m = gl_WorkGroupID.x * BM;
const uint group_n = gl_WorkGroupID.y * BN;

float sums[TM][TN];
[[unroll]]
for (uint i = 0; i < TM; i++) {
[[unroll]]
for (uint j = 0; j < TN; j++) {
sums[i][j] = 0.0f;
}
}

const uint num_k_tiles = (p.K + BK - 1) / BK;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not robust enough and might be wrong for adreno case, but it passes the tests on test-backend-ops feel like it shouldn't

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be hard to add a case or two with odd K. I suggest having relatively small M,N to avoid the error being hidden.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I misquoted, I was thinking more about the adreno case:

BM = 32, BK = 8 -> VEC_K = 2
WG_SIZE = 128
A_LOADS_PER_THREAD = (32 * 2) / 128 = 64 / 128 = 0

So theoretically it shouldn’t be able to load matrix A regardless of the dimensions, but the tests are passing so I’m a bit confused

const uint A_LOADS_PER_THREAD = (BM * VEC_K) / WG_SIZE;
const uint B_LOADS_PER_THREAD = (BN * VEC_K) / WG_SIZE;

for (uint t = 0; t < num_k_tiles; t++) {
const uint k_tile_start = t * BK;

[[unroll]]
for(uint i = 0; i < A_LOADS_PER_THREAD; ++i) {
uint load_idx = lid + i * WG_SIZE;
uint m = load_idx / VEC_K;
uint k = load_idx % VEC_K;
uint global_m = group_m + m;
uint k_scalar = k_tile_start + k * 4;

if (global_m < p.M && k_scalar < p.K) {
uint base_idx = global_m * p.stride_a + k_scalar;
if (k_scalar + 3 < p.K) {
buf_a[m][k] = A_VEC4_TYPE(data_a[base_idx], data_a[base_idx+1], data_a[base_idx+2], data_a[base_idx+3]);
} else {
A_SCALAR_TYPE temp[4] = {A_SCALAR_TYPE(0), A_SCALAR_TYPE(0), A_SCALAR_TYPE(0), A_SCALAR_TYPE(0)};
if (k_scalar < p.K) temp[0] = data_a[base_idx];
if (k_scalar + 1 < p.K) temp[1] = data_a[base_idx+1];
if (k_scalar + 2 < p.K) temp[2] = data_a[base_idx+2];
buf_a[m][k] = A_VEC4_TYPE(temp[0], temp[1], temp[2], temp[3]);
}
} else {
buf_a[m][k] = A_VEC4_ZERO;
}
}

[[unroll]]
for(uint i = 0; i < B_LOADS_PER_THREAD; ++i) {
uint load_idx = lid + i * WG_SIZE;
uint n = load_idx / VEC_K;
uint k = load_idx % VEC_K;
uint global_n = group_n + n;
uint k_scalar = k_tile_start + k * 4;

if (global_n < p.N && k_scalar < p.K) {
uint base_idx = global_n * p.stride_b + k_scalar;
if (k_scalar + 3 < p.K) {
buf_b[n][k] = vec4(data_b[base_idx], data_b[base_idx+1], data_b[base_idx+2], data_b[base_idx+3]);
} else {
float temp[4] = {0.0f, 0.0f, 0.0f, 0.0f};
if (k_scalar < p.K) temp[0] = data_b[base_idx];
if (k_scalar + 1 < p.K) temp[1] = data_b[base_idx+1];
if (k_scalar + 2 < p.K) temp[2] = data_b[base_idx+2];
buf_b[n][k] = vec4(temp[0], temp[1], temp[2], temp[3]);
}
} else {
buf_b[n][k] = vec4(0.0f);
}
}

barrier();

[[unroll]]
for (uint k = 0; k < VEC_K; k++) {
A_VEC4_TYPE a_reg[TM];
[[unroll]]
for (uint i = 0; i < TM; i++) {
a_reg[i] = buf_a[lidy + i * WG_Y][k];
}

vec4 b_reg[TN];
[[unroll]]
for (uint j = 0; j < TN; j++) {
b_reg[j] = buf_b[lidx + j * WG_X][k];
}

[[unroll]]
for (uint i = 0; i < TM; i++) {
vec4 a_f32 = A_VEC4_CAST(a_reg[i]);

sums[i][0] += a_f32.x * b_reg[0].x + a_f32.y * b_reg[0].y + a_f32.z * b_reg[0].z + a_f32.w * b_reg[0].w;
sums[i][1] += a_f32.x * b_reg[1].x + a_f32.y * b_reg[1].y + a_f32.z * b_reg[1].z + a_f32.w * b_reg[1].w;
sums[i][2] += a_f32.x * b_reg[2].x + a_f32.y * b_reg[2].y + a_f32.z * b_reg[2].z + a_f32.w * b_reg[2].w;
sums[i][3] += a_f32.x * b_reg[3].x + a_f32.y * b_reg[3].y + a_f32.z * b_reg[3].z + a_f32.w * b_reg[3].w;
sums[i][4] += a_f32.x * b_reg[4].x + a_f32.y * b_reg[4].y + a_f32.z * b_reg[4].z + a_f32.w * b_reg[4].w;
sums[i][5] += a_f32.x * b_reg[5].x + a_f32.y * b_reg[5].y + a_f32.z * b_reg[5].z + a_f32.w * b_reg[5].w;
sums[i][6] += a_f32.x * b_reg[6].x + a_f32.y * b_reg[6].y + a_f32.z * b_reg[6].z + a_f32.w * b_reg[6].w;
sums[i][7] += a_f32.x * b_reg[7].x + a_f32.y * b_reg[7].y + a_f32.z * b_reg[7].z + a_f32.w * b_reg[7].w;
}
}
barrier();
}

[[unroll]]
for (uint i = 0; i < TM; i++) {
uint global_m = group_m + lidy + i * WG_Y;
if (global_m < p.M) {
[[unroll]]
for (uint j = 0; j < TN; j++) {
uint global_n = group_n + lidx + j * WG_X;
if (global_n < p.N) {
data_d[global_n * p.stride_d + global_m] = sums[i][j];
}
}
}
}
}
3 changes: 3 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,9 @@ void process_shaders() {
}
}

string_to_spv("mul_mat_f16_f32_embed", "mul_mm_embed.comp", {{"A_TYPE_FP16", "1"}});
string_to_spv("mul_mat_f32_f32_embed", "mul_mm_embed.comp", {});

// flash attention
for (const auto& f16acc : {false, true}) {
std::map<std::string, std::string> fa_base_dict = base_dict;
Expand Down
Loading