@@ -1572,10 +1572,13 @@ void llm_graph_context::build_pooling(
15721572 ggml_tensor * inp_cls = build_inp_cls ();
15731573 inp = ggml_get_rows (ctx0, inp, inp_cls);
15741574
1575- if (cls != nullptr && cls_b != nullptr ) {
1575+ if (cls != nullptr ) {
15761576 // classification head
15771577 // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1578- cur = ggml_add (ctx0, ggml_mul_mat (ctx0, cls, inp), cls_b);
1578+ cur = ggml_mul_mat (ctx0, cls, inp);
1579+ if (cls_b != nullptr ) {
1580+ cur = ggml_add (ctx0, cur, cls_b);
1581+ }
15791582 cur = ggml_tanh (ctx0, cur);
15801583
15811584 if (cls_norm) {
@@ -1586,16 +1589,22 @@ void llm_graph_context::build_pooling(
15861589 // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
15871590 // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
15881591 if (cls_out) {
1589- GGML_ASSERT (cls_out_b != nullptr );
1590- cur = ggml_add (ctx0, ggml_mul_mat (ctx0, cls_out, cur), cls_out_b);
1592+ cur = ggml_mul_mat (ctx0, cls_out, cur);
1593+ if (cls_out_b != nullptr ) {
1594+ cur = ggml_add (ctx0, cur, cls_out_b);
1595+ }
15911596 }
15921597 } else if (cls_out) {
15931598 // Single layer classification head (direct projection)
15941599 // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1595- GGML_ASSERT (cls_out_b != nullptr );
1596- cur = ggml_add (ctx0, ggml_mul_mat (ctx0, cls_out, inp), cls_out_b);
1600+ cur = ggml_mul_mat (ctx0, cls_out, inp);
1601+ if (cls_out_b != nullptr ) {
1602+ cur = ggml_add (ctx0, cur, cls_out_b);
1603+ }
15971604 } else {
1598- GGML_ABORT (" RANK pooling requires either cls+cls_b or cls_out+cls_out_b" );
1605+ // Some models may not have either classification heads
1606+ // In this case, just use the CLS/pooled embedding directly
1607+ cur = inp;
15991608 }
16001609 } break ;
16011610 default :
0 commit comments