3636"""
3737
3838class OpenAIService :
39- def __init__ (self , openai_api_key ):
40- self .client = AsyncOpenAI (api_key = openai_api_key )
39+ def __init__ (self , embedding_api_key : str , completion_api_key : str , completion_model : str , base_url : str = None ):
40+ self .embedding_client = AsyncOpenAI (api_key = embedding_api_key )
41+ self .completion_client = AsyncOpenAI (
42+ api_key = completion_api_key ,
43+ base_url = base_url
44+ )
45+ self .model = completion_model
4146
4247 @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ))
43- async def _chat (self , messages , model = "gpt-4o-mini" , temperature = 0.0 , max_tokens = None , prompt_cache_key = None , stream : bool = False ):
48+ async def _chat (self , messages , temperature = 0.0 , max_tokens = None , prompt_cache_key = None , stream : bool = False ):
4449 kwargs = {
45- "model" : model ,
50+ "model" : self . model ,
4651 "messages" : messages ,
4752 "temperature" : temperature ,
4853 "stream" : stream
@@ -53,7 +58,7 @@ async def _chat(self, messages, model="gpt-4o-mini", temperature=0.0, max_tokens
5358 if max_tokens :
5459 kwargs ["max_tokens" ] = max_tokens
5560
56- return await self .client .chat .completions .create (** kwargs )
61+ return await self .completion_client .chat .completions .create (** kwargs )
5762
5863 async def situate_context (self , doc : str , chunk : str , cache_key : str ) -> str :
5964 messages = [
@@ -72,7 +77,7 @@ async def situate_context(self, doc: str, chunk: str, cache_key: str) -> str:
7277
7378 @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ), reraise = True )
7479 async def get_batch_embeddings (self , texts : List [str ]) -> List [List [float ]]:
75- response = await self .client .embeddings .create (
80+ response = await self .embedding_client .embeddings .create (
7681 model = "text-embedding-3-small" ,
7782 input = texts ,
7883 encoding_format = "float"
@@ -82,7 +87,7 @@ async def get_batch_embeddings(self, texts: List[str]) -> List[List[float]]:
8287
8388 @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ), reraise = True )
8489 async def embed_query (self , text : str ) -> List [float ]:
85- response = await self .client .embeddings .create (
90+ response = await self .embedding_client .embeddings .create (
8691 model = "text-embedding-3-small" ,
8792 input = text ,
8893 encoding_format = "float"
@@ -91,7 +96,7 @@ async def embed_query(self, text: str) -> List[float]:
9196 return response .data [0 ].embedding
9297
9398 @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ), reraise = True )
94- async def get_answer (self , question : str , context : str , model = "gpt-4o-mini" ):
99+ async def get_answer (self , question : str , context : str ):
95100 messages = [
96101 {
97102 "role" : "system" ,
@@ -103,15 +108,15 @@ async def get_answer(self, question: str, context: str, model="gpt-4o-mini"):
103108 }
104109 ]
105110
106- stream = await self ._chat (messages = messages , stream = True , model = model )
111+ stream = await self ._chat (messages = messages , stream = True )
107112
108113 async for chunk in stream :
109114 yield f"data: { json .dumps (chunk .model_dump ())} \n \n "
110115
111116 yield "data: [DONE]\n \n "
112117
113118 @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ), reraise = True )
114- async def get_mcp_answer (self , question : str , context : str , model = "gpt-4o-mini" ):
119+ async def get_mcp_answer (self , question : str , context : str ):
115120 messages = [
116121 {
117122 "role" : "system" ,
@@ -123,5 +128,5 @@ async def get_mcp_answer(self, question: str, context: str, model="gpt-4o-mini")
123128 }
124129 ]
125130
126- response = await self ._chat (messages = messages , model = model )
131+ response = await self ._chat (messages = messages )
127132 return response .choices [0 ].message .content
0 commit comments