Skip to content
19 changes: 12 additions & 7 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,12 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
struct vk_matmul_pipeline_struct {
vk_pipeline l, m, s;
vk_pipeline a_l, a_m, a_s;
// Returns true when all member pipelines are null
bool is_empty() const {
return l == nullptr && m == nullptr && s == nullptr &&
a_l == nullptr && a_m == nullptr && a_s == nullptr;
}
};

typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;

struct vk_matmul_pipeline2 {
Expand Down Expand Up @@ -5079,7 +5083,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;

if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
if (pipelines->is_empty()) {
return nullptr;
}

Expand Down Expand Up @@ -5228,7 +5232,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;

if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
if (pipelines->is_empty()) {
return nullptr;
}

Expand Down Expand Up @@ -5263,16 +5267,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
return nullptr;
}

vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
// XXX TODO 'prec' is not actually allowed in mul_mat_id.
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr;
bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr;
bool support_fp16acc = !mmp.f16acc->is_empty();
bool support_fp32acc = !mmp.f32acc->is_empty();

if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc;
return mmp.f16acc;
} else {
GGML_ASSERT(support_fp32acc);
return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
return mmp.f32acc;
}
}

Expand Down
Loading