@@ -71,10 +71,18 @@ def get_api_server_url(self) -> Optional[str]:
7171 """Get the API server URL if available."""
7272 return None
7373
74+ def get_api_key (self ) -> str :
75+ """Get the API key."""
76+ return "EMPTY"
77+
7478 def get_model_config (self ) -> InferenceModelConfig :
7579 """Get the model configuration."""
7680 return self .config
7781
82+ def get_model_path (self ) -> Optional [str ]:
83+ """Get the model path"""
84+ return self .config .model_path
85+
7886
7987def _history_recorder (func ):
8088 """Decorator to record history of the model calls."""
@@ -118,10 +126,11 @@ def __init__(
118126 engine_type .startswith ("vllm" ) or engine_type == "tinker"
119127 ), "Only vLLM and tinker model is supported for now."
120128 self .model = model
129+ self .engine_type = engine_type
121130 self .config : InferenceModelConfig = None # init during prepare
122131 self ._model_name : str = None
123- self ._model_path : str = None
124132 self .api_address : str = None
133+ self ._api_key : str = None
125134 self .openai_client : openai .OpenAI = None
126135 self .openai_async_client : openai .AsyncOpenAI = None
127136 self .logger = get_logger (__name__ )
@@ -138,7 +147,7 @@ async def prepare(self) -> None:
138147 """Prepare the model wrapper."""
139148 self .config = await self .model .get_model_config .remote ()
140149 self ._model_name = self .config .name
141- self ._model_path = self .config . model_path
150+ self ._api_key = await self .model . get_api_key . remote ()
142151 self ._generate_kwargs = {
143152 "temperature" : self .config .temperature ,
144153 "top_p" : self .config .top_p ,
@@ -152,6 +161,8 @@ async def prepare(self) -> None:
152161 if self .api_address is None :
153162 self .logger .info ("API server is not enabled for inference model." )
154163 return
164+ if self .engine_type == "tinker" :
165+ return
155166 max_retries = 30
156167 interval = 2 # seconds
157168 for i in range (max_retries ):
@@ -285,6 +296,11 @@ async def convert_messages_to_experience_async(
285296 messages , tools = tools , temperature = temperature
286297 )
287298
299+ @property
300+ def api_key (self ) -> str :
301+ """Get the API key."""
302+ return self ._api_key
303+
288304 @property
289305 def model_version (self ) -> int :
290306 """Get the version of the model."""
@@ -298,7 +314,12 @@ async def model_version_async(self) -> int:
298314 @property
299315 def model_path (self ) -> str :
300316 """Get the model path."""
301- return self ._model_path
317+ return ray .get (self .model .get_model_path .remote ())
318+
319+ @property
320+ async def model_path_async (self ) -> str :
321+ """Get the model path."""
322+ return await self .model .get_model_path .remote ()
302323
303324 @property
304325 def model_name (self ) -> Optional [str ]:
@@ -332,16 +353,38 @@ def get_openai_client(self) -> openai.OpenAI:
332353 openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path.
333354 """
334355 if self .openai_client is not None :
356+ setattr (self .openai_client , "model_path" , self .model_path )
335357 return self .openai_client
336358 if not self .api_address :
337359 raise ValueError (
338360 "API server is not enabled for this model. OpenAI client is unavailable."
339361 )
340362 self .openai_client = openai .OpenAI (
341363 base_url = f"{ self .api_address } /v1" ,
342- api_key = "EMPTY" ,
364+ api_key = self . _api_key ,
343365 )
344- if self .enable_history :
366+ if self .engine_type == "tinker" :
367+ # ! TODO: because tinker's OpenAI API interface is in beta,
368+ # we need to use original API in thinker instead.
369+ ori_create = self .openai_async_client .chat .completions .create
370+
371+ async def chat_completions (* args , ** kwargs ):
372+ messages = kwargs .pop ("messages" )
373+ chat_response = ray .get (
374+ self .model .chat .remote (
375+ messages = messages ,
376+ with_chat_completion = True ,
377+ return_token_ids = self .enable_history ,
378+ ** kwargs ,
379+ )
380+ )
381+ response = chat_response .pop ()
382+ if self .enable_history :
383+ self .history .extend (chat_response )
384+ return response
385+
386+ self .openai_async_client .chat .completions .create = chat_completions
387+ elif self .enable_history :
345388 # add a decorator to the openai client to record history
346389
347390 ori_create = self .openai_client .chat .completions .create
@@ -359,7 +402,7 @@ def record_chat_completions(*args, **kwargs):
359402 return response
360403
361404 self .openai_client .chat .completions .create = record_chat_completions
362- setattr (self .openai_client , "model_path" , self .openai_client . models . list (). data [ 0 ]. id )
405+ setattr (self .openai_client , "model_path" , self .model_path )
363406 return self .openai_client
364407
365408 def get_openai_async_client (self ) -> openai .AsyncOpenAI :
@@ -369,6 +412,7 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
369412 openai.AsyncOpenAI: The async openai client. And `model_path` is added to the client which refers to the model path.
370413 """
371414 if self .openai_async_client is not None :
415+ setattr (self .openai_async_client , "model_path" , self .model_path )
372416 return self .openai_async_client
373417 if not self .api_address :
374418 raise ValueError (
@@ -377,9 +421,29 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
377421 # first make sure that we have the sync openai client
378422 self .openai_async_client = openai .AsyncOpenAI (
379423 base_url = f"{ self .api_address } /v1" ,
380- api_key = "EMPTY" ,
424+ api_key = self . _api_key ,
381425 )
382- if self .enable_history :
426+
427+ if self .engine_type == "tinker" :
428+ # ! TODO: because tinker's OpenAI API interface is in beta,
429+ # we need to use original API in thinker instead.
430+ ori_create = self .openai_async_client .chat .completions .create
431+
432+ async def chat_completions (* args , ** kwargs ):
433+ messages = kwargs .pop ("messages" )
434+ chat_response = await self .model .chat .remote (
435+ messages = messages ,
436+ with_chat_completion = True ,
437+ return_token_ids = self .enable_history ,
438+ ** kwargs ,
439+ )
440+ response = chat_response .pop ()
441+ if self .enable_history :
442+ self .history .extend (chat_response )
443+ return response
444+
445+ self .openai_async_client .chat .completions .create = chat_completions
446+ elif self .enable_history :
383447 # add a decorator to the openai client to record history
384448
385449 ori_create = self .openai_async_client .chat .completions .create
@@ -400,8 +464,8 @@ async def record_chat_completions(*args, **kwargs):
400464
401465 self .openai_async_client .chat .completions .create = record_chat_completions
402466 # get model_path from the sync openai client to avoid async call here
403- openai_client = self .get_openai_client ()
404- setattr (self .openai_async_client , "model_path" , openai_client . models . list (). data [ 0 ]. id )
467+ # openai_client = self.get_openai_client()
468+ setattr (self .openai_async_client , "model_path" , self . model_path )
405469 return self .openai_async_client
406470
407471 async def get_current_load (self ) -> int :
0 commit comments