@@ -143,6 +143,13 @@ def make_model(self):
143143 max_new_tokens=self.max_new_tokens,
144144 n_retry_server=self.n_retry_server,
145145 )
146+ elif self.backend == "vllm":
147+ return VLLMChatModel(
148+ model_name=self.model_name,
149+ temperature=self.temperature,
150+ max_tokens=self.max_new_tokens,
151+ n_retry_server=self.n_retry_server,
152+ )
146153 else:
147154 raise ValueError(f"Backend {self.backend} is not supported")
148155
@@ -423,3 +430,27 @@ def __init__(
423430
424431 client = InferenceClient(model=model_url, token=token)
425432 self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens)
433+
434+
435+ class VLLMChatModel(ChatModel):
436+ def __init__(
437+ self,
438+ model_name,
439+ api_key=None,
440+ temperature=0.5,
441+ max_tokens=100,
442+ n_retry_server=4,
443+ min_retry_wait_time=60,
444+ ):
445+ super().__init__(
446+ model_name=model_name,
447+ api_key=api_key,
448+ temperature=temperature,
449+ max_tokens=max_tokens,
450+ max_retry=n_retry_server,
451+ min_retry_wait_time=min_retry_wait_time,
452+ api_key_env_var="VLLM_API_KEY",
453+ client_class=OpenAI,
454+ client_args={"base_url": "http://0.0.0.0:8000/v1"},
455+ pricing_func=None,
456+ )
0 commit comments