@@ -147,6 +147,13 @@ def make_model(self):
147147 n_retry_server = self .n_retry_server ,
148148 log_probs = self .log_probs ,
149149 )
150+ elif self .backend == "vllm" :
151+ return VLLMChatModel (
152+ model_name = self .model_name ,
153+ temperature = self .temperature ,
154+ max_tokens = self .max_new_tokens ,
155+ n_retry_server = self .n_retry_server ,
156+ )
150157 else :
151158 raise ValueError (f"Backend { self .backend } is not supported" )
152159
@@ -440,3 +447,27 @@ def __init__(
440447
441448 client = InferenceClient (model = model_url , token = token )
442449 self .llm = partial (client .text_generation , max_new_tokens = max_new_tokens , details = log_probs )
450+
451+
452+ class VLLMChatModel (ChatModel ):
453+ def __init__ (
454+ self ,
455+ model_name ,
456+ api_key = None ,
457+ temperature = 0.5 ,
458+ max_tokens = 100 ,
459+ n_retry_server = 4 ,
460+ min_retry_wait_time = 60 ,
461+ ):
462+ super ().__init__ (
463+ model_name = model_name ,
464+ api_key = api_key ,
465+ temperature = temperature ,
466+ max_tokens = max_tokens ,
467+ max_retry = n_retry_server ,
468+ min_retry_wait_time = min_retry_wait_time ,
469+ api_key_env_var = "VLLM_API_KEY" ,
470+ client_class = OpenAI ,
471+ client_args = {"base_url" : "http://0.0.0.0:8000/v1" },
472+ pricing_func = None ,
473+ )
0 commit comments