Skip to content

Commit 6604dbc

Browse files
committed
added the MessageBuilder class, which should help interfacing APIs
1 parent c11db49 commit 6604dbc

File tree

1 file changed

+154
-5
lines changed

1 file changed

+154
-5
lines changed

src/agentlab/llm/response_api.py

Lines changed: 154 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,103 @@
11
import logging
22
from dataclasses import dataclass
3+
from typing import Any, Dict, List, Optional, Union
34

45
import openai
56
from openai import OpenAI
67

78
from .base_api import AbstractChatModel, BaseModelArgs
89

10+
type ContentItem = Dict[str, Any]
11+
type Message = Dict[str, Union[str, List[ContentItem]]]
12+
13+
14+
class MessageBuilder:
15+
def __init__(self, role: str):
16+
self.role = role
17+
self.content: List[ContentItem] = []
18+
self.tool_call_id = None
19+
20+
@staticmethod
21+
def system() -> "MessageBuilder":
22+
return MessageBuilder(role="system")
23+
24+
@staticmethod
25+
def user() -> "MessageBuilder":
26+
return MessageBuilder(role="user")
27+
28+
@staticmethod
29+
def assistant() -> "MessageBuilder":
30+
return MessageBuilder(role="assistant")
31+
32+
@staticmethod
33+
def tool() -> "MessageBuilder":
34+
return MessageBuilder(role="tool")
35+
36+
def add_text(self, text: str) -> "MessageBuilder":
37+
self.content.append({"text": text})
38+
return self
39+
40+
def add_image(self, image: str) -> "MessageBuilder":
41+
self.content.append({"image": image})
42+
return self
43+
44+
def add_tool_id(self, tool_id: str) -> "MessageBuilder":
45+
self.tool_call_id = tool_id
46+
return self
47+
48+
def to_openai(self) -> List[Message]:
49+
content = []
50+
for item in self.content:
51+
if "text" in item:
52+
content.append({"type": "input_text", "text": item["text"]})
53+
elif "image" in item:
54+
content.append({"type": "input_image", "image_url": item["image"]})
55+
res = [{"role": self.role, "content": content}]
56+
57+
if self.role == "tool":
58+
assert self.tool_call_id is not None, "Tool call ID is required for tool messages"
59+
# tool messages can only take text with openai
60+
# we need to split the first content element if it's text and use it
61+
# then open a new (user) message with the rest
62+
res[0]["tool_call_id"] = self.tool_call_id
63+
text_content = (
64+
content.pop(0)["text"]
65+
if "text" in content[0]
66+
else "Tool call answer in next message"
67+
)
68+
res[0]["content"] = text_content
69+
res.append({"role": "user", "content": content})
70+
71+
return res
72+
73+
def to_anthropic(self) -> List[Message]:
74+
content = []
75+
for item in self.content:
76+
if "text" in item:
77+
content.append({"type": "text", "text": item["text"]})
78+
elif "image" in item:
79+
content.append(
80+
{
81+
"type": "image",
82+
"source": {
83+
"type": "base64", # currently only base64 is supported
84+
"media_type": "image/png", # currently only png is supported
85+
"data": item["image"],
86+
},
87+
}
88+
)
89+
res = [{"role": self.role, "content": content}]
90+
91+
if self.role == "tool":
92+
assert self.tool_call_id is not None, "Tool call ID is required for tool messages"
93+
res[0]["role"] = "user"
94+
res[0]["content"] = {
95+
"type": "tool_result",
96+
"tool_use_id": self.tool_call_id,
97+
"content": res[0]["content"],
98+
}
99+
return res
100+
9101

10102
class ResponseModel(AbstractChatModel):
11103
def __init__(
@@ -29,15 +121,15 @@ def __call__(self, content: dict, temperature: float = None) -> dict:
29121
response = self.client.responses.create(
30122
model=self.model_name,
31123
input=content,
32-
# temperature=temperature,
124+
temperature=temperature,
33125
# previous_response_id=content.get("previous_response_id", None),
34126
max_output_tokens=self.max_tokens,
35127
**self.extra_kwargs,
36128
tool_choice="required",
37-
reasoning={
38-
"effort": "low",
39-
"summary": "detailed",
40-
},
129+
# reasoning={
130+
# "effort": "low",
131+
# "summary": "detailed",
132+
# },
41133
)
42134
return response
43135
except openai.OpenAIError as e:
@@ -165,3 +257,60 @@ def make_model(self, extra_kwargs=None):
165257
max_tokens=self.max_new_tokens,
166258
extra_kwargs=extra_kwargs,
167259
)
260+
261+
262+
import anthropic
263+
264+
265+
class ClaudeResponseModel(ResponseModel):
266+
def __init__(
267+
self,
268+
model_name,
269+
api_key=None,
270+
temperature=0.5,
271+
max_tokens=100,
272+
extra_kwargs=None,
273+
):
274+
super().__init__(
275+
model_name=model_name,
276+
api_key=api_key,
277+
temperature=temperature,
278+
max_tokens=max_tokens,
279+
extra_kwargs=extra_kwargs,
280+
)
281+
self.client = anthropic.Client(api_key=api_key)
282+
self.model_name = model_name
283+
self.temperature = temperature
284+
self.max_tokens = max_tokens
285+
self.extra_kwargs = extra_kwargs or {}
286+
self.model_name = model_name
287+
self.api_key = api_key
288+
289+
def __call__(self, messages: list[dict], temperature: float = None) -> dict:
290+
temperature = temperature if temperature is not None else self.temperature
291+
try:
292+
response = self.client.messages.create(
293+
model=self.model_name,
294+
messages=messages,
295+
temperature=temperature,
296+
max_tokens=self.max_tokens,
297+
**self.extra_kwargs,
298+
)
299+
return response
300+
except Exception as e:
301+
logging.error(f"Failed to get a response from the API: {e}")
302+
raise e
303+
304+
305+
@dataclass
306+
class ClaudeResponseModelArgs(BaseModelArgs):
307+
"""Serializable object for instantiating a generic chat model with an OpenAI
308+
model."""
309+
310+
def make_model(self, extra_kwargs=None):
311+
return ClaudeResponseModel(
312+
model_name=self.model_name,
313+
temperature=self.temperature,
314+
max_tokens=self.max_new_tokens,
315+
extra_kwargs=extra_kwargs,
316+
)

0 commit comments

Comments
 (0)