Skip to content

Commit e783f1c

Browse files
committed
feat: make embedding support list of string as input
makes the /v1/embedding route similar to OpenAI api.
1 parent 01a010b commit e783f1c

File tree

2 files changed

+30
-18
lines changed

2 files changed

+30
-18
lines changed

llama_cpp/llama.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,9 @@ def generate(
531531
if tokens_or_none is not None:
532532
tokens.extend(tokens_or_none)
533533

534-
def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding:
534+
def create_embedding(
535+
self, input: Union[str, List[str]], model: Optional[str] = None
536+
) -> Embedding:
535537
"""Embed a string.
536538
537539
Args:
@@ -551,30 +553,40 @@ def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding
551553
if self.verbose:
552554
llama_cpp.llama_reset_timings(self.ctx)
553555

554-
tokens = self.tokenize(input.encode("utf-8"))
555-
self.reset()
556-
self.eval(tokens)
557-
n_tokens = len(tokens)
558-
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
559-
: llama_cpp.llama_n_embd(self.ctx)
560-
]
556+
if isinstance(input, str):
557+
inputs = [input]
558+
else:
559+
inputs = input
561560

562-
if self.verbose:
563-
llama_cpp.llama_print_timings(self.ctx)
561+
data = []
562+
total_tokens = 0
563+
for input in inputs:
564+
tokens = self.tokenize(input.encode("utf-8"))
565+
self.reset()
566+
self.eval(tokens)
567+
n_tokens = len(tokens)
568+
total_tokens += n_tokens
569+
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
570+
: llama_cpp.llama_n_embd(self.ctx)
571+
]
564572

565-
return {
566-
"object": "list",
567-
"data": [
573+
if self.verbose:
574+
llama_cpp.llama_print_timings(self.ctx)
575+
data.append(
568576
{
569577
"object": "embedding",
570578
"embedding": embedding,
571579
"index": 0,
572580
}
573-
],
574-
"model": model_name,
581+
)
582+
583+
return {
584+
"object": "list",
585+
"data": data,
586+
"model": self.model_path,
575587
"usage": {
576-
"prompt_tokens": n_tokens,
577-
"total_tokens": n_tokens,
588+
"prompt_tokens": total_tokens,
589+
"total_tokens": total_tokens,
578590
},
579591
}
580592

llama_cpp/server/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ async def server_sent_events(
275275

276276
class CreateEmbeddingRequest(BaseModel):
277277
model: Optional[str] = model_field
278-
input: str = Field(description="The input to embed.")
278+
input: Union[str, List[str]] = Field(description="The input to embed.")
279279
user: Optional[str]
280280

281281
class Config:

0 commit comments

Comments
 (0)