@@ -531,7 +531,9 @@ def generate(
531531 if tokens_or_none is not None :
532532 tokens .extend (tokens_or_none )
533533
534- def create_embedding (self , input : str , model : Optional [str ] = None ) -> Embedding :
534+ def create_embedding (
535+ self , input : Union [str , List [str ]], model : Optional [str ] = None
536+ ) -> Embedding :
535537 """Embed a string.
536538
537539 Args:
@@ -551,30 +553,40 @@ def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding
551553 if self .verbose :
552554 llama_cpp .llama_reset_timings (self .ctx )
553555
554- tokens = self .tokenize (input .encode ("utf-8" ))
555- self .reset ()
556- self .eval (tokens )
557- n_tokens = len (tokens )
558- embedding = llama_cpp .llama_get_embeddings (self .ctx )[
559- : llama_cpp .llama_n_embd (self .ctx )
560- ]
556+ if isinstance (input , str ):
557+ inputs = [input ]
558+ else :
559+ inputs = input
561560
562- if self .verbose :
563- llama_cpp .llama_print_timings (self .ctx )
561+ data = []
562+ total_tokens = 0
563+ for input in inputs :
564+ tokens = self .tokenize (input .encode ("utf-8" ))
565+ self .reset ()
566+ self .eval (tokens )
567+ n_tokens = len (tokens )
568+ total_tokens += n_tokens
569+ embedding = llama_cpp .llama_get_embeddings (self .ctx )[
570+ : llama_cpp .llama_n_embd (self .ctx )
571+ ]
564572
565- return {
566- "object" : "list" ,
567- " data" : [
573+ if self . verbose :
574+ llama_cpp . llama_print_timings ( self . ctx )
575+ data . append (
568576 {
569577 "object" : "embedding" ,
570578 "embedding" : embedding ,
571579 "index" : 0 ,
572580 }
573- ],
574- "model" : model_name ,
581+ )
582+
583+ return {
584+ "object" : "list" ,
585+ "data" : data ,
586+ "model" : self .model_path ,
575587 "usage" : {
576- "prompt_tokens" : n_tokens ,
577- "total_tokens" : n_tokens ,
588+ "prompt_tokens" : total_tokens ,
589+ "total_tokens" : total_tokens ,
578590 },
579591 }
580592
0 commit comments