diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b61879aa5d312..c3e5a9eccc3aa 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -145,8 +145,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); struct vk_matmul_pipeline_struct { vk_pipeline l, m, s; vk_pipeline a_l, a_m, a_s; + // Returns true when all unaligned pipelines are null. + // We only check for unaligned variants since one of the unaligned pipelines must exist + // while aligned pipelines are optional + bool is_empty() const { + return l == nullptr && m == nullptr && s == nullptr; + } }; - typedef std::shared_ptr vk_matmul_pipeline; struct vk_matmul_pipeline2 { @@ -5079,7 +5084,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte if (src1_type == GGML_TYPE_Q8_1) { vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; - if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { + if (pipelines->is_empty()) { return nullptr; } @@ -5228,7 +5233,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co if (src1_type == GGML_TYPE_Q8_1) { vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc; - if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { + if (pipelines->is_empty()) { return nullptr; } @@ -5263,16 +5268,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co return nullptr; } + vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type]; // XXX TODO 'prec' is not actually allowed in mul_mat_id. bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/; - bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr; - bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr; + bool support_fp16acc = !mmp.f16acc->is_empty(); + bool support_fp32acc = !mmp.f32acc->is_empty(); if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) { - return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc; + return mmp.f16acc; } else { GGML_ASSERT(support_fp32acc); - return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc; + return mmp.f32acc; } }