Skip to content

Commit a0554c3

Browse files
committed
context : always use non-causal attention for encoder graphs
ggml-ci
1 parent d9a1452 commit a0554c3

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

src/llama-context.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1627,7 +1627,16 @@ llm_graph_result_ptr llama_context::graph_build(
16271627
ggml_cgraph * gf,
16281628
const llama_ubatch & ubatch,
16291629
llm_graph_type gtype) {
1630-
return model.build_graph(
1630+
const auto causal_attn_org = cparams.causal_attn;
1631+
1632+
// always use non-causal attention for encoder graphs
1633+
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
1634+
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
1635+
if (gtype == LLM_GRAPH_TYPE_ENCODER) {
1636+
cparams.causal_attn = false;
1637+
}
1638+
1639+
auto res = model.build_graph(
16311640
{
16321641
/*.ctx =*/ ctx,
16331642
/*.arch =*/ model.arch,
@@ -1643,6 +1652,12 @@ llm_graph_result_ptr llama_context::graph_build(
16431652
/*.n_outputs =*/ n_outputs,
16441653
/*.cb =*/ graph_get_cb(),
16451654
}, gf, gtype);
1655+
1656+
if (gtype == LLM_GRAPH_TYPE_ENCODER) {
1657+
cparams.causal_attn = causal_attn_org;
1658+
}
1659+
1660+
return res;
16461661
}
16471662

16481663
ggml_status llama_context::graph_compute(

0 commit comments

Comments
 (0)