Skip to content

Commit d8da9c7

Browse files
committed
feat(beeai-server): switch embeddings proxy to client libs
Signed-off-by: Jan Pokorný <JenomPokorny@gmail.com>
1 parent 53c6856 commit d8da9c7

File tree

5 files changed

+61
-359
lines changed

5 files changed

+61
-359
lines changed

apps/beeai-cli/src/beeai_cli/commands/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ async def setup(
120120
),
121121
Choice(
122122
name="IBM watsonx".ljust(25),
123-
value=("watsonx", None, "ibm/granite-3-3-8b-instruct", "granite-embedding-278m-multilingual"),
123+
value=("watsonx", None, "ibm/granite-3-3-8b-instruct", "ibm/granite-embedding-278m-multilingual"),
124124
),
125125
Choice(name="Jan".ljust(25) + "💻 local", value=("Jan", "http://localhost:1337/v1", None, None)),
126126
Choice(

apps/beeai-server/Dockerfile

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
FROM python:3.13-alpine3.21 AS builder
22
WORKDIR /app
33
COPY --from=ghcr.io/astral-sh/uv:0.6.2 /uv /bin/
4-
# tiktoken builds using rust and cargo
5-
RUN apk add --no-cache rust cargo
64
COPY pyproject.toml dist/requirements.txt ./
75
RUN uv pip install --system -r requirements.txt
86
COPY dist/*.tar.gz ./
97
RUN uv pip install --system ./*.tar.gz
108

119
FROM python:3.13-alpine3.21
1210
WORKDIR /app
13-
# tiktoken requires libgcc
14-
RUN apk add --no-cache libgcc
1511
COPY --from=builder /usr/local/lib/python3.13/site-packages/ /usr/local/lib/python3.13/site-packages/
1612
COPY --from=builder /usr/local/bin/beeai-server /usr/local/bin/beeai-server
1713
COPY --from=builder /usr/local/bin/migrate /usr/local/bin/migrate

apps/beeai-server/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ dependencies = [
2626
"cachetools>=5.5.2",
2727
"python-multipart>=0.0.20",
2828
"kr8s>=0.20.7",
29-
"beeai-framework~=0.1.29",
3029
"alembic>=1.15.2",
3130
"asyncpg>=0.30.0",
3231
"sqlalchemy[asyncio]>=2.0.41",
@@ -40,6 +39,7 @@ dependencies = [
4039
"sqlparse>=0.5.3",
4140
"pgvector>=0.4.1",
4241
"ibm-watsonx-ai>=1.3.28",
42+
"openai>=1.97.0",
4343
]
4444

4545
[project.scripts]
Lines changed: 54 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,74 @@
11
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import re
4+
from typing import Literal
55

66
import fastapi
7-
from beeai_framework.adapters.openai.backend.embedding import OpenAIEmbeddingModel
8-
from beeai_framework.adapters.watsonx.backend.embedding import WatsonxEmbeddingModel
9-
from beeai_framework.backend.types import EmbeddingModelOutput
10-
from pydantic import BaseModel
7+
import ibm_watsonx_ai
8+
import ibm_watsonx_ai.foundation_models.embeddings
9+
import openai
10+
import openai.types
11+
import pydantic
12+
from fastapi.concurrency import run_in_threadpool
1113

1214
from beeai_server.api.dependencies import EnvServiceDependency
1315

1416
router = fastapi.APIRouter()
1517

16-
17-
class EmbeddingsRequest(BaseModel):
18-
model: str
19-
input: list[str] | str
18+
BEEAI_PROXY_VERSION = 1
2019

2120

22-
class EmbeddingsDataItem(BaseModel):
23-
object: str = "embedding"
24-
index: int
25-
embedding: list[float]
21+
class EmbeddingsRequest(pydantic.BaseModel):
22+
"""
23+
Corresponds to the arguments for OpenAI `client.embeddings.create(...)`.
24+
"""
2625

27-
28-
class EmbeddingsResponse(BaseModel):
29-
object: str = "list"
30-
system_fingerprint: str = "beeai-embeddings-gateway"
3126
model: str
32-
usage: dict[str, int] = {
33-
"prompt_tokens": int,
34-
"total_tokens": int,
35-
"completion_tokens": int,
36-
}
37-
data: list[EmbeddingsDataItem]
27+
input: list[str] | str
28+
encoding_format: Literal["float"]
3829

3930

4031
@router.post("/embeddings")
41-
async def create_embeddings(
42-
env_service: EnvServiceDependency,
43-
request: EmbeddingsRequest,
44-
):
32+
async def create_embedding(env_service: EnvServiceDependency, request: EmbeddingsRequest):
4533
env = await env_service.list_env()
4634

47-
is_rits = re.match(r"^https://[a-z0-9.-]+\.rits\.fmaas\.res\.ibm.com/.*$", env["LLM_API_BASE"])
48-
is_watsonx = re.match(r"^https://[a-z0-9.-]+\.ml\.cloud\.ibm\.com.*?$", env["LLM_API_BASE"])
49-
50-
embeddings = (
51-
WatsonxEmbeddingModel(
52-
model_id=env["EMBEDDING_MODEL"],
53-
api_key=env["LLM_API_KEY"],
54-
base_url=env["LLM_API_BASE"],
55-
project_id=env.get("WATSONX_PROJECT_ID"),
56-
space_id=env.get("WATSONX_SPACE_ID"),
57-
)
58-
if is_watsonx
59-
else OpenAIEmbeddingModel(
60-
env["EMBEDDING_MODEL"],
61-
api_key=env["LLM_API_KEY"],
62-
base_url=env["LLM_API_BASE"],
63-
extra_headers={"RITS_API_KEY": env["LLM_API_KEY"]} if is_rits else {},
35+
if pydantic.HttpUrl(env["LLM_API_BASE"]).host.endswith(".ml.cloud.ibm.com"):
36+
watsonx_response = await run_in_threadpool(
37+
ibm_watsonx_ai.foundation_models.embeddings.Embeddings(
38+
model_id=env["EMBEDDING_MODEL"],
39+
credentials=ibm_watsonx_ai.Credentials(url=env["LLM_API_BASE"], api_key=env["LLM_API_KEY"]),
40+
project_id=env.get("WATSONX_PROJECT_ID"),
41+
space_id=env.get("WATSONX_SPACE_ID"),
42+
).generate,
43+
inputs=[request.input] if isinstance(request.input, str) else request.input,
6444
)
65-
)
66-
67-
output: EmbeddingModelOutput = await embeddings.create(
68-
values=(request.input if isinstance(request.input, list) else [request.input]),
69-
)
70-
71-
return EmbeddingsResponse(
72-
model=request.model,
73-
data=[EmbeddingsDataItem(index=i, embedding=embedding) for i, embedding in enumerate(output.embeddings)],
74-
usage={
75-
"prompt_tokens": output.usage.prompt_tokens,
76-
"completion_tokens": output.usage.completion_tokens,
77-
"total_tokens": output.usage.total_tokens,
78-
},
79-
)
45+
return openai.types.CreateEmbeddingResponse(
46+
object="list",
47+
model=watsonx_response["model_id"],
48+
data=[
49+
openai.types.Embedding(
50+
object="embedding",
51+
index=i,
52+
embedding=result["embedding"],
53+
)
54+
for i, result in enumerate(watsonx_response.get("results", []))
55+
],
56+
usage=openai.types.create_embedding_response.Usage(
57+
prompt_tokens=watsonx_response.get("usage", {}).get("prompt_tokens", 0),
58+
total_tokens=watsonx_response.get("usage", {}).get("total_tokens", 0),
59+
),
60+
).model_dump(mode="json") | {"beeai_proxy_version": BEEAI_PROXY_VERSION}
61+
else:
62+
return (
63+
await openai.AsyncOpenAI(
64+
api_key=env["LLM_API_KEY"],
65+
base_url=env["LLM_API_BASE"],
66+
default_headers=(
67+
{"RITS_API_KEY": env["LLM_API_KEY"]}
68+
if pydantic.HttpUrl(env["LLM_API_BASE"]).host.endswith(".rits.fmaas.res.ibm.com")
69+
else {}
70+
),
71+
).embeddings.create(
72+
**(request.model_dump(mode="json", exclude_none=True) | {"model": env["EMBEDDING_MODEL"]})
73+
)
74+
).model_dump(mode="json") | {"beeai_proxy_version": BEEAI_PROXY_VERSION}

0 commit comments

Comments
 (0)