Skip to content

Commit d384b8a

Browse files
Add VLLMChatModel support to chat API (#220)
1 parent 4ccbf41 commit d384b8a

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

src/agentlab/llm/chat_api.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,13 @@ def make_model(self):
143143
max_new_tokens=self.max_new_tokens,
144144
n_retry_server=self.n_retry_server,
145145
)
146+
elif self.backend == "vllm":
147+
return VLLMChatModel(
148+
model_name=self.model_name,
149+
temperature=self.temperature,
150+
max_tokens=self.max_new_tokens,
151+
n_retry_server=self.n_retry_server,
152+
)
146153
else:
147154
raise ValueError(f"Backend {self.backend} is not supported")
148155

@@ -423,3 +430,27 @@ def __init__(
423430

424431
client = InferenceClient(model=model_url, token=token)
425432
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens)
433+
434+
435+
class VLLMChatModel(ChatModel):
436+
def __init__(
437+
self,
438+
model_name,
439+
api_key=None,
440+
temperature=0.5,
441+
max_tokens=100,
442+
n_retry_server=4,
443+
min_retry_wait_time=60,
444+
):
445+
super().__init__(
446+
model_name=model_name,
447+
api_key=api_key,
448+
temperature=temperature,
449+
max_tokens=max_tokens,
450+
max_retry=n_retry_server,
451+
min_retry_wait_time=min_retry_wait_time,
452+
api_key_env_var="VLLM_API_KEY",
453+
client_class=OpenAI,
454+
client_args={"base_url": "http://0.0.0.0:8000/v1"},
455+
pricing_func=None,
456+
)

0 commit comments

Comments
 (0)