diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index cd1c66ba7b476..f92686566efb0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -77,6 +77,9 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_APPLE 0x106b #define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_NVIDIA 0x10de +#define VK_VENDOR_ID_ARM 0x13B5 +#define VK_VENDOR_ID_QUALCOMM 0x5143 + #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 @@ -448,6 +451,8 @@ struct vk_device_struct { vk_matmul_pipeline pipeline_matmul_bf16 {}; vk_matmul_pipeline2 pipeline_matmul_f16; vk_matmul_pipeline2 pipeline_matmul_f16_f32; + vk_pipeline pipeline_matmul_f16_f32_embed; + vk_pipeline pipeline_matmul_f32_f32_embed; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; @@ -2901,6 +2906,40 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); } } + + if (device->vendor_id == VK_VENDOR_ID_ARM || device->vendor_id == VK_VENDOR_ID_QUALCOMM) { + // Shader workgroup size is 16x8 = 128 + const uint32_t wg_x = 16; + const uint32_t wg_y = 8; + + // Tile sizes for the workgroup + uint32_t bm, bn, bk; + + if (device->vendor_id == VK_VENDOR_ID_QUALCOMM) { + bm = 32; + bn = 128; + bk = 8; + } else { + bm = 64; + bn = 64; + bk = 16; + } + + // Threads per tile + const uint32_t tm = bm / wg_y; + const uint32_t tn = bn / wg_x; + + const std::vector embed_spec_constants = {bm, bn, bk, tm, tn}; + const std::array embed_wg_denoms = {bm, bn, 1}; + + ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32_embed, "mul_mat_f16_f32_embed", + mul_mat_f16_f32_embed_len, mul_mat_f16_f32_embed_data, "main", 3, + sizeof(vk_mat_mat_push_constants), embed_wg_denoms, embed_spec_constants, 1); + ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f32_embed, "mul_mat_f32_f32_embed", + mul_mat_f32_f32_embed_len, mul_mat_f32_f32_embed_data, "main", 3, + sizeof(vk_mat_mat_push_constants), embed_wg_denoms, embed_spec_constants, 1); + } + // reusing CREATE_MM from the fp32 path if ((device->coopmat2 || device->coopmat_support) #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) @@ -5726,6 +5765,114 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const uint64_t ne12 = src1->ne[2]; const uint64_t ne13 = src1->ne[3]; + if ((ctx->device->vendor_id == VK_VENDOR_ID_ARM || ctx->device->vendor_id == VK_VENDOR_ID_QUALCOMM) && + (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && + src1->type == GGML_TYPE_F32 && ggml_vk_dim01_contiguous(src1) && + ne02 == 1 && ne03 == 1 && + ne12 == 1 && ne13 == 1) { + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + vk_buffer d_Qx = nullptr, d_Qy = nullptr, d_D = nullptr; + size_t qx_buf_offset = 0, qy_buf_offset = 0, d_buf_offset = 0; + bool src0_uma = false, src1_uma = false, dst_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + dst_uma = d_D != nullptr; + } + + if (!src0_uma) { d_Qx = src0_buf_ctx->dev_buffer; qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; } + if (!src1_uma) { d_Qy = src1_buf_ctx->dev_buffer; qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; } + if (!dst_uma) { d_D = dst_buf_ctx->dev_buffer; d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; } + + const uint32_t M = ne01; + const uint32_t N = ne11; + const uint32_t K = ne10; + + vk_pipeline pipeline = nullptr; + vk_buffer d_X; + uint64_t x_buf_offset; + uint32_t stride_a; + bool dequantized = false; + + if (ggml_is_quantized(src0->type)) { + vk_pipeline dequant_pipeline = ggml_vk_get_to_fp16(ctx, src0->type); + + if (dequant_pipeline) { + dequantized = true; + const uint64_t x_sz = sizeof(ggml_fp16_t) * M * K; + + if (dryrun) { + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + } + ggml_pipeline_request_descriptor_sets(ctx, dequant_pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_f16_f32_embed, 1); + return; + } + + const std::vector pc = { (uint32_t)M, (uint32_t)K, (uint32_t)K, (uint32_t)K, (uint32_t)(ggml_nelements(src0)) }; + + ggml_vk_dispatch_pipeline(ctx, subctx, dequant_pipeline, { + vk_subbuffer{ d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, + vk_subbuffer{ ctx->prealloc_x, 0, VK_WHOLE_SIZE } + }, pc, { (uint32_t)(ggml_nelements(src0)), 1, 1}); + + d_X = ctx->prealloc_x; + x_buf_offset = 0; + stride_a = K; + pipeline = ctx->device->pipeline_matmul_f16_f32_embed; + } + } else { + if (src0->type == GGML_TYPE_F16) { + pipeline = ctx->device->pipeline_matmul_f16_f32_embed; + } else { + pipeline = ctx->device->pipeline_matmul_f32_f32_embed; + } + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + stride_a = src0->nb[1] / ggml_type_size(src0->type); + } + + if (pipeline != nullptr) { + if (dequantized) { + ggml_vk_sync_buffers(ctx, subctx); // Ensure dequant is finished + } + + const uint32_t stride_b = src1->nb[1] / ggml_type_size(src1->type); + const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type); + + const vk_mat_mat_push_constants pc = { M, N, K, stride_a, stride_b, stride_d, M * K, K * N, M * N, K, 1, 1, 1, 1, N }; + + vk_buffer d_Y = d_Qy; + const uint64_t y_buf_offset = qy_buf_offset; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_X, x_buf_offset, VK_WHOLE_SIZE }, + vk_subbuffer{ d_Y, y_buf_offset, VK_WHOLE_SIZE }, + vk_subbuffer{ d_D, d_buf_offset, VK_WHOLE_SIZE }, + }, pc, { M, N, 1 }); + + if (dequantized) { + ctx->prealloc_x_need_sync = true; + } + + return; + } + } + const uint64_t ne20 = dst->ne[0]; const uint64_t ne21 = dst->ne[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp new file mode 100644 index 0000000000000..c67ad5a45bb3a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp @@ -0,0 +1,167 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_control_flow_attributes : require + +#ifdef A_TYPE_FP16 + #define A_VEC4_TYPE f16vec4 + #define A_SCALAR_TYPE float16_t + #define A_VEC4_ZERO f16vec4(0.0hf) + #define A_VEC4_CAST(v) vec4(v) +#else + #define A_VEC4_TYPE vec4 + #define A_SCALAR_TYPE float + #define A_VEC4_ZERO vec4(0.0f) + #define A_VEC4_CAST(v) (v) +#endif + +layout(constant_id = 0) const uint BM = 64; +layout(constant_id = 1) const uint BN = 64; +layout(constant_id = 2) const uint BK = 16; +layout(constant_id = 3) const uint TM = 4; +layout(constant_id = 4) const uint TN = 8; + +const uint WG_X = BN / TN; +const uint WG_Y = BM / TM; +const uint WG_SIZE = WG_X * WG_Y; +const uint VEC_K = BK / 4; + +layout(local_size_x = 16, local_size_y = 8, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A_BUFFER { A_SCALAR_TYPE data_a[]; }; +layout (binding = 1) readonly buffer B_BUFFER { float data_b[]; }; +layout (binding = 2) writeonly buffer D_BUFFER { float data_d[]; }; + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; +} p; + +shared A_VEC4_TYPE buf_a[BM][VEC_K]; +shared vec4 buf_b[BN][VEC_K]; + +void main() { + const uint lidx = gl_LocalInvocationID.x; + const uint lidy = gl_LocalInvocationID.y; + const uint lid = lidy * WG_X + lidx; + + const uint group_m = gl_WorkGroupID.x * BM; + const uint group_n = gl_WorkGroupID.y * BN; + + float sums[TM][TN]; + [[unroll]] + for (uint i = 0; i < TM; i++) { + [[unroll]] + for (uint j = 0; j < TN; j++) { + sums[i][j] = 0.0f; + } + } + + const uint num_k_tiles = (p.K + BK - 1) / BK; + const uint A_LOADS_PER_THREAD = (BM * VEC_K) / WG_SIZE; + const uint B_LOADS_PER_THREAD = (BN * VEC_K) / WG_SIZE; + + for (uint t = 0; t < num_k_tiles; t++) { + const uint k_tile_start = t * BK; + + [[unroll]] + for(uint i = 0; i < A_LOADS_PER_THREAD; ++i) { + uint load_idx = lid + i * WG_SIZE; + uint m = load_idx / VEC_K; + uint k = load_idx % VEC_K; + uint global_m = group_m + m; + uint k_scalar = k_tile_start + k * 4; + + if (global_m < p.M && k_scalar < p.K) { + uint base_idx = global_m * p.stride_a + k_scalar; + if (k_scalar + 3 < p.K) { + buf_a[m][k] = A_VEC4_TYPE(data_a[base_idx], data_a[base_idx+1], data_a[base_idx+2], data_a[base_idx+3]); + } else { + A_SCALAR_TYPE temp[4] = {A_SCALAR_TYPE(0), A_SCALAR_TYPE(0), A_SCALAR_TYPE(0), A_SCALAR_TYPE(0)}; + if (k_scalar < p.K) temp[0] = data_a[base_idx]; + if (k_scalar + 1 < p.K) temp[1] = data_a[base_idx+1]; + if (k_scalar + 2 < p.K) temp[2] = data_a[base_idx+2]; + buf_a[m][k] = A_VEC4_TYPE(temp[0], temp[1], temp[2], temp[3]); + } + } else { + buf_a[m][k] = A_VEC4_ZERO; + } + } + + [[unroll]] + for(uint i = 0; i < B_LOADS_PER_THREAD; ++i) { + uint load_idx = lid + i * WG_SIZE; + uint n = load_idx / VEC_K; + uint k = load_idx % VEC_K; + uint global_n = group_n + n; + uint k_scalar = k_tile_start + k * 4; + + if (global_n < p.N && k_scalar < p.K) { + uint base_idx = global_n * p.stride_b + k_scalar; + if (k_scalar + 3 < p.K) { + buf_b[n][k] = vec4(data_b[base_idx], data_b[base_idx+1], data_b[base_idx+2], data_b[base_idx+3]); + } else { + float temp[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + if (k_scalar < p.K) temp[0] = data_b[base_idx]; + if (k_scalar + 1 < p.K) temp[1] = data_b[base_idx+1]; + if (k_scalar + 2 < p.K) temp[2] = data_b[base_idx+2]; + buf_b[n][k] = vec4(temp[0], temp[1], temp[2], temp[3]); + } + } else { + buf_b[n][k] = vec4(0.0f); + } + } + + barrier(); + + [[unroll]] + for (uint k = 0; k < VEC_K; k++) { + A_VEC4_TYPE a_reg[TM]; + [[unroll]] + for (uint i = 0; i < TM; i++) { + a_reg[i] = buf_a[lidy + i * WG_Y][k]; + } + + vec4 b_reg[TN]; + [[unroll]] + for (uint j = 0; j < TN; j++) { + b_reg[j] = buf_b[lidx + j * WG_X][k]; + } + + [[unroll]] + for (uint i = 0; i < TM; i++) { + vec4 a_f32 = A_VEC4_CAST(a_reg[i]); + + sums[i][0] += a_f32.x * b_reg[0].x + a_f32.y * b_reg[0].y + a_f32.z * b_reg[0].z + a_f32.w * b_reg[0].w; + sums[i][1] += a_f32.x * b_reg[1].x + a_f32.y * b_reg[1].y + a_f32.z * b_reg[1].z + a_f32.w * b_reg[1].w; + sums[i][2] += a_f32.x * b_reg[2].x + a_f32.y * b_reg[2].y + a_f32.z * b_reg[2].z + a_f32.w * b_reg[2].w; + sums[i][3] += a_f32.x * b_reg[3].x + a_f32.y * b_reg[3].y + a_f32.z * b_reg[3].z + a_f32.w * b_reg[3].w; + sums[i][4] += a_f32.x * b_reg[4].x + a_f32.y * b_reg[4].y + a_f32.z * b_reg[4].z + a_f32.w * b_reg[4].w; + sums[i][5] += a_f32.x * b_reg[5].x + a_f32.y * b_reg[5].y + a_f32.z * b_reg[5].z + a_f32.w * b_reg[5].w; + sums[i][6] += a_f32.x * b_reg[6].x + a_f32.y * b_reg[6].y + a_f32.z * b_reg[6].z + a_f32.w * b_reg[6].w; + sums[i][7] += a_f32.x * b_reg[7].x + a_f32.y * b_reg[7].y + a_f32.z * b_reg[7].z + a_f32.w * b_reg[7].w; + } + } + barrier(); + } + + [[unroll]] + for (uint i = 0; i < TM; i++) { + uint global_m = group_m + lidy + i * WG_Y; + if (global_m < p.M) { + [[unroll]] + for (uint j = 0; j < TN; j++) { + uint global_n = group_n + lidx + j * WG_X; + if (global_n < p.N) { + data_d[global_n * p.stride_d + global_m] = sums[i][j]; + } + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 613498d0d50b7..997a904f499de 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -454,6 +454,9 @@ void process_shaders() { } } + string_to_spv("mul_mat_f16_f32_embed", "mul_mm_embed.comp", {{"A_TYPE_FP16", "1"}}); + string_to_spv("mul_mat_f32_f32_embed", "mul_mm_embed.comp", {}); + // flash attention for (const auto& f16acc : {false, true}) { std::map fa_base_dict = base_dict;