Skip to content

Commit 86754ed

Browse files
authored
Merge branch 'main' into tlsdc/log_prob
2 parents faac3bf + af0742a commit 86754ed

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

src/agentlab/agents/generic_agent/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
AGENT_4o,
1717
AGENT_4o_MINI,
1818
AGENT_CLAUDE_SONNET_35,
19+
AGENT_CLAUDE_SONNET_35_VISION,
1920
AGENT_4o_VISION,
21+
AGENT_4o_MINI_VISION,
2022
AGENT_o3_MINI,
2123
AGENT_o1_MINI,
2224
)

src/agentlab/llm/chat_api.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,13 @@ def make_model(self):
147147
n_retry_server=self.n_retry_server,
148148
log_probs=self.log_probs,
149149
)
150+
elif self.backend == "vllm":
151+
return VLLMChatModel(
152+
model_name=self.model_name,
153+
temperature=self.temperature,
154+
max_tokens=self.max_new_tokens,
155+
n_retry_server=self.n_retry_server,
156+
)
150157
else:
151158
raise ValueError(f"Backend {self.backend} is not supported")
152159

@@ -440,3 +447,27 @@ def __init__(
440447

441448
client = InferenceClient(model=model_url, token=token)
442449
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens, details=log_probs)
450+
451+
452+
class VLLMChatModel(ChatModel):
453+
def __init__(
454+
self,
455+
model_name,
456+
api_key=None,
457+
temperature=0.5,
458+
max_tokens=100,
459+
n_retry_server=4,
460+
min_retry_wait_time=60,
461+
):
462+
super().__init__(
463+
model_name=model_name,
464+
api_key=api_key,
465+
temperature=temperature,
466+
max_tokens=max_tokens,
467+
max_retry=n_retry_server,
468+
min_retry_wait_time=min_retry_wait_time,
469+
api_key_env_var="VLLM_API_KEY",
470+
client_class=OpenAI,
471+
client_args={"base_url": "http://0.0.0.0:8000/v1"},
472+
pricing_func=None,
473+
)

0 commit comments

Comments
 (0)