-
Notifications
You must be signed in to change notification settings - Fork 13.4k
vulkan: add mul_mat variant for embedded gpus #15800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -77,6 +77,9 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } | |
| #define VK_VENDOR_ID_APPLE 0x106b | ||
| #define VK_VENDOR_ID_INTEL 0x8086 | ||
| #define VK_VENDOR_ID_NVIDIA 0x10de | ||
| #define VK_VENDOR_ID_ARM 0x13B5 | ||
| #define VK_VENDOR_ID_QUALCOMM 0x5143 | ||
|
|
||
|
|
||
| #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 | ||
|
|
||
|
|
@@ -448,6 +451,8 @@ struct vk_device_struct { | |
| vk_matmul_pipeline pipeline_matmul_bf16 {}; | ||
| vk_matmul_pipeline2 pipeline_matmul_f16; | ||
| vk_matmul_pipeline2 pipeline_matmul_f16_f32; | ||
| vk_pipeline pipeline_matmul_f16_f32_embed; | ||
| vk_pipeline pipeline_matmul_f32_f32_embed; | ||
|
|
||
| vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; | ||
| vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; | ||
|
|
@@ -2901,6 +2906,40 @@ static void ggml_vk_load_shaders(vk_device& device) { | |
| CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); | ||
| } | ||
| } | ||
|
|
||
| if (device->vendor_id == VK_VENDOR_ID_ARM || device->vendor_id == VK_VENDOR_ID_QUALCOMM) { | ||
| // Shader workgroup size is 16x8 = 128 | ||
| const uint32_t wg_x = 16; | ||
| const uint32_t wg_y = 8; | ||
|
|
||
| // Tile sizes for the workgroup | ||
| uint32_t bm, bn, bk; | ||
|
|
||
| if (device->vendor_id == VK_VENDOR_ID_QUALCOMM) { | ||
| bm = 32; | ||
| bn = 128; | ||
| bk = 8; | ||
| } else { | ||
| bm = 64; | ||
| bn = 64; | ||
| bk = 16; | ||
| } | ||
|
|
||
| // Threads per tile | ||
| const uint32_t tm = bm / wg_y; | ||
| const uint32_t tn = bn / wg_x; | ||
|
|
||
| const std::vector<uint32_t> embed_spec_constants = {bm, bn, bk, tm, tn}; | ||
| const std::array<uint32_t, 3> embed_wg_denoms = {bm, bn, 1}; | ||
|
|
||
| ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32_embed, "mul_mat_f16_f32_embed", | ||
| mul_mat_f16_f32_embed_len, mul_mat_f16_f32_embed_data, "main", 3, | ||
| sizeof(vk_mat_mat_push_constants), embed_wg_denoms, embed_spec_constants, 1); | ||
| ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f32_embed, "mul_mat_f32_f32_embed", | ||
| mul_mat_f32_f32_embed_len, mul_mat_f32_f32_embed_data, "main", 3, | ||
| sizeof(vk_mat_mat_push_constants), embed_wg_denoms, embed_spec_constants, 1); | ||
| } | ||
|
|
||
| // reusing CREATE_MM from the fp32 path | ||
| if ((device->coopmat2 || device->coopmat_support) | ||
| #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) | ||
|
|
@@ -5726,6 +5765,114 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub | |
| const uint64_t ne12 = src1->ne[2]; | ||
| const uint64_t ne13 = src1->ne[3]; | ||
|
|
||
| if ((ctx->device->vendor_id == VK_VENDOR_ID_ARM || ctx->device->vendor_id == VK_VENDOR_ID_QUALCOMM) && | ||
| (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && | ||
| src1->type == GGML_TYPE_F32 && ggml_vk_dim01_contiguous(src1) && | ||
| ne02 == 1 && ne03 == 1 && | ||
| ne12 == 1 && ne13 == 1) { | ||
rmatif marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; | ||
| ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; | ||
| ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; | ||
| vk_buffer d_Qx = nullptr, d_Qy = nullptr, d_D = nullptr; | ||
| size_t qx_buf_offset = 0, qy_buf_offset = 0, d_buf_offset = 0; | ||
| bool src0_uma = false, src1_uma = false, dst_uma = false; | ||
|
|
||
| if (ctx->device->uma) { | ||
| ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); | ||
| ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); | ||
| ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset); | ||
| src0_uma = d_Qx != nullptr; | ||
| src1_uma = d_Qy != nullptr; | ||
| dst_uma = d_D != nullptr; | ||
| } | ||
|
|
||
| if (!src0_uma) { d_Qx = src0_buf_ctx->dev_buffer; qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; } | ||
| if (!src1_uma) { d_Qy = src1_buf_ctx->dev_buffer; qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; } | ||
| if (!dst_uma) { d_D = dst_buf_ctx->dev_buffer; d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; } | ||
|
|
||
| const uint32_t M = ne01; | ||
| const uint32_t N = ne11; | ||
| const uint32_t K = ne10; | ||
|
|
||
| vk_pipeline pipeline = nullptr; | ||
| vk_buffer d_X; | ||
| uint64_t x_buf_offset; | ||
| uint32_t stride_a; | ||
| bool dequantized = false; | ||
|
|
||
| if (ggml_is_quantized(src0->type)) { | ||
| vk_pipeline dequant_pipeline = ggml_vk_get_to_fp16(ctx, src0->type); | ||
|
|
||
| if (dequant_pipeline) { | ||
| dequantized = true; | ||
| const uint64_t x_sz = sizeof(ggml_fp16_t) * M * K; | ||
|
|
||
| if (dryrun) { | ||
| if (ctx->prealloc_size_x < x_sz) { | ||
| ctx->prealloc_size_x = x_sz; | ||
| } | ||
| ggml_pipeline_request_descriptor_sets(ctx, dequant_pipeline, 1); | ||
| ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_f16_f32_embed, 1); | ||
| return; | ||
| } | ||
|
|
||
| const std::vector<uint32_t> pc = { (uint32_t)M, (uint32_t)K, (uint32_t)K, (uint32_t)K, (uint32_t)(ggml_nelements(src0)) }; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this path is not handling noncontiguous src0. Like @0cc4m said, it'll be better to let this run through the existing code paths rather than having this separate code path. |
||
|
|
||
| ggml_vk_dispatch_pipeline(ctx, subctx, dequant_pipeline, { | ||
| vk_subbuffer{ d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, | ||
| vk_subbuffer{ ctx->prealloc_x, 0, VK_WHOLE_SIZE } | ||
| }, pc, { (uint32_t)(ggml_nelements(src0)), 1, 1}); | ||
|
|
||
| d_X = ctx->prealloc_x; | ||
| x_buf_offset = 0; | ||
| stride_a = K; | ||
| pipeline = ctx->device->pipeline_matmul_f16_f32_embed; | ||
| } | ||
| } else { | ||
| if (src0->type == GGML_TYPE_F16) { | ||
| pipeline = ctx->device->pipeline_matmul_f16_f32_embed; | ||
| } else { | ||
| pipeline = ctx->device->pipeline_matmul_f32_f32_embed; | ||
| } | ||
|
|
||
| if (dryrun) { | ||
| ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); | ||
| return; | ||
| } | ||
|
|
||
| d_X = d_Qx; | ||
| x_buf_offset = qx_buf_offset; | ||
| stride_a = src0->nb[1] / ggml_type_size(src0->type); | ||
| } | ||
|
|
||
| if (pipeline != nullptr) { | ||
| if (dequantized) { | ||
| ggml_vk_sync_buffers(ctx, subctx); // Ensure dequant is finished | ||
| } | ||
|
|
||
| const uint32_t stride_b = src1->nb[1] / ggml_type_size(src1->type); | ||
| const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type); | ||
|
|
||
| const vk_mat_mat_push_constants pc = { M, N, K, stride_a, stride_b, stride_d, M * K, K * N, M * N, K, 1, 1, 1, 1, N }; | ||
|
|
||
| vk_buffer d_Y = d_Qy; | ||
| const uint64_t y_buf_offset = qy_buf_offset; | ||
|
|
||
| ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { | ||
| vk_subbuffer{ d_X, x_buf_offset, VK_WHOLE_SIZE }, | ||
| vk_subbuffer{ d_Y, y_buf_offset, VK_WHOLE_SIZE }, | ||
| vk_subbuffer{ d_D, d_buf_offset, VK_WHOLE_SIZE }, | ||
| }, pc, { M, N, 1 }); | ||
|
|
||
rmatif marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if (dequantized) { | ||
| ctx->prealloc_x_need_sync = true; | ||
| } | ||
|
|
||
| return; | ||
| } | ||
| } | ||
|
|
||
| const uint64_t ne20 = dst->ne[0]; | ||
| const uint64_t ne21 = dst->ne[1]; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,167 @@ | ||
| #version 450 | ||
|
|
||
| #extension GL_EXT_shader_16bit_storage : require | ||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require | ||
| #extension GL_EXT_control_flow_attributes : require | ||
|
|
||
| #ifdef A_TYPE_FP16 | ||
| #define A_VEC4_TYPE f16vec4 | ||
| #define A_SCALAR_TYPE float16_t | ||
| #define A_VEC4_ZERO f16vec4(0.0hf) | ||
| #define A_VEC4_CAST(v) vec4(v) | ||
| #else | ||
| #define A_VEC4_TYPE vec4 | ||
| #define A_SCALAR_TYPE float | ||
| #define A_VEC4_ZERO vec4(0.0f) | ||
| #define A_VEC4_CAST(v) (v) | ||
| #endif | ||
|
|
||
| layout(constant_id = 0) const uint BM = 64; | ||
| layout(constant_id = 1) const uint BN = 64; | ||
| layout(constant_id = 2) const uint BK = 16; | ||
| layout(constant_id = 3) const uint TM = 4; | ||
| layout(constant_id = 4) const uint TN = 8; | ||
|
|
||
| const uint WG_X = BN / TN; | ||
| const uint WG_Y = BM / TM; | ||
| const uint WG_SIZE = WG_X * WG_Y; | ||
| const uint VEC_K = BK / 4; | ||
|
|
||
| layout(local_size_x = 16, local_size_y = 8, local_size_z = 1) in; | ||
|
|
||
| layout (binding = 0) readonly buffer A_BUFFER { A_SCALAR_TYPE data_a[]; }; | ||
| layout (binding = 1) readonly buffer B_BUFFER { float data_b[]; }; | ||
| layout (binding = 2) writeonly buffer D_BUFFER { float data_d[]; }; | ||
|
|
||
| layout (push_constant) uniform parameter | ||
| { | ||
| uint M; | ||
| uint N; | ||
| uint K; | ||
| uint stride_a; | ||
| uint stride_b; | ||
| uint stride_d; | ||
| } p; | ||
|
|
||
| shared A_VEC4_TYPE buf_a[BM][VEC_K]; | ||
| shared vec4 buf_b[BN][VEC_K]; | ||
|
|
||
| void main() { | ||
| const uint lidx = gl_LocalInvocationID.x; | ||
| const uint lidy = gl_LocalInvocationID.y; | ||
| const uint lid = lidy * WG_X + lidx; | ||
|
|
||
| const uint group_m = gl_WorkGroupID.x * BM; | ||
| const uint group_n = gl_WorkGroupID.y * BN; | ||
|
|
||
| float sums[TM][TN]; | ||
| [[unroll]] | ||
| for (uint i = 0; i < TM; i++) { | ||
| [[unroll]] | ||
| for (uint j = 0; j < TN; j++) { | ||
| sums[i][j] = 0.0f; | ||
| } | ||
| } | ||
|
|
||
| const uint num_k_tiles = (p.K + BK - 1) / BK; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is not robust enough and might be wrong for adreno case, but it passes the tests on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't be hard to add a case or two with odd K. I suggest having relatively small M,N to avoid the error being hidden. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I misquoted, I was thinking more about the adreno case: So theoretically it shouldn’t be able to load matrix A regardless of the dimensions, but the tests are passing so I’m a bit confused |
||
| const uint A_LOADS_PER_THREAD = (BM * VEC_K) / WG_SIZE; | ||
| const uint B_LOADS_PER_THREAD = (BN * VEC_K) / WG_SIZE; | ||
|
|
||
| for (uint t = 0; t < num_k_tiles; t++) { | ||
| const uint k_tile_start = t * BK; | ||
|
|
||
| [[unroll]] | ||
| for(uint i = 0; i < A_LOADS_PER_THREAD; ++i) { | ||
| uint load_idx = lid + i * WG_SIZE; | ||
| uint m = load_idx / VEC_K; | ||
| uint k = load_idx % VEC_K; | ||
| uint global_m = group_m + m; | ||
| uint k_scalar = k_tile_start + k * 4; | ||
|
|
||
| if (global_m < p.M && k_scalar < p.K) { | ||
| uint base_idx = global_m * p.stride_a + k_scalar; | ||
| if (k_scalar + 3 < p.K) { | ||
| buf_a[m][k] = A_VEC4_TYPE(data_a[base_idx], data_a[base_idx+1], data_a[base_idx+2], data_a[base_idx+3]); | ||
| } else { | ||
| A_SCALAR_TYPE temp[4] = {A_SCALAR_TYPE(0), A_SCALAR_TYPE(0), A_SCALAR_TYPE(0), A_SCALAR_TYPE(0)}; | ||
| if (k_scalar < p.K) temp[0] = data_a[base_idx]; | ||
| if (k_scalar + 1 < p.K) temp[1] = data_a[base_idx+1]; | ||
| if (k_scalar + 2 < p.K) temp[2] = data_a[base_idx+2]; | ||
| buf_a[m][k] = A_VEC4_TYPE(temp[0], temp[1], temp[2], temp[3]); | ||
| } | ||
| } else { | ||
| buf_a[m][k] = A_VEC4_ZERO; | ||
| } | ||
| } | ||
|
|
||
| [[unroll]] | ||
| for(uint i = 0; i < B_LOADS_PER_THREAD; ++i) { | ||
| uint load_idx = lid + i * WG_SIZE; | ||
| uint n = load_idx / VEC_K; | ||
| uint k = load_idx % VEC_K; | ||
| uint global_n = group_n + n; | ||
| uint k_scalar = k_tile_start + k * 4; | ||
|
|
||
| if (global_n < p.N && k_scalar < p.K) { | ||
| uint base_idx = global_n * p.stride_b + k_scalar; | ||
| if (k_scalar + 3 < p.K) { | ||
| buf_b[n][k] = vec4(data_b[base_idx], data_b[base_idx+1], data_b[base_idx+2], data_b[base_idx+3]); | ||
| } else { | ||
| float temp[4] = {0.0f, 0.0f, 0.0f, 0.0f}; | ||
| if (k_scalar < p.K) temp[0] = data_b[base_idx]; | ||
| if (k_scalar + 1 < p.K) temp[1] = data_b[base_idx+1]; | ||
| if (k_scalar + 2 < p.K) temp[2] = data_b[base_idx+2]; | ||
| buf_b[n][k] = vec4(temp[0], temp[1], temp[2], temp[3]); | ||
| } | ||
| } else { | ||
| buf_b[n][k] = vec4(0.0f); | ||
| } | ||
| } | ||
|
|
||
| barrier(); | ||
|
|
||
| [[unroll]] | ||
| for (uint k = 0; k < VEC_K; k++) { | ||
| A_VEC4_TYPE a_reg[TM]; | ||
| [[unroll]] | ||
| for (uint i = 0; i < TM; i++) { | ||
| a_reg[i] = buf_a[lidy + i * WG_Y][k]; | ||
| } | ||
|
|
||
| vec4 b_reg[TN]; | ||
| [[unroll]] | ||
| for (uint j = 0; j < TN; j++) { | ||
| b_reg[j] = buf_b[lidx + j * WG_X][k]; | ||
| } | ||
|
|
||
| [[unroll]] | ||
| for (uint i = 0; i < TM; i++) { | ||
| vec4 a_f32 = A_VEC4_CAST(a_reg[i]); | ||
|
|
||
| sums[i][0] += a_f32.x * b_reg[0].x + a_f32.y * b_reg[0].y + a_f32.z * b_reg[0].z + a_f32.w * b_reg[0].w; | ||
| sums[i][1] += a_f32.x * b_reg[1].x + a_f32.y * b_reg[1].y + a_f32.z * b_reg[1].z + a_f32.w * b_reg[1].w; | ||
| sums[i][2] += a_f32.x * b_reg[2].x + a_f32.y * b_reg[2].y + a_f32.z * b_reg[2].z + a_f32.w * b_reg[2].w; | ||
| sums[i][3] += a_f32.x * b_reg[3].x + a_f32.y * b_reg[3].y + a_f32.z * b_reg[3].z + a_f32.w * b_reg[3].w; | ||
| sums[i][4] += a_f32.x * b_reg[4].x + a_f32.y * b_reg[4].y + a_f32.z * b_reg[4].z + a_f32.w * b_reg[4].w; | ||
| sums[i][5] += a_f32.x * b_reg[5].x + a_f32.y * b_reg[5].y + a_f32.z * b_reg[5].z + a_f32.w * b_reg[5].w; | ||
| sums[i][6] += a_f32.x * b_reg[6].x + a_f32.y * b_reg[6].y + a_f32.z * b_reg[6].z + a_f32.w * b_reg[6].w; | ||
| sums[i][7] += a_f32.x * b_reg[7].x + a_f32.y * b_reg[7].y + a_f32.z * b_reg[7].z + a_f32.w * b_reg[7].w; | ||
| } | ||
| } | ||
| barrier(); | ||
| } | ||
|
|
||
| [[unroll]] | ||
| for (uint i = 0; i < TM; i++) { | ||
| uint global_m = group_m + lidy + i * WG_Y; | ||
| if (global_m < p.M) { | ||
| [[unroll]] | ||
| for (uint j = 0; j < TN; j++) { | ||
| uint global_n = group_n + lidx + j * WG_X; | ||
| if (global_n < p.N) { | ||
| data_d[global_n * p.stride_d + global_m] = sums[i][j]; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems to have an impact also with Intel integrated GPUs in some cases: