diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 7e9c3c8c7a0..d5f54915a33 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1301,7 +1301,7 @@ extern "C" { struct ggml_tensor * b); // change the precision of a matrix multiplication - // set to GGML_PREC_F32 for higher precision (useful for phi-2) + // set to GGML_PREC_F32 for higher precision GGML_API void ggml_mul_mat_set_prec( struct ggml_tensor * a, enum ggml_prec prec); @@ -1313,6 +1313,12 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * ids); + // change the precision of an indirect matrix multiplication + // set to GGML_PREC_F32 for higher precision + GGML_API void ggml_mul_mat_id_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + // A: m columns, n rows, // B: p columns, n rows, // result is m columns, p rows diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 04ad664e61c..f1969db1473 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4657,8 +4657,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co return nullptr; } - // XXX TODO 'prec' is not actually allowed in mul_mat_id. - bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/; + bool prefer_fp16acc = ctx->device->fp16 && prec == GGML_PREC_DEFAULT; bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr; bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d76ea58f789..ff0588d1b24 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3081,6 +3081,16 @@ struct ggml_tensor * ggml_mul_mat_id( return result; } +void ggml_mul_mat_id_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT_ID); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 0, prec_i32); +} + // ggml_out_prod static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b928e9e16ea..d554d2f0d86 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -542,8 +542,10 @@ ggml_tensor * llm_graph_context::build_cvec( ggml_tensor * llm_graph_context::build_lora_mm( ggml_tensor * w, - ggml_tensor * cur) const { + ggml_tensor * cur, + ggml_prec prec) const { ggml_tensor * res = ggml_mul_mat(ctx0, w, cur); + ggml_mul_mat_set_prec(res, prec); for (const auto & lora : *loras) { llama_adapter_lora_weight * lw = lora.first->get_weight(w); @@ -569,8 +571,11 @@ ggml_tensor * llm_graph_context::build_lora_mm( ggml_tensor * llm_graph_context::build_lora_mm_id( ggml_tensor * w, // ggml_tensor * as ggml_tensor * cur, // ggml_tensor * b - ggml_tensor * ids) const { + ggml_tensor * ids, + ggml_prec prec) const { ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids); + ggml_mul_mat_id_set_prec(res, prec); + for (const auto & lora : *loras) { llama_adapter_lora_weight * lw = lora.first->get_weight(w); if (lw == nullptr) { @@ -750,11 +755,7 @@ ggml_tensor * llm_graph_context::build_ffn( } if (down) { - cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { - // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators - ggml_mul_mat_set_prec(cur, GGML_PREC_F32); - } + cur = build_lora_mm(down, cur, GGML_PREC_F32); } if (down_b) { @@ -978,7 +979,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( GGML_ABORT("fatal error"); } - experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] + experts = build_lora_mm_id(down_exps, cur, selected_experts, GGML_PREC_F32); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); if (down_exps_b) { @@ -1475,11 +1476,7 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (wo) { - cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { - // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators - ggml_mul_mat_set_prec(cur, GGML_PREC_F32); - } + cur = build_lora_mm(wo, cur, GGML_PREC_F32); } if (wo_b) { @@ -1542,7 +1539,7 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (wo) { - cur = build_lora_mm(wo, cur); + cur = build_lora_mm(wo, cur, GGML_PREC_F32); } if (wo_b) { diff --git a/src/llama-graph.h b/src/llama-graph.h index e11d91d5293..6d4dec5cc18 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -589,13 +589,15 @@ struct llm_graph_context { // do mat_mul, while optionally apply lora ggml_tensor * build_lora_mm( ggml_tensor * w, - ggml_tensor * cur) const; + ggml_tensor * cur, + ggml_prec prec = GGML_PREC_DEFAULT) const; // do mat_mul_id, while optionally apply lora ggml_tensor * build_lora_mm_id( ggml_tensor * w, // ggml_tensor * as ggml_tensor * cur, // ggml_tensor * b - ggml_tensor * ids) const; + ggml_tensor * ids, + ggml_prec prec = GGML_PREC_DEFAULT) const; ggml_tensor * build_norm( ggml_tensor * cur,