@@ -717,10 +717,53 @@ def create_embedding(
717717        Returns: 
718718            An embedding object. 
719719        """ 
720-         assert  self ._ctx .ctx  is  not None 
721720        assert  self ._model .model  is  not None 
722721        model_name : str  =  model  if  model  is  not None  else  self .model_path 
723722
723+         # get numeric embeddings 
724+         embeds : List [List [float ]]
725+         total_tokens : int 
726+         embeds , total_tokens  =  self .embed (input , return_count = True )  # type: ignore 
727+ 
728+         # convert to CreateEmbeddingResponse 
729+         data : List [Embedding ] =  [
730+             {
731+                 "object" : "embedding" ,
732+                 "embedding" : emb ,
733+                 "index" : idx ,
734+             }
735+             for  idx , emb  in  enumerate (embeds )
736+         ]
737+ 
738+         return  {
739+             "object" : "list" ,
740+             "data" : data ,
741+             "model" : model_name ,
742+             "usage" : {
743+                 "prompt_tokens" : total_tokens ,
744+                 "total_tokens" : total_tokens ,
745+             },
746+         }
747+ 
748+     def  embed (
749+         self ,
750+         input : Union [str , List [str ]],
751+         normalize : bool  =  True ,
752+         truncate : bool  =  True ,
753+         return_count : bool  =  False ,
754+     ):
755+         """Embed a string. 
756+ 
757+         Args: 
758+             input: The utf-8 encoded string to embed. 
759+ 
760+         Returns: 
761+             A list of embeddings 
762+         """ 
763+         assert  self ._ctx .ctx  is  not None 
764+         n_embd  =  self .n_embd ()
765+         n_ctx  =  self .n_ctx ()
766+ 
724767        if  self .context_params .embedding  ==  False :
725768            raise  RuntimeError (
726769                "Llama model must be created with embedding=True to call this method" 
@@ -734,48 +777,72 @@ def create_embedding(
734777        else :
735778            inputs  =  input 
736779
737-         data : List [Embedding ] =  []
780+         # reset batch 
781+         self ._batch .reset ()
782+ 
783+         # decode and fetch embeddings 
784+         data : List [List [float ]] =  []
785+         def  decode_batch (sizes : List [int ]):
786+             assert  self ._ctx .ctx  is  not None 
787+             llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
788+             self ._ctx .decode (self ._batch )
789+             self ._batch .reset ()
790+ 
791+             # store embeddings 
792+             for  i , s  in  enumerate (sizes ):
793+                 embedding  =  llama_cpp .llama_get_embeddings_ith (self ._ctx .ctx , i )[
794+                     :n_embd 
795+                 ]
796+                 norm  =  np .linalg .norm (embedding ) if  normalize  else  s 
797+                 embedding : List [float ] =  [v  /  float (norm ) for  v  in  embedding ]
798+                 data .append (embedding )
799+ 
800+         # init state 
738801        total_tokens  =  0 
739-         for  index , input  in  enumerate (inputs ):
740-             tokens  =  self .tokenize (input .encode ("utf-8" ), special = True )
741-             self .reset ()
742-             self .eval (tokens )
802+         t_batch  =  0 
803+         s_sizes : List [int ] =  []
804+ 
805+         # accumulate batches and encode 
806+         for  text  in  inputs :
807+             tokens  =  self .tokenize (text .encode ("utf-8" ))
808+             if  truncate :
809+                 tokens  =  tokens [:n_ctx ]
810+ 
743811            n_tokens  =  len (tokens )
744812            total_tokens  +=  n_tokens 
745-             embedding  =  llama_cpp .llama_get_embeddings (self ._ctx .ctx )[
746-                 : llama_cpp .llama_n_embd (self ._model .model )
747-             ]
748813
749-             data .append (
750-                 {
751-                     "object" : "embedding" ,
752-                     "embedding" : embedding ,
753-                     "index" : index ,
754-                 }
755-             )
814+             # check for overrun 
815+             if  n_tokens  >  n_ctx :
816+                 raise  ValueError (
817+                     f"Requested tokens ({ n_tokens } { n_ctx }  
818+                 )
819+ 
820+             # time to eval batch 
821+             if  t_batch  +  n_tokens  >  self ._n_ctx :
822+                 decode_batch (s_sizes )
823+                 t_batch  =  0 
824+                 s_sizes  =  []
825+ 
826+             # add to batch 
827+             self ._batch .add_sequence (tokens , len (s_sizes ), False )
828+             t_batch  +=  n_tokens 
829+             s_sizes .append (n_tokens )
830+ 
831+         # hanlde last batch 
832+         decode_batch (s_sizes )
833+ 
756834        if  self .verbose :
757835            llama_cpp .llama_print_timings (self ._ctx .ctx )
758836
759-         return  {
760-             "object" : "list" ,
761-             "data" : data ,
762-             "model" : model_name ,
763-             "usage" : {
764-                 "prompt_tokens" : total_tokens ,
765-                 "total_tokens" : total_tokens ,
766-             },
767-         }
768- 
769-     def  embed (self , input : str ) ->  List [float ]:
770-         """Embed a string. 
837+         output  =  data [0 ] if  isinstance (input , str ) else  data 
771838
772-         Args:  
773-             input: The utf-8 encoded string to embed.  
839+         llama_cpp . llama_kv_cache_clear ( self . _ctx . ctx ) 
840+         self . reset () 
774841
775-         Returns : 
776-             A list of embeddings  
777-         """  
778-         return   list ( map ( float ,  self . create_embedding ( input )[ "data" ][ 0 ][ "embedding" ])) 
842+         if   return_count :
843+             return   output ,  total_tokens 
844+         else : 
845+              return   output 
779846
780847    def  _create_completion (
781848        self ,
0 commit comments