|
1 | 1 | # Copyright 2025 © BeeAI a Series of LF Projects, LLC |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
4 | | -import re |
| 4 | +from typing import Literal |
5 | 5 |
|
6 | 6 | import fastapi |
7 | | -from beeai_framework.adapters.openai.backend.embedding import OpenAIEmbeddingModel |
8 | | -from beeai_framework.adapters.watsonx.backend.embedding import WatsonxEmbeddingModel |
9 | | -from beeai_framework.backend.types import EmbeddingModelOutput |
10 | | -from pydantic import BaseModel |
| 7 | +import ibm_watsonx_ai |
| 8 | +import ibm_watsonx_ai.foundation_models.embeddings |
| 9 | +import openai |
| 10 | +import openai.types |
| 11 | +import pydantic |
| 12 | +from fastapi.concurrency import run_in_threadpool |
11 | 13 |
|
12 | 14 | from beeai_server.api.dependencies import EnvServiceDependency |
13 | 15 |
|
14 | 16 | router = fastapi.APIRouter() |
15 | 17 |
|
16 | | - |
17 | | -class EmbeddingsRequest(BaseModel): |
18 | | - model: str |
19 | | - input: list[str] | str |
| 18 | +BEEAI_PROXY_VERSION = 1 |
20 | 19 |
|
21 | 20 |
|
22 | | -class EmbeddingsDataItem(BaseModel): |
23 | | - object: str = "embedding" |
24 | | - index: int |
25 | | - embedding: list[float] |
| 21 | +class EmbeddingsRequest(pydantic.BaseModel): |
| 22 | + """ |
| 23 | + Corresponds to the arguments for OpenAI `client.embeddings.create(...)`. |
| 24 | + """ |
26 | 25 |
|
27 | | - |
28 | | -class EmbeddingsResponse(BaseModel): |
29 | | - object: str = "list" |
30 | | - system_fingerprint: str = "beeai-embeddings-gateway" |
31 | 26 | model: str |
32 | | - usage: dict[str, int] = { |
33 | | - "prompt_tokens": int, |
34 | | - "total_tokens": int, |
35 | | - "completion_tokens": int, |
36 | | - } |
37 | | - data: list[EmbeddingsDataItem] |
| 27 | + input: list[str] | str |
| 28 | + encoding_format: Literal["float"] |
38 | 29 |
|
39 | 30 |
|
40 | 31 | @router.post("/embeddings") |
41 | | -async def create_embeddings( |
42 | | - env_service: EnvServiceDependency, |
43 | | - request: EmbeddingsRequest, |
44 | | -): |
| 32 | +async def create_embedding(env_service: EnvServiceDependency, request: EmbeddingsRequest): |
45 | 33 | env = await env_service.list_env() |
46 | 34 |
|
47 | | - is_rits = re.match(r"^https://[a-z0-9.-]+\.rits\.fmaas\.res\.ibm.com/.*$", env["LLM_API_BASE"]) |
48 | | - is_watsonx = re.match(r"^https://[a-z0-9.-]+\.ml\.cloud\.ibm\.com.*?$", env["LLM_API_BASE"]) |
49 | | - |
50 | | - embeddings = ( |
51 | | - WatsonxEmbeddingModel( |
52 | | - model_id=env["EMBEDDING_MODEL"], |
53 | | - api_key=env["LLM_API_KEY"], |
54 | | - base_url=env["LLM_API_BASE"], |
55 | | - project_id=env.get("WATSONX_PROJECT_ID"), |
56 | | - space_id=env.get("WATSONX_SPACE_ID"), |
57 | | - ) |
58 | | - if is_watsonx |
59 | | - else OpenAIEmbeddingModel( |
60 | | - env["EMBEDDING_MODEL"], |
61 | | - api_key=env["LLM_API_KEY"], |
62 | | - base_url=env["LLM_API_BASE"], |
63 | | - extra_headers={"RITS_API_KEY": env["LLM_API_KEY"]} if is_rits else {}, |
| 35 | + if pydantic.HttpUrl(env["LLM_API_BASE"]).host.endswith(".ml.cloud.ibm.com"): |
| 36 | + watsonx_response = await run_in_threadpool( |
| 37 | + ibm_watsonx_ai.foundation_models.embeddings.Embeddings( |
| 38 | + model_id=env["EMBEDDING_MODEL"], |
| 39 | + credentials=ibm_watsonx_ai.Credentials(url=env["LLM_API_BASE"], api_key=env["LLM_API_KEY"]), |
| 40 | + project_id=env.get("WATSONX_PROJECT_ID"), |
| 41 | + space_id=env.get("WATSONX_SPACE_ID"), |
| 42 | + ).generate, |
| 43 | + inputs=[request.input] if isinstance(request.input, str) else request.input, |
64 | 44 | ) |
65 | | - ) |
66 | | - |
67 | | - output: EmbeddingModelOutput = await embeddings.create( |
68 | | - values=(request.input if isinstance(request.input, list) else [request.input]), |
69 | | - ) |
70 | | - |
71 | | - return EmbeddingsResponse( |
72 | | - model=request.model, |
73 | | - data=[EmbeddingsDataItem(index=i, embedding=embedding) for i, embedding in enumerate(output.embeddings)], |
74 | | - usage={ |
75 | | - "prompt_tokens": output.usage.prompt_tokens, |
76 | | - "completion_tokens": output.usage.completion_tokens, |
77 | | - "total_tokens": output.usage.total_tokens, |
78 | | - }, |
79 | | - ) |
| 45 | + return openai.types.CreateEmbeddingResponse( |
| 46 | + object="list", |
| 47 | + model=watsonx_response["model_id"], |
| 48 | + data=[ |
| 49 | + openai.types.Embedding( |
| 50 | + object="embedding", |
| 51 | + index=i, |
| 52 | + embedding=result["embedding"], |
| 53 | + ) |
| 54 | + for i, result in enumerate(watsonx_response.get("results", [])) |
| 55 | + ], |
| 56 | + usage=openai.types.create_embedding_response.Usage( |
| 57 | + prompt_tokens=watsonx_response.get("usage", {}).get("prompt_tokens", 0), |
| 58 | + total_tokens=watsonx_response.get("usage", {}).get("total_tokens", 0), |
| 59 | + ), |
| 60 | + ).model_dump(mode="json") | {"beeai_proxy_version": BEEAI_PROXY_VERSION} |
| 61 | + else: |
| 62 | + return ( |
| 63 | + await openai.AsyncOpenAI( |
| 64 | + api_key=env["LLM_API_KEY"], |
| 65 | + base_url=env["LLM_API_BASE"], |
| 66 | + default_headers=( |
| 67 | + {"RITS_API_KEY": env["LLM_API_KEY"]} |
| 68 | + if pydantic.HttpUrl(env["LLM_API_BASE"]).host.endswith(".rits.fmaas.res.ibm.com") |
| 69 | + else {} |
| 70 | + ), |
| 71 | + ).embeddings.create( |
| 72 | + **(request.model_dump(mode="json", exclude_none=True) | {"model": env["EMBEDDING_MODEL"]}) |
| 73 | + ) |
| 74 | + ).model_dump(mode="json") | {"beeai_proxy_version": BEEAI_PROXY_VERSION} |
0 commit comments