@@ -145,8 +145,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
145145struct vk_matmul_pipeline_struct {
146146 vk_pipeline l, m, s;
147147 vk_pipeline a_l, a_m, a_s;
148+ // Returns true when all unaligned pipelines are null.
149+ // We only check for unaligned variants since one of the unaligned pipelines must exist
150+ // while aligned pipelines are optional
151+ bool is_empty() const {
152+ return l == nullptr && m == nullptr && s == nullptr;
153+ }
148154};
149-
150155typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
151156
152157struct vk_matmul_pipeline2 {
@@ -5079,7 +5084,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
50795084 if (src1_type == GGML_TYPE_Q8_1) {
50805085 vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
50815086
5082- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr ) {
5087+ if (pipelines->is_empty() ) {
50835088 return nullptr;
50845089 }
50855090
@@ -5228,7 +5233,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
52285233 if (src1_type == GGML_TYPE_Q8_1) {
52295234 vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
52305235
5231- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr ) {
5236+ if (pipelines->is_empty() ) {
52325237 return nullptr;
52335238 }
52345239
@@ -5263,16 +5268,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
52635268 return nullptr;
52645269 }
52655270
5271+ vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
52665272 // XXX TODO 'prec' is not actually allowed in mul_mat_id.
52675273 bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
5268- bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type] .f16acc != nullptr ;
5269- bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type] .f32acc != nullptr ;
5274+ bool support_fp16acc = !mmp .f16acc->is_empty() ;
5275+ bool support_fp32acc = !mmp .f32acc->is_empty() ;
52705276
52715277 if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
5272- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type] .f16acc;
5278+ return mmp .f16acc;
52735279 } else {
52745280 GGML_ASSERT(support_fp32acc);
5275- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type] .f32acc;
5281+ return mmp .f32acc;
52765282 }
52775283}
52785284
0 commit comments