@@ -143,6 +143,13 @@ def make_model(self):
143143 max_new_tokens = self .max_new_tokens ,
144144 n_retry_server = self .n_retry_server ,
145145 )
146+ elif self .backend == "vllm" :
147+ return VLLMChatModel (
148+ model_name = self .model_name ,
149+ temperature = self .temperature ,
150+ max_tokens = self .max_new_tokens ,
151+ n_retry_server = self .n_retry_server ,
152+ )
146153 else :
147154 raise ValueError (f"Backend { self .backend } is not supported" )
148155
@@ -423,3 +430,27 @@ def __init__(
423430
424431 client = InferenceClient (model = model_url , token = token )
425432 self .llm = partial (client .text_generation , max_new_tokens = max_new_tokens )
433+
434+
435+ class VLLMChatModel (ChatModel ):
436+ def __init__ (
437+ self ,
438+ model_name ,
439+ api_key = None ,
440+ temperature = 0.5 ,
441+ max_tokens = 100 ,
442+ n_retry_server = 4 ,
443+ min_retry_wait_time = 60 ,
444+ ):
445+ super ().__init__ (
446+ model_name = model_name ,
447+ api_key = api_key ,
448+ temperature = temperature ,
449+ max_tokens = max_tokens ,
450+ max_retry = n_retry_server ,
451+ min_retry_wait_time = min_retry_wait_time ,
452+ api_key_env_var = "VLLM_API_KEY" ,
453+ client_class = OpenAI ,
454+ client_args = {"base_url" : "http://0.0.0.0:8000/v1" },
455+ pricing_func = None ,
456+ )
0 commit comments