Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def get_clean_message_list(
Args:
message_list (`list[ChatMessage | dict]`): List of chat messages. Mixed types are allowed.
role_conversions (`dict[MessageRole, MessageRole]`, *optional* ): Mapping to convert roles.
convert_images_to_image_urls (`bool`, default `False`): Whether to convert images to image URLs.
convert_images_to_image_urls (`bool`, default `False`, *deprecated*): botocore now takes care of converting data to base64
flatten_messages_as_text (`bool`, default `False`): Whether to flatten messages as text.
"""
output_message_list: list[dict[str, Any]] = []
Expand All @@ -302,16 +302,11 @@ def get_clean_message_list(
assert isinstance(element, dict), "Error: this element should be a dict:" + str(element)
if element["type"] == "image":
assert not flatten_messages_as_text, f"Cannot use images with {flatten_messages_as_text=}"
if convert_images_to_image_urls:
element.update(
{
"type": "image_url",
"image_url": {"url": make_image_url(encode_image_base64(element.pop("image")))},
}
)
else:
element["image"] = encode_image_base64(element["image"])

# see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageBlock.html for the format
element["image"] = {
"format": "png",
"source": {"bytes": get_image_bytes(element["image"])}
}
if len(output_message_list) > 0 and message.role == output_message_list[-1]["role"]:
assert isinstance(message.content, list), "Error: wrong content:" + str(message.content)
if flatten_messages_as_text:
Expand Down
6 changes: 6 additions & 0 deletions src/smolagents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,12 @@ def encode_image_base64(image):
return base64.b64encode(buffered.getvalue()).decode("utf-8")


def get_image_bytes(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return buffered.getvalue()


def make_image_url(base64_image):
return f"data:image/png;base64,{base64_image}"

Expand Down
6 changes: 3 additions & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,8 @@ def test_get_clean_message_list_role_conversions():
dict(
role=MessageRole.USER,
content=[
{"type": "image_url", "image_url": {"url": "_image"}},
{"type": "image_url", "image_url": {"url": "_encoded_image"}},
{"format": "png", "source": {"bytes": "encoded_image"}},
{"format": "png", "source": {"bytes": "second_encoded_image"}},
],
),
),
Expand All @@ -677,7 +677,7 @@ def test_get_clean_message_list_image_encoding(convert_images_to_image_urls, exp
role=MessageRole.USER,
content=[{"type": "image", "image": b"image_data"}, {"type": "image", "image": b"second_image_data"}],
)
with patch("smolagents.models.encode_image_base64") as mock_encode:
with patch("smolagents.models.get_image_bytes") as mock_encode:
mock_encode.side_effect = ["encoded_image", "second_encoded_image"]
result = get_clean_message_list([message], convert_images_to_image_urls=convert_images_to_image_urls)
mock_encode.assert_any_call(b"image_data")
Expand Down