Skip to content

Commit 7d8c9c8

Browse files
author
Vishal Patil
committed
creating the new DemoChatBedrock POC.
1 parent ac7ec07 commit 7d8c9c8

File tree

5 files changed

+866
-2
lines changed

5 files changed

+866
-2
lines changed

libs/aws/langchain_aws/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from langchain_aws.chat_model_adapter import BedrockClaudeAdapter, ModelAdapter
12
from langchain_aws.chains import (
23
create_neptune_opencypher_qa_chain,
34
create_neptune_sparql_qa_chain,
45
)
5-
from langchain_aws.chat_models import ChatBedrock, ChatBedrockConverse
6+
from langchain_aws.chat_models import ChatBedrock, ChatBedrockConverse, DemoChatBedrock
67
from langchain_aws.embeddings import BedrockEmbeddings
78
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
89
from langchain_aws.llms import BedrockLLM, SagemakerEndpoint
@@ -20,6 +21,9 @@
2021
"BedrockLLM",
2122
"ChatBedrock",
2223
"ChatBedrockConverse",
24+
"DemoChatBedrock",
25+
"ModelAdapter",
26+
"BedrockClaudeAdapter",
2327
"SagemakerEndpoint",
2428
"AmazonKendraRetriever",
2529
"AmazonKnowledgeBasesRetriever",
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from langchain_aws.chat_model_adapter.demo_chat_adapter import (
2+
BedrockClaudeAdapter,
3+
ModelAdapter,
4+
)
5+
6+
__all__ = ["ModelAdapter", "BedrockClaudeAdapter"]
Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
from typing import (
2+
Any,
3+
Iterator,
4+
List,
5+
Optional,
6+
Sequence,
7+
Union,
8+
Dict,
9+
Callable,
10+
Literal,
11+
Type,
12+
TypeVar,
13+
Tuple,
14+
cast,
15+
)
16+
17+
from langchain_core.language_models import BaseChatModel, LanguageModelInput
18+
from langchain_core.callbacks import CallbackManagerForLLMRun
19+
from langchain_core.messages import (
20+
BaseMessage,
21+
AIMessage,
22+
AIMessageChunk,
23+
HumanMessage,
24+
SystemMessage,
25+
ToolMessage,
26+
ChatMessage,
27+
)
28+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
29+
from langchain_core.runnables import Runnable
30+
from langchain_core.tools import BaseTool
31+
from langchain_core.utils.pydantic import TypeBaseModel
32+
from pydantic import BaseModel
33+
34+
from abc import ABC, abstractmethod
35+
import re
36+
37+
# ModelAdapter might also need access to the data that the wrapper ChatModel class has
38+
# for example, the provider or custom inputs passed in by the user
39+
40+
41+
class ModelAdapter(ABC):
42+
"""Abstract base class for model-specific adaptation strategies"""
43+
44+
@abstractmethod
45+
def convert_messages_to_payload(
46+
self,
47+
messages: List[BaseMessage],
48+
stop: Optional[List[str]] = None,
49+
**kwargs: Any,
50+
) -> Any:
51+
"""Convert LangChain messages to model-specific payload"""
52+
pass
53+
54+
@abstractmethod
55+
def convert_response_to_chat_result(self, response: Any) -> ChatResult:
56+
"""Convert model-specific response to LangChain ChatResult"""
57+
pass
58+
59+
@abstractmethod
60+
def convert_stream_response_to_chunks(
61+
self, response: Any
62+
) -> Iterator[ChatGenerationChunk]:
63+
"""Convert model-specific stream response to LangChain chunks"""
64+
pass
65+
66+
@abstractmethod
67+
def format_tools(
68+
self, tools: Sequence[Union[Dict[str, Any], TypeBaseModel, Callable, BaseTool]]
69+
) -> Any:
70+
"""Format tools for the specific model"""
71+
pass
72+
73+
74+
# Example concrete implementation for a specific model
75+
class BedrockClaudeAdapter(ModelAdapter):
76+
message_type_lookups = {
77+
"human": "user",
78+
"ai": "assistant",
79+
"AIMessageChunk": "assistant",
80+
"HumanMessageChunk": "user",
81+
}
82+
83+
def convert_messages_to_payload(
84+
self,
85+
messages: List[BaseMessage],
86+
stop: Optional[List[str]] = None,
87+
**kwargs: Any,
88+
) -> Dict[str, Any]:
89+
# Specific implementation for converting LC messages to Claude payload
90+
response_msg_with_provider = {
91+
"messages": [self._convert_message(msg) for msg in messages],
92+
"max_tokens": kwargs.get("max_tokens", 1000),
93+
"stop_sequences": stop or [],
94+
}
95+
return self.convert_messages_to_prompt_anthropic(messages=messages)
96+
97+
def _convert_message(self, msg: BaseMessage) -> Dict[str, str]:
98+
# Convert LangChain message to Claude-specific format
99+
role_map = {"human": "user", "ai": "assistant", "system": "system"}
100+
return {
101+
"role": role_map.get(msg.type, "user"),
102+
# This is just a string. A dict is expected with "type" and "text" fields
103+
"content": msg.content,
104+
}
105+
106+
def convert_response_to_chat_result(self, response: Any) -> ChatResult:
107+
pass
108+
109+
def convert_stream_response_to_chunks(
110+
self, response: Any
111+
) -> Iterator[ChatGenerationChunk]:
112+
"""Convert model-specific stream response to LangChain chunks"""
113+
pass
114+
115+
def format_tools(
116+
self, tools: Sequence[Union[Dict[str, Any], TypeBaseModel, Callable, BaseTool]]
117+
) -> Any:
118+
"""Format tools for the specific model"""
119+
pass
120+
121+
def _format_image(self, image_url: str) -> Dict:
122+
"""
123+
Formats an image of format data:image/jpeg;base64,{b64_string}
124+
to a dict for anthropic api
125+
126+
{
127+
"type": "base64",
128+
"media_type": "image/jpeg",
129+
"data": "/9j/4AAQSkZJRg...",
130+
}
131+
132+
And throws an error if it's not a b64 image
133+
"""
134+
regex = r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
135+
match = re.match(regex, image_url)
136+
if match is None:
137+
raise ValueError(
138+
"Anthropic only supports base64-encoded images currently."
139+
" Example: data:image/png;base64,'/9j/4AAQSk'..."
140+
)
141+
return {
142+
"type": "base64",
143+
"media_type": match.group("media_type"),
144+
"data": match.group("data"),
145+
}
146+
147+
def _merge_messages(
148+
self,
149+
messages: Sequence[BaseMessage],
150+
) -> List[Union[SystemMessage, AIMessage, HumanMessage]]:
151+
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
152+
merged: list = []
153+
for curr in messages:
154+
curr = curr.model_copy(deep=True)
155+
if isinstance(curr, ToolMessage):
156+
if isinstance(curr.content, list) and all(
157+
isinstance(block, dict) and block.get("type") == "tool_result"
158+
for block in curr.content
159+
):
160+
curr = HumanMessage(curr.content) # type: ignore[misc]
161+
else:
162+
curr = HumanMessage( # type: ignore[misc]
163+
[
164+
{
165+
"type": "tool_result",
166+
"content": curr.content,
167+
"tool_use_id": curr.tool_call_id,
168+
}
169+
]
170+
)
171+
last = merged[-1] if merged else None
172+
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
173+
if isinstance(last.content, str):
174+
new_content: List = [{"type": "text", "text": last.content}]
175+
else:
176+
new_content = last.content
177+
if isinstance(curr.content, str):
178+
new_content.append({"type": "text", "text": curr.content})
179+
else:
180+
new_content.extend(curr.content)
181+
last.content = new_content
182+
else:
183+
merged.append(curr)
184+
return merged
185+
186+
def format_anthropic_messages(
187+
self,
188+
messages: List[BaseMessage],
189+
) -> Tuple[Optional[str], List[Dict]]:
190+
"""Format messages for anthropic."""
191+
system: Optional[str] = None
192+
formatted_messages: List[Dict] = []
193+
194+
merged_messages = self._merge_messages(messages)
195+
for i, message in enumerate(merged_messages):
196+
if message.type == "system":
197+
if i != 0:
198+
raise ValueError(
199+
"System message must be at beginning of message list."
200+
)
201+
if not isinstance(message.content, str):
202+
raise ValueError(
203+
"System message must be a string, "
204+
f"instead was: {type(message.content)}"
205+
)
206+
system = message.content
207+
continue
208+
209+
role = self.message_type_lookups[message.type]
210+
content: Union[str, List]
211+
212+
if not isinstance(message.content, str):
213+
# parse as dict
214+
assert isinstance(
215+
message.content, list
216+
), "Anthropic message content must be str or list of dicts"
217+
218+
# populate content
219+
content = []
220+
for item in message.content:
221+
if isinstance(item, str):
222+
content.append({"type": "text", "text": item})
223+
elif isinstance(item, dict):
224+
if "type" not in item:
225+
raise ValueError("Dict content item must have a type key")
226+
elif item["type"] == "image_url":
227+
# convert format
228+
source = self._format_image(item["image_url"]["url"])
229+
content.append({"type": "image", "source": source})
230+
elif item["type"] == "tool_use":
231+
# If a tool_call with the same id as a tool_use content
232+
# block exists, the tool_call is preferred.
233+
if isinstance(message, AIMessage) and item["id"] in [
234+
tc["id"] for tc in message.tool_calls
235+
]:
236+
overlapping = [
237+
tc
238+
for tc in message.tool_calls
239+
if tc["id"] == item["id"]
240+
]
241+
# content.extend(
242+
# _lc_tool_calls_to_anthropic_tool_use_blocks(overlapping)
243+
# )
244+
else:
245+
item.pop("text", None)
246+
content.append(item)
247+
elif item["type"] == "text":
248+
text = item.get("text", "")
249+
# Only add non-empty strings for now as empty ones are not
250+
# accepted.
251+
# https://github.com/anthropics/anthropic-sdk-python/issues/461
252+
if text.strip():
253+
content.append({"type": "text", "text": text})
254+
else:
255+
content.append(item)
256+
else:
257+
raise ValueError(
258+
f"Content items must be str or dict, instead was: {type(item)}"
259+
)
260+
elif isinstance(message, AIMessage) and message.tool_calls:
261+
content = (
262+
[]
263+
if not message.content
264+
else [{"type": "text", "text": message.content}]
265+
)
266+
# Note: Anthropic can't have invalid tool calls as presently defined,
267+
# since the model already returns dicts args not JSON strings, and invalid
268+
# tool calls are those with invalid JSON for args.
269+
# content += _lc_tool_calls_to_anthropic_tool_use_blocks(message.tool_calls)
270+
else:
271+
content = message.content
272+
273+
formatted_messages.append({"role": role, "content": content})
274+
return system, formatted_messages
275+
276+
def _convert_one_message_to_text_anthropic(
277+
self,
278+
message: BaseMessage,
279+
human_prompt: str,
280+
ai_prompt: str,
281+
) -> str:
282+
content = cast(str, message.content)
283+
if isinstance(message, ChatMessage):
284+
message_text = f"\n\n{message.role.capitalize()}: {content}"
285+
elif isinstance(message, HumanMessage):
286+
message_text = f"{human_prompt} {content}"
287+
elif isinstance(message, AIMessage):
288+
message_text = f"{ai_prompt} {content}"
289+
elif isinstance(message, SystemMessage):
290+
message_text = content
291+
else:
292+
raise ValueError(f"Got unknown type {message}")
293+
return message_text
294+
295+
def convert_messages_to_prompt_anthropic(
296+
self,
297+
messages: List[BaseMessage],
298+
*,
299+
human_prompt: str = "\n\nHuman:",
300+
ai_prompt: str = "\n\nAssistant:",
301+
) -> str:
302+
"""Format a list of messages into a full prompt for the Anthropic model
303+
Args:
304+
messages (List[BaseMessage]): List of BaseMessage to combine.
305+
human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:".
306+
ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:".
307+
Returns:
308+
str: Combined string with necessary human_prompt and ai_prompt tags.
309+
"""
310+
311+
messages = messages.copy() # don't mutate the original list
312+
if not isinstance(messages[-1], AIMessage):
313+
messages.append(AIMessage(content=""))
314+
315+
text = "".join(
316+
self._convert_one_message_to_text_anthropic(
317+
message, human_prompt, ai_prompt
318+
)
319+
for message in messages
320+
)
321+
322+
# trim off the trailing ' ' that might come from the "Assistant: "
323+
return text.rstrip()
324+
325+
# Implement other abstract methods similarly...
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from langchain_aws.chat_models.bedrock import ChatBedrock
22
from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
3+
from langchain_aws.chat_models.demo_chat import DemoChatBedrock
34

4-
__all__ = ["ChatBedrock", "ChatBedrockConverse"]
5+
__all__ = ["ChatBedrock", "ChatBedrockConverse", "DemoChatBedrock"]

0 commit comments

Comments
 (0)