@@ -470,6 +470,7 @@ class llama_model_params(Structure):
470470#     bool logits_all;  // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) 
471471#     bool embedding;   // embedding mode only 
472472#     bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU 
473+ #     bool do_pooling;  // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) 
473474# }; 
474475class  llama_context_params (Structure ):
475476    """Parameters for llama_context 
@@ -496,6 +497,7 @@ class llama_context_params(Structure):
496497        logits_all (bool): the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) 
497498        embedding (bool): embedding mode only 
498499        offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU 
500+         do_pooling (bool): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) 
499501    """ 
500502
501503    _fields_  =  [
@@ -520,6 +522,7 @@ class llama_context_params(Structure):
520522        ("logits_all" , c_bool ),
521523        ("embedding" , c_bool ),
522524        ("offload_kqv" , c_bool ),
525+         ("do_pooling" , c_bool ),
523526    ]
524527
525528
@@ -1699,6 +1702,21 @@ def llama_get_embeddings(
16991702_lib .llama_get_embeddings .restype  =  c_float_p 
17001703
17011704
1705+ # // Get the embeddings for the ith sequence 
1706+ # // llama_get_embeddings(ctx) + i*n_embd 
1707+ # LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); 
1708+ def  llama_get_embeddings_ith (
1709+     ctx : llama_context_p , i : Union [c_int32 , int ]
1710+ ):  # type: (...) -> Array[float] # type: ignore 
1711+     """Get the embeddings for the ith sequence 
1712+     llama_get_embeddings(ctx) + i*n_embd""" 
1713+     return  _lib .llama_get_embeddings_ith (ctx , i )
1714+ 
1715+ 
1716+ _lib .llama_get_embeddings_ith .argtypes  =  [llama_context_p , c_int32 ]
1717+ _lib .llama_get_embeddings_ith .restype  =  c_float_p 
1718+ 
1719+ 
17021720# // 
17031721# // Vocab 
17041722# // 
0 commit comments