@@ -71,10 +71,18 @@ def get_api_server_url(self) -> Optional[str]:
7171 """Get the API server URL if available."""
7272 return None
7373
74+ def get_api_key (self ) -> str :
75+ """Get the API key."""
76+ return "EMPTY"
77+
7478 def get_model_config (self ) -> InferenceModelConfig :
7579 """Get the model configuration."""
7680 return self .config
7781
82+ def get_model_path (self ) -> Optional [str ]:
83+ """Get the model path"""
84+ return self .config .model_path
85+
7886
7987def _history_recorder (func ):
8088 """Decorator to record history of the model calls."""
@@ -118,10 +126,11 @@ def __init__(
118126 engine_type .startswith ("vllm" ) or engine_type == "tinker"
119127 ), "Only vLLM and tinker model is supported for now."
120128 self .model = model
129+ self .engine_type = engine_type
121130 self .config : InferenceModelConfig = None # init during prepare
122131 self ._model_name : str = None
123- self ._model_path : str = None
124132 self .api_address : str = None
133+ self ._api_key : str = None
125134 self .openai_client : openai .OpenAI = None
126135 self .openai_async_client : openai .AsyncOpenAI = None
127136 self .logger = get_logger (__name__ )
@@ -138,7 +147,7 @@ async def prepare(self) -> None:
138147 """Prepare the model wrapper."""
139148 self .config = await self .model .get_model_config .remote ()
140149 self ._model_name = self .config .name
141- self ._model_path = self .config . model_path
150+ self ._api_key = await self .model . get_api_key . remote ()
142151 self ._generate_kwargs = {
143152 "temperature" : self .config .temperature ,
144153 "top_p" : self .config .top_p ,
@@ -152,6 +161,8 @@ async def prepare(self) -> None:
152161 if self .api_address is None :
153162 self .logger .info ("API server is not enabled for inference model." )
154163 return
164+ if self .engine_type == "tinker" :
165+ return
155166 max_retries = 30
156167 interval = 2 # seconds
157168 for i in range (max_retries ):
@@ -285,6 +296,11 @@ async def convert_messages_to_experience_async(
285296 messages , tools = tools , temperature = temperature
286297 )
287298
299+ @property
300+ def api_key (self ) -> str :
301+ """Get the API key."""
302+ return self ._api_key
303+
288304 @property
289305 def model_version (self ) -> int :
290306 """Get the version of the model."""
@@ -297,8 +313,23 @@ async def model_version_async(self) -> int:
297313
298314 @property
299315 def model_path (self ) -> str :
300- """Get the model path."""
301- return self ._model_path
316+ """
317+ Returns the path to the model files based on the current engine type.
318+
319+ - For 'vllm' engine: returns the model path from the configuration (`config.model_path`)
320+ - For 'tinker' engine: returns the path to the most recent sampler weights
321+ """
322+ return ray .get (self .model .get_model_path .remote ())
323+
324+ @property
325+ async def model_path_async (self ) -> str :
326+ """
327+ Returns the path to the model files based on the current engine type.
328+
329+ - For 'vllm' engine: returns the model path from the configuration (`config.model_path`)
330+ - For 'tinker' engine: returns the path to the most recent sampler weights
331+ """
332+ return await self .model .get_model_path .remote ()
302333
303334 @property
304335 def model_name (self ) -> Optional [str ]:
@@ -332,16 +363,36 @@ def get_openai_client(self) -> openai.OpenAI:
332363 openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path.
333364 """
334365 if self .openai_client is not None :
366+ setattr (self .openai_client , "model_path" , self .model_path )
335367 return self .openai_client
336368 if not self .api_address :
337369 raise ValueError (
338370 "API server is not enabled for this model. OpenAI client is unavailable."
339371 )
340372 self .openai_client = openai .OpenAI (
341373 base_url = f"{ self .api_address } /v1" ,
342- api_key = "EMPTY" ,
374+ api_key = self . _api_key ,
343375 )
344- if self .enable_history :
376+ if self .engine_type == "tinker" :
377+ # ! TODO: because tinker's OpenAI API interface is in beta,
378+ # we need to use original API in thinker instead.
379+ def chat_completions (* args , ** kwargs ):
380+ messages = kwargs .pop ("messages" )
381+ chat_response = ray .get (
382+ self .model .chat .remote (
383+ messages = messages ,
384+ with_chat_completion = True ,
385+ return_token_ids = self .enable_history ,
386+ ** kwargs ,
387+ )
388+ )
389+ response = chat_response .pop ()
390+ if self .enable_history :
391+ self .history .extend (chat_response )
392+ return response
393+
394+ self .openai_client .chat .completions .create = chat_completions
395+ elif self .enable_history :
345396 # add a decorator to the openai client to record history
346397
347398 ori_create = self .openai_client .chat .completions .create
@@ -359,7 +410,7 @@ def record_chat_completions(*args, **kwargs):
359410 return response
360411
361412 self .openai_client .chat .completions .create = record_chat_completions
362- setattr (self .openai_client , "model_path" , self .openai_client . models . list (). data [ 0 ]. id )
413+ setattr (self .openai_client , "model_path" , self .model_path )
363414 return self .openai_client
364415
365416 def get_openai_async_client (self ) -> openai .AsyncOpenAI :
@@ -369,6 +420,7 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
369420 openai.AsyncOpenAI: The async openai client. And `model_path` is added to the client which refers to the model path.
370421 """
371422 if self .openai_async_client is not None :
423+ setattr (self .openai_async_client , "model_path" , self .model_path )
372424 return self .openai_async_client
373425 if not self .api_address :
374426 raise ValueError (
@@ -377,9 +429,27 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
377429 # first make sure that we have the sync openai client
378430 self .openai_async_client = openai .AsyncOpenAI (
379431 base_url = f"{ self .api_address } /v1" ,
380- api_key = "EMPTY" ,
432+ api_key = self . _api_key ,
381433 )
382- if self .enable_history :
434+
435+ if self .engine_type == "tinker" :
436+ # ! TODO: because tinker's OpenAI API interface is in beta,
437+ # we need to use original API in thinker instead.
438+ async def chat_completions (* args , ** kwargs ):
439+ messages = kwargs .pop ("messages" )
440+ chat_response = await self .model .chat .remote (
441+ messages = messages ,
442+ with_chat_completion = True ,
443+ return_token_ids = self .enable_history ,
444+ ** kwargs ,
445+ )
446+ response = chat_response .pop ()
447+ if self .enable_history :
448+ self .history .extend (chat_response )
449+ return response
450+
451+ self .openai_async_client .chat .completions .create = chat_completions
452+ elif self .enable_history :
383453 # add a decorator to the openai client to record history
384454
385455 ori_create = self .openai_async_client .chat .completions .create
@@ -400,8 +470,7 @@ async def record_chat_completions(*args, **kwargs):
400470
401471 self .openai_async_client .chat .completions .create = record_chat_completions
402472 # get model_path from the sync openai client to avoid async call here
403- openai_client = self .get_openai_client ()
404- setattr (self .openai_async_client , "model_path" , openai_client .models .list ().data [0 ].id )
473+ setattr (self .openai_async_client , "model_path" , self .model_path )
405474 return self .openai_async_client
406475
407476 async def get_current_load (self ) -> int :
0 commit comments