Skip to content

Commit 22ffc4f

Browse files
committed
integrate gemini and claude as ai providers to mcp endpoint
1 parent 0df91e8 commit 22ffc4f

File tree

3 files changed

+45
-20
lines changed

3 files changed

+45
-20
lines changed

apps/meshjs-rag/app/api/v1/ask_mesh_ai.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ async def ask_mesh_ai(body: ChatCompletionRequest, credentials: HTTPAuthorizatio
4848
if openai_api_key is None:
4949
raise ValueError("OpenAI api key is missing")
5050

51-
openai_service = OpenAIService(openai_api_key)
51+
openai_service = OpenAIService(
52+
embedding_api_key=openai_api_key,
53+
completion_api_key=openai_api_key,
54+
completion_model="gpt-4o-mini"
55+
)
5256

5357
try:
5458
question = body.messages[-1].content
@@ -71,35 +75,47 @@ async def ask_mesh_ai(body: ChatCompletionRequest, credentials: HTTPAuthorizatio
7175

7276
###########################################################################################################
7377
@router.post("/mcp")
74-
async def ask_mesh_ai(body: MCPRequestBody, authorization: str = Header(None), supabase: AsyncClient = Depends(get_db_client)):
78+
async def mesh_mcp(body: MCPRequestBody, authorization: str = Header(None), supabase: AsyncClient = Depends(get_db_client)):
7579

7680
if not authorization or not authorization.startswith("Bearer"):
7781
print("error")
7882
raise HTTPException(
7983
status_code=status.HTTP_401_UNAUTHORIZED,
8084
detail="You are not authorized"
8185
)
86+
87+
embedding_api_key = os.getenv("OPENAI_KEY") or None
88+
if embedding_api_key is None:
89+
raise ValueError("Embedding api key is missing")
8290

8391
try:
84-
OPENAI_KEY = authorization.split(" ")[-1]
85-
openai_service = OpenAIService(OPENAI_KEY)
86-
92+
completion_api_key = authorization.split(" ")[-1]
8793
question = body.query
8894
model = body.model
8995

96+
if model.startswith("gemini"):
97+
base_url = "https://generativelanguage.googleapis.com/v1beta/openai/"
98+
elif model.startswith("claude"):
99+
base_url = "https://api.anthropic.com/v1/"
100+
101+
openai_service = OpenAIService(
102+
embedding_api_key=embedding_api_key,
103+
completion_api_key=completion_api_key,
104+
completion_model=model,
105+
base_url=base_url
106+
)
107+
90108
embedded_query = await openai_service.embed_query(question)
91109
context = await get_context(embedded_query, supabase)
92-
response = await openai_service.get_mcp_answer(question=question, context=context, model=model)
110+
response = await openai_service.get_mcp_answer(question=question, context=context)
93111
return response
94112

95113
except (openai.APIError, openai.AuthenticationError, openai.RateLimitError) as e:
96-
print(e)
97114
raise HTTPException(
98115
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
99116
detail=f"An OpenAI API error occurred: {e}"
100117
)
101118
except Exception as e:
102-
print(e)
103119
raise HTTPException(
104120
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
105121
detail=f"An unexpected error occurred: {e}"

apps/meshjs-rag/app/services/openai.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,18 @@
3636
"""
3737

3838
class OpenAIService:
39-
def __init__(self, openai_api_key):
40-
self.client = AsyncOpenAI(api_key=openai_api_key)
39+
def __init__(self, embedding_api_key: str, completion_api_key: str, completion_model: str, base_url: str = None):
40+
self.embedding_client = AsyncOpenAI(api_key=embedding_api_key)
41+
self.completion_client = AsyncOpenAI(
42+
api_key=completion_api_key,
43+
base_url=base_url
44+
)
45+
self.model = completion_model
4146

4247
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
43-
async def _chat(self, messages, model="gpt-4o-mini", temperature=0.0, max_tokens=None, prompt_cache_key=None, stream: bool = False):
48+
async def _chat(self, messages, temperature=0.0, max_tokens=None, prompt_cache_key=None, stream: bool = False):
4449
kwargs = {
45-
"model": model,
50+
"model": self.model,
4651
"messages": messages,
4752
"temperature": temperature,
4853
"stream": stream
@@ -53,7 +58,7 @@ async def _chat(self, messages, model="gpt-4o-mini", temperature=0.0, max_tokens
5358
if max_tokens:
5459
kwargs["max_tokens"] = max_tokens
5560

56-
return await self.client.chat.completions.create(**kwargs)
61+
return await self.completion_client.chat.completions.create(**kwargs)
5762

5863
async def situate_context(self, doc: str, chunk: str, cache_key: str) -> str:
5964
messages = [
@@ -72,7 +77,7 @@ async def situate_context(self, doc: str, chunk: str, cache_key: str) -> str:
7277

7378
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), reraise=True)
7479
async def get_batch_embeddings(self, texts: List[str]) -> List[List[float]]:
75-
response = await self.client.embeddings.create(
80+
response = await self.embedding_client.embeddings.create(
7681
model="text-embedding-3-small",
7782
input=texts,
7883
encoding_format="float"
@@ -82,7 +87,7 @@ async def get_batch_embeddings(self, texts: List[str]) -> List[List[float]]:
8287

8388
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), reraise=True)
8489
async def embed_query(self, text: str) -> List[float]:
85-
response = await self.client.embeddings.create(
90+
response = await self.embedding_client.embeddings.create(
8691
model="text-embedding-3-small",
8792
input=text,
8893
encoding_format="float"
@@ -91,7 +96,7 @@ async def embed_query(self, text: str) -> List[float]:
9196
return response.data[0].embedding
9297

9398
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), reraise=True)
94-
async def get_answer(self, question: str, context: str, model="gpt-4o-mini"):
99+
async def get_answer(self, question: str, context: str):
95100
messages = [
96101
{
97102
"role": "system",
@@ -103,15 +108,15 @@ async def get_answer(self, question: str, context: str, model="gpt-4o-mini"):
103108
}
104109
]
105110

106-
stream = await self._chat(messages=messages, stream=True, model=model)
111+
stream = await self._chat(messages=messages, stream=True)
107112

108113
async for chunk in stream:
109114
yield f"data: {json.dumps(chunk.model_dump())}\n\n"
110115

111116
yield "data: [DONE]\n\n"
112117

113118
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), reraise=True)
114-
async def get_mcp_answer(self, question: str, context: str, model="gpt-4o-mini"):
119+
async def get_mcp_answer(self, question: str, context: str):
115120
messages = [
116121
{
117122
"role": "system",
@@ -123,5 +128,5 @@ async def get_mcp_answer(self, question: str, context: str, model="gpt-4o-mini")
123128
}
124129
]
125130

126-
response = await self._chat(messages=messages, model=model)
131+
response = await self._chat(messages=messages)
127132
return response.choices[0].message.content

apps/meshjs-rag/app/utils/process_chunks.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
if openai_api_key is None:
1414
raise ValueError("OpenAI api key is missing")
1515

16-
openai_service = OpenAIService(openai_api_key=openai_api_key)
16+
openai_service = OpenAIService(
17+
embedding_api_key=openai_api_key,
18+
completion_api_key=openai_api_key,
19+
completion_model="gpt-4o-mini"
20+
)
1721

1822
async def process_chunks_and_update_db(
1923
chunks: List[str],

0 commit comments

Comments
 (0)