Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,7 @@ extern "C" {
struct ggml_tensor * b);

// change the precision of a matrix multiplication
// set to GGML_PREC_F32 for higher precision (useful for phi-2)
// set to GGML_PREC_F32 for higher precision
GGML_API void ggml_mul_mat_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec);
Expand All @@ -1313,6 +1313,12 @@ extern "C" {
struct ggml_tensor * b,
struct ggml_tensor * ids);

// change the precision of an indirect matrix multiplication
// set to GGML_PREC_F32 for higher precision
GGML_API void ggml_mul_mat_id_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec);

// A: m columns, n rows,
// B: p columns, n rows,
// result is m columns, p rows
Expand Down
3 changes: 1 addition & 2 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4657,8 +4657,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
return nullptr;
}

// XXX TODO 'prec' is not actually allowed in mul_mat_id.
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
bool prefer_fp16acc = ctx->device->fp16 && prec == GGML_PREC_DEFAULT;
bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr;
bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr;

Expand Down
10 changes: 10 additions & 0 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3081,6 +3081,16 @@ struct ggml_tensor * ggml_mul_mat_id(
return result;
}

void ggml_mul_mat_id_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec) {
GGML_ASSERT(a->op == GGML_OP_MUL_MAT_ID);

const int32_t prec_i32 = (int32_t) prec;

ggml_set_op_params_i32(a, 0, prec_i32);
}

// ggml_out_prod

static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
Expand Down
25 changes: 11 additions & 14 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,8 +542,10 @@ ggml_tensor * llm_graph_context::build_cvec(

ggml_tensor * llm_graph_context::build_lora_mm(
ggml_tensor * w,
ggml_tensor * cur) const {
ggml_tensor * cur,
ggml_prec prec) const {
ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
ggml_mul_mat_set_prec(res, prec);

for (const auto & lora : *loras) {
llama_adapter_lora_weight * lw = lora.first->get_weight(w);
Expand All @@ -569,8 +571,11 @@ ggml_tensor * llm_graph_context::build_lora_mm(
ggml_tensor * llm_graph_context::build_lora_mm_id(
ggml_tensor * w, // ggml_tensor * as
ggml_tensor * cur, // ggml_tensor * b
ggml_tensor * ids) const {
ggml_tensor * ids,
ggml_prec prec) const {
ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
ggml_mul_mat_id_set_prec(res, prec);

for (const auto & lora : *loras) {
llama_adapter_lora_weight * lw = lora.first->get_weight(w);
if (lw == nullptr) {
Expand Down Expand Up @@ -750,11 +755,7 @@ ggml_tensor * llm_graph_context::build_ffn(
}

if (down) {
cur = build_lora_mm(down, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
cur = build_lora_mm(down, cur, GGML_PREC_F32);
}

if (down_b) {
Expand Down Expand Up @@ -978,7 +979,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
GGML_ABORT("fatal error");
}

experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
experts = build_lora_mm_id(down_exps, cur, selected_experts, GGML_PREC_F32); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);

if (down_exps_b) {
Expand Down Expand Up @@ -1475,11 +1476,7 @@ ggml_tensor * llm_graph_context::build_attn(
cb(cur, "kqv_out", il);

if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
cur = build_lora_mm(wo, cur, GGML_PREC_F32);
}

if (wo_b) {
Expand Down Expand Up @@ -1542,7 +1539,7 @@ ggml_tensor * llm_graph_context::build_attn(
cb(cur, "kqv_out", il);

if (wo) {
cur = build_lora_mm(wo, cur);
cur = build_lora_mm(wo, cur, GGML_PREC_F32);
}

if (wo_b) {
Expand Down
6 changes: 4 additions & 2 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -589,13 +589,15 @@ struct llm_graph_context {
// do mat_mul, while optionally apply lora
ggml_tensor * build_lora_mm(
ggml_tensor * w,
ggml_tensor * cur) const;
ggml_tensor * cur,
ggml_prec prec = GGML_PREC_DEFAULT) const;

// do mat_mul_id, while optionally apply lora
ggml_tensor * build_lora_mm_id(
ggml_tensor * w, // ggml_tensor * as
ggml_tensor * cur, // ggml_tensor * b
ggml_tensor * ids) const;
ggml_tensor * ids,
ggml_prec prec = GGML_PREC_DEFAULT) const;

ggml_tensor * build_norm(
ggml_tensor * cur,
Expand Down
Loading